Merge remote-tracking branch 'upstream/master' into constant_from_identity

# Conflicts:
#	tensorflow/python/framework/tensor_util.py
#	tensorflow/python/kernel_tests/confusion_matrix_test.py
This commit is contained in:
ngc92 2020-04-22 20:02:22 +03:00
commit 4b265f2022
4567 changed files with 220359 additions and 79094 deletions

View File

@ -19,10 +19,10 @@
# Compiler options:
# cuda_clang: Use clang when building CUDA code.
# c++17: Build with C++17 options
# C++1z: Build with C++17 options
# c++1z: Build with C++17 options
# avx_linux: Build with avx instruction set on linux.
# avx2_linux: Build with avx2 instruction set on linux.
# arch_native_linux: Build with instruction sets available to the host machine on linux
# native_arch_linux: Build with instruction sets available to the host machine on linux
# avx_win: Build with avx instruction set on windows
# avx2_win: Build with avx2 instruction set on windows
#
@ -73,6 +73,10 @@
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
#
# Embedded Linux options (experimental and only tested with TFLite build yet)
# elinux: General Embedded Linux options shared by all flavors.
# elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support.
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
@ -432,6 +436,14 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
# TFLite build configs for generic embedded Linux
build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain
build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
build:elinux_aarch64 --config=elinux
build:elinux_aarch64 --cpu=aarch64
build:elinux_armhf --config=elinux
build:elinux_armhf --cpu=armhf
# END TF REMOTE BUILD EXECUTION OPTIONS
# Default options should come above this line

View File

@ -1 +1 @@
2.0.0
3.0.0

View File

@ -10,32 +10,30 @@ labels: 'type:bug'
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
**System information**
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
You can also obtain the TensorFlow version with:
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

View File

@ -38,6 +38,9 @@ state what is wrong:
- Producing correct results, but the model is slower than expected (model generated from old converter)
**RNN conversion support**
If converting TF RNN to TFLite fused RNN ops, please prefix [RNN] in the title.
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

View File

@ -11,32 +11,29 @@ As per our
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
**System information**
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
You can also obtain the TensorFlow version with:
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

39
.github/stale.yml vendored Normal file
View File

@ -0,0 +1,39 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# Number of days of inactivity before an Issue or Pull Request becomes stale
daysUntilStale: 7
# Number of days of inactivity before a stale Issue or Pull Request is closed
daysUntilClose: 7
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
onlyLabels:
- stat:awaiting response
# Comment to post when marking as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you.
# Comment to post when removing the stale label. Set to `false` to disable
unmarkComment: false
closeComment: >
Closing as stale. Please reopen if you'd like to work on this further.
limitPerRun: 30
# Limit to only `issues` or `pulls`
only: issues

1
.gitignore vendored
View File

@ -38,6 +38,7 @@ gradleBuild
*.pbxproj
*.xcworkspace
/*.podspec
/tensorflow/lite/**/coreml/**/BUILD
/tensorflow/lite/**/ios/BUILD
/tensorflow/lite/**/objc/BUILD
/tensorflow/lite/**/swift/BUILD

View File

@ -2,6 +2,10 @@
<img src="https://www.tensorflow.org/images/tf_logo_social.png">
</div>
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow)
[![PyPI](https://badge.fury.io/py/tensorflow.svg)](https://badge.fury.io/py/tensorflow)
**`Documentation`** |
------------------- |
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) |
@ -135,6 +139,7 @@ Build Type | Status
* [TensorFlow Examples](https://github.com/tensorflow/examples)
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [TensorFlow Blog](https://blog.tensorflow.org)

View File

@ -18,6 +18,8 @@ load("//tensorflow:workspace.bzl", "tf_repositories")
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_repositories()
register_execution_platforms("@local_execution_config_platform//:platform")
register_toolchains("@local_execution_config_python//:py_toolchain")
register_toolchains("@local_config_python//:py_toolchain")
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")

View File

@ -2,58 +2,42 @@ package(default_visibility = ["//visibility:public"])
filegroup(
name = "gcc",
srcs = [
"bin/arm-rpi-linux-gnueabihf-gcc",
],
srcs = glob(["bin/*-gcc"]),
)
filegroup(
name = "ar",
srcs = [
"bin/arm-rpi-linux-gnueabihf-ar",
],
srcs = glob(["bin/*-ar"]),
)
filegroup(
name = "ld",
srcs = [
"bin/arm-rpi-linux-gnueabihf-ld",
],
srcs = glob(["bin/*-ld"]),
)
filegroup(
name = "nm",
srcs = [
"bin/arm-rpi-linux-gnueabihf-nm",
],
srcs = glob(["bin/*-nm"]),
)
filegroup(
name = "objcopy",
srcs = [
"bin/arm-rpi-linux-gnueabihf-objcopy",
],
srcs = glob(["bin/*-objcopy"]),
)
filegroup(
name = "objdump",
srcs = [
"bin/arm-rpi-linux-gnueabihf-objdump",
],
srcs = glob(["bin/*-objdump"]),
)
filegroup(
name = "strip",
srcs = [
"bin/arm-rpi-linux-gnueabihf-strip",
],
srcs = glob(["bin/*-strip"]),
)
filegroup(
name = "as",
srcs = [
"bin/arm-rpi-linux-gnueabihf-as",
],
srcs = glob(["bin/*-as"]),
)
filegroup(
@ -66,6 +50,16 @@ filegroup(
]),
)
filegroup(
name = "aarch64_compiler_pieces",
srcs = glob([
"aarch64-none-linux-gnu/**",
"libexec/**",
"lib/gcc/aarch64-none-linux-gnu/**",
"include/**",
]),
)
filegroup(
name = "compiler_components",
srcs = [

2
configure vendored
View File

@ -4,7 +4,7 @@ set -e
set -o pipefail
if [ -z "$PYTHON_BIN_PATH" ]; then
PYTHON_BIN_PATH=$(which python || which python3 || true)
PYTHON_BIN_PATH=$(which python3 || which python || true)
fi
# Set all env variables

View File

@ -50,7 +50,7 @@ _TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '3.99.0'
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''

View File

@ -214,6 +214,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_armhf",
values = {"cpu": "armhf"},
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_x86_64",
values = {"cpu": "k8"},
@ -639,7 +645,7 @@ tf_cc_shared_object(
"//tensorflow/cc/saved_model:loader_lite_impl",
"//tensorflow/core:core_cpu_impl",
"//tensorflow/core:framework_internal_impl",
"//tensorflow/core:gpu_runtime_impl",
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
"//tensorflow/core:lib_internal_impl",
"//tensorflow/core/profiler:profiler_impl",
@ -703,8 +709,8 @@ tf_cc_shared_object(
"//tensorflow/c:version_script.lds",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:distributed_tensorflow_dependencies",
"//tensorflow/core:tensorflow",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
],
)

View File

@ -116,7 +116,7 @@ from tensorflow.python.lib.io import file_io as _fi
# Get sitepackages directories for the python installation.
_site_packages_dirs = []
_site_packages_dirs += [_site.USER_SITE]
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
if 'getsitepackages' in dir(_site):
_site_packages_dirs += _site.getsitepackages()

View File

@ -126,7 +126,7 @@ from tensorflow.python.lib.io import file_io as _fi
# Get sitepackages directories for the python installation.
_site_packages_dirs = []
_site_packages_dirs += [_site.USER_SITE]
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
if 'getsitepackages' in dir(_site):
_site_packages_dirs += _site.getsitepackages()

View File

@ -23,6 +23,7 @@ filegroup(
srcs = [
"c_api.h",
"c_api_experimental.h",
"tensor_interface.h",
"tf_attrtype.h",
"tf_datatype.h",
"tf_file_statistics.h",
@ -58,6 +59,7 @@ filegroup(
srcs = [
"c_api_internal.h",
"python_api.h",
"tensor_interface.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
@ -116,6 +118,12 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "c_api_macros",
hdrs = ["c_api_macros.h"],
visibility = ["//visibility:public"],
)
tf_cuda_library(
name = "c_api",
hdrs = [
@ -238,6 +246,16 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "tensor_interface",
hdrs = ["tensor_interface.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tf_datatype",
srcs = ["tf_datatype.cc"],
@ -264,6 +282,7 @@ cc_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
@ -271,6 +290,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
],
}),
)
@ -281,16 +301,18 @@ tf_cuda_library(
"tf_tensor.h",
"tf_tensor_internal.h",
],
visibility = ["//tensorflow/c:__subpackages__"],
visibility = ["//tensorflow:internal"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
],
}),
)
@ -318,6 +340,8 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform",
"@com_google_absl//absl/strings",

View File

@ -56,16 +56,16 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/coding.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"

View File

@ -24,17 +24,18 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -686,8 +687,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
status->status = tensorflow::Status::OK();
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
}
namespace {
@ -827,8 +827,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
tensorflow::down_cast<tensorflow::OperationInterface*>(
tfe_op->operation.get())
OperationFromInterface(tfe_op->operation)
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());

View File

@ -218,7 +218,7 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
float five = 5.0;
TFE_TensorHandle* scalar = TestScalarTensorHandle(five);
TFE_TensorHandle* scalar = TestScalarTensorHandle(tfe_context_, five);
TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CheckOutputShapes(fill_op,

View File

@ -27,8 +27,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/strings/base64.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/base64.h"
#include "tensorflow/core/platform/strcat.h"
using tensorflow::errors::InvalidArgument;

View File

@ -14,17 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {

View File

@ -40,8 +40,8 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
@ -186,10 +186,6 @@ struct TF_Server {
namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out);

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_C_API_MACROS_H_
#define TENSORFLOW_C_C_API_MACROS_H_
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#endif // TENSORFLOW_C_C_API_MACROS_H_

View File

@ -43,10 +43,10 @@ limitations under the License.
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"

View File

@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/public/session_options.h"
using tensorflow::GraphDef;

View File

@ -18,9 +18,9 @@ limitations under the License.
#include <unordered_set>
#include <utility>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
#include "tensorflow/core/util/tensor_slice_reader.h"

View File

@ -29,11 +29,6 @@ tf_cuda_library(
"c_api_experimental.h",
"c_api_internal.h",
"c_api_unified_experimental.h",
"context_interface.cc",
"context_interface.h",
"operation_interface.cc",
"operation_interface.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api.h"],
copts = tf_copts() + tfe_xla_copts(),
@ -43,29 +38,39 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":context_interface",
":operation_interface",
":tensor_handle_interface",
":tfe_context_internal",
":tfe_cancellation_manager_internal",
":tfe_executor_internal",
":tfe_monitoring_internal",
":tfe_op_attrs_internal",
":tfe_op_internal",
":tfe_tensor_debug_info_internal",
":tfe_tensorhandle_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/common_runtime/eager:copy_to_device_node",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:errors",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/types:variant",
],
}) + select({
"//tensorflow:with_xla_support": [
@ -104,6 +109,14 @@ filegroup(
"dlpack.h",
"operation_interface.h",
"tensor_handle_interface.h",
"tfe_cancellation_manager_internal.h",
"tfe_context_internal.h",
"tfe_executor_internal.h",
"tfe_monitoring_internal.h",
"tfe_op_attrs_internal.h",
"tfe_op_internal.h",
"tfe_tensor_debug_info_internal.h",
"tfe_tensorhandle_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
@ -111,38 +124,167 @@ filegroup(
],
)
tf_cuda_library(
cc_library(
name = "c_api_internal",
srcs = [
hdrs = [
"c_api_experimental.h",
"c_api_unified_experimental.h",
"context_interface.h",
"operation_interface.h",
"tensor_handle_interface.h",
"c_api_internal.h",
],
hdrs = ["c_api_internal.h"],
visibility = [
"//learning/deepmind/courier:__subpackages__",
"//tensorflow:internal",
],
deps = [
":c_api",
"//tensorflow/c:c_api",
":tfe_cancellation_manager_internal",
":tfe_context_internal",
":tfe_executor_internal",
":tfe_monitoring_internal",
":tfe_op_attrs_internal",
":tfe_op_internal",
":tfe_tensor_debug_info_internal",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
],
)
cc_library(
name = "tensor_handle_interface",
hdrs = ["tensor_handle_interface.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "operation_interface",
hdrs = ["operation_interface.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":tensor_handle_interface",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "context_interface",
hdrs = ["context_interface.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":operation_interface",
":tensor_handle_interface",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tfe_context_internal",
hdrs = ["tfe_context_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":context_interface",
],
)
cc_library(
name = "tfe_cancellation_manager_internal",
hdrs = ["tfe_cancellation_manager_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:framework",
],
)
cc_library(
name = "tfe_executor_internal",
hdrs = ["tfe_executor_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/container:fixed_array",
],
)
cc_library(
name = "tfe_monitoring_internal",
hdrs = ["tfe_monitoring_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
],
)
cc_library(
name = "tfe_op_attrs_internal",
hdrs = ["tfe_op_attrs_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":tfe_context_internal",
":tfe_op_internal",
"//tensorflow/c:tf_status",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
],
)
cc_library(
name = "tfe_op_internal",
hdrs = ["tfe_op_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":operation_interface",
],
)
cc_library(
name = "tfe_tensor_debug_info_internal",
hdrs = ["tfe_tensor_debug_info_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:lib",
],
)
cc_library(
name = "tfe_tensorhandle_internal",
hdrs = ["tfe_tensorhandle_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":tensor_handle_interface",
],
)
@ -157,6 +299,7 @@ tf_cuda_library(
],
deps = [
":c_api",
":c_api_experimental",
"//tensorflow/c:c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -184,9 +327,12 @@ tf_cuda_cc_test(
":c_api_test_util",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/strings",
],
)
@ -198,11 +344,6 @@ tf_cuda_cc_test(
"c_api_remote_test.cc",
],
extra_copts = tfe_xla_copts(),
tags = [
"guitar",
"multi_gpu",
"no_oss",
],
deps = [
":c_api",
":c_api_experimental",
@ -213,6 +354,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
@ -243,6 +385,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
@ -265,7 +408,6 @@ tf_cuda_library(
}) + [
"@com_google_absl//absl/memory",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
@ -329,6 +471,22 @@ tf_cuda_cc_test(
],
)
cc_library(
name = "custom_device_testutil",
testonly = True,
srcs = ["custom_device_testutil.cc"],
hdrs = ["custom_device_testutil.h"],
visibility = ["//tensorflow:internal"],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",
@ -339,6 +497,7 @@ tf_cc_test(
":c_api",
":c_api_experimental",
":c_api_test_util",
":custom_device_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
@ -388,6 +547,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@dlpack",
],
alwayslink = 1,
@ -405,6 +565,7 @@ filegroup(
],
exclude = [
"c_api_experimental.cc",
"*c_api_tfrt*",
"*test*",
"*dlpack*",
],

View File

@ -22,28 +22,30 @@ limitations under the License.
#include <vector>
// clang-format off
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/algorithm/container.h"
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tf_tensor_internal.h"
#ifdef PLATFORM_GOOGLE
#include "tensorflow/c/eager/c_api_tfrt.h"
#endif
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -53,8 +55,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#if !defined(IS_MOBILE_PLATFORM)
@ -73,18 +75,17 @@ limitations under the License.
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/public/version.h"
@ -94,15 +95,6 @@ using tensorflow::string;
namespace {
bool IsCPU(
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
if (VariantDeviceIsCustom(variant)) {
return false;
}
tensorflow::Device* d = absl::get<tensorflow::Device*>(variant);
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
@ -690,6 +682,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE
status->status = tensorflow::Status::OK();
return new TFE_Context{new tfrt::ContextInterface()};
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;
#endif
}
std::vector<std::unique_ptr<tensorflow::Device>> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
@ -701,16 +702,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context{std::make_unique<tensorflow::ContextInterface>(
new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator()))};
return new TFE_Context{new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator())};
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -721,24 +720,24 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context{std::make_unique<tensorflow::ContextInterface>(
new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator()))};
return new TFE_Context{new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator())};
}
void TFE_DeleteContext(TFE_Context* ctx) {
if (ctx == nullptr) {
return;
}
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->Unref();
ctx->context->Release();
delete ctx;
}
@ -921,253 +920,102 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
if (!status->status.ok()) return nullptr;
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
if (h == nullptr) return;
tensorflow::profiler::TraceMe activity(
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
if (h->handle) {
h->handle->Release();
}
delete h;
}
tensorflow::TensorHandleInterface::~TensorHandleInterface() {
VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
<< handle_;
if (handle_) {
handle_->Unref();
}
}
bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
if (handle_ == nullptr) {
*status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return false;
}
return true;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return h->handle->DataType();
}
TF_DataType tensorflow::TensorHandleInterface::DataType() const {
return static_cast<TF_DataType>(handle_->dtype);
return static_cast<TF_DataType>(h->handle->DataType());
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
return h->handle->NumDims(&status->status);
}
int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
if (!IsValid(status)) {
return -1;
}
int result;
*status = handle_->NumDims(&result);
return result;
int num_dims = -1;
status->status = h->handle->NumDims(&num_dims);
return num_dims;
}
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
return h->handle->NumElements(&status->status);
}
int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result;
*status = handle_->NumElements(&result);
return result;
int64 num_elements = -1;
status->status = h->handle->NumElements(&num_elements);
return num_elements;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
return h->handle->Dim(dim_index, &status->status);
}
int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result;
*status = handle_->Dim(dim_index, &result);
return result;
int64 dim = -1;
status->status = h->handle->Dim(dim_index, &dim);
return dim;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return h->handle->DeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::DeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
if (VariantDeviceIsCustom(handle_->device())) {
return absl::get<CustomDevice*>(handle_->device())->name().c_str();
}
tensorflow::Device* d = handle_->op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return h->handle->BackingDeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::BackingDeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
if (VariantDeviceIsCustom(handle_->device())) {
return absl::get<tensorflow::CustomDevice*>(handle_->device())
->name()
.c_str();
} else {
tensorflow::Device* d = absl::get<tensorflow::Device*>(handle_->device());
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
}
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return new TFE_TensorHandle{
std::unique_ptr<tensorflow::AbstractTensorHandleInterface>(
h->handle->Copy())};
}
tensorflow::AbstractTensorHandleInterface*
tensorflow::TensorHandleInterface::Copy() {
handle_->Ref();
return new TensorHandleInterface(handle_);
}
void tensorflow::TensorHandleInterface::EnableImplicitMirroring() {
handle_->EnableImplicitMirroring();
return new TFE_TensorHandle{h->handle->Copy()};
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return h->handle->Resolve(&status->status);
}
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!IsValid(status)) {
tensorflow::AbstractTensorInterface* t = h->handle->Resolve(&status->status);
if (t == nullptr) {
return nullptr;
}
if (VariantDeviceIsCustom(handle_->device())) {
tensorflow::CustomDevice* custom_device =
absl::get<tensorflow::CustomDevice*>(handle_->device());
tensorflow::TensorHandle* copy;
*status = custom_device->CopyTensorFromDevice(
handle_, "/job:localhost/task:0/replica:0/device:CPU:0", &copy);
if (status->ok()) {
return TensorHandleInterface(copy).Resolve(status);
} else {
return nullptr;
}
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
if (handle_->IsRemote()) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
*status = EagerCopyToDevice(handle_, handle_->Context(),
&handle_->Context()->Executor(),
handle_->Context()->HostCPU(), false, &h_cpu);
if (!status->ok()) {
return nullptr;
}
*status = h_cpu->Tensor(&t);
if (!status->ok()) {
h_cpu->Unref();
return nullptr;
}
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
h_cpu->Unref();
return retval;
} else {
tensorflow::Tensor tensor;
if (IsCPU(handle_->device()) || handle_->HasLocalMirror(nullptr)) {
const tensorflow::Tensor* src = nullptr;
if (handle_->HasLocalMirror(nullptr)) {
*status = handle_->TensorFromDevice(nullptr, &src);
} else {
*status = handle_->Tensor(&src);
}
if (!status->ok()) return nullptr;
tensor = *src;
} else {
tensorflow::EagerContext* ctx = handle_->Context();
CHECK_NE(ctx, nullptr);
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
if (!status->ok()) return nullptr;
if (handle_->ImplicitMirroring()) {
*status = handle_->AddEmptyLocalMirror(nullptr);
if (!status->ok()) return nullptr;
Tensor mirror = tensor;
*status = handle_->SetTensor(std::move(mirror), nullptr);
if (!status->ok()) return nullptr;
}
}
return tensorflow::TF_TensorFromTensor(tensor, status);
}
return new TF_Tensor{t};
}
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
@ -1233,15 +1081,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorShape(dimvec), buf);
buf->Unref();
if (custom_device == nullptr) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(std::move(t), device,
device, context))};
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), device, device, context)};
} else {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context))};
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context)};
}
}
@ -1250,9 +1094,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
// bytes of the memory pointed to by the device pointer returned above.
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return 0;
}
tensorflow::TensorHandle* handle =
@ -1281,7 +1124,17 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
return new_op.release();
}
void TFE_DeleteOp(TFE_Op* op) { delete op; }
void TFE_DeleteOp(TFE_Op* op) {
if (op == nullptr) {
return;
}
if (op->operation) {
op->operation->Release();
}
delete op;
}
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = op->operation->SetDeviceName(device_name);
@ -1309,12 +1162,13 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
absl::FixedArray<std::unique_ptr<tensorflow::AbstractTensorHandleInterface>>
handles(num_inputs);
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
num_inputs);
for (int i = 0; i < num_inputs; ++i) {
handles[i].reset(inputs[i]->handle->Copy());
handles[i] = inputs[i]->handle;
}
status->status = op->operation->AddInputList(handles);
status->status =
op->operation->AddInputList({handles.data(), handles.size()});
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1378,7 +1232,8 @@ void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
auto s = op->operation->SetAttrType(attr_name, value);
auto s = op->operation->SetAttrType(attr_name,
static_cast<tensorflow::DataType>(value));
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1407,7 +1262,10 @@ void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
TF_Status* status) {
status->status = op->operation->SetAttrTensor(attr_name, tensor);
tensorflow::Tensor t;
status->status = TF_TensorToTensor(tensor, &t);
tensorflow::TensorInterface interface(t);
status->status = op->operation->SetAttrTensor(attr_name, &interface);
}
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
@ -1438,7 +1296,9 @@ void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
auto s = op->operation->SetAttrTypeList(attr_name, values, num_values);
auto s = op->operation->SetAttrTypeList(
attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1461,7 +1321,13 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) {
auto s = op->operation->SetAttrFunctionList(attr_name, value, num_values);
absl::FixedArray<const tensorflow::AbstractOperationInterface*> values(
num_values);
for (int i = 0; i < num_values; ++i) {
values[i] = value[i]->operation;
}
auto s = op->operation->SetAttrFunctionList(attr_name,
{values.data(), values.size()});
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1481,8 +1347,7 @@ void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
"Got a null or uninitialized `op` argument");
return;
}
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
operation->MutableAttrs()->Set(attr_name, attr_value);
}
@ -1504,14 +1369,14 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
absl::FixedArray<std::unique_ptr<tensorflow::AbstractTensorHandleInterface>>
handles(*num_retvals);
status->status = op->operation->Execute(&handles, num_retvals);
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
*num_retvals);
status->status = op->operation->Execute(absl::MakeSpan(handles), num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle{std::move(handles[i])};
retvals[i] = new TFE_TensorHandle{handles[i]};
}
}
@ -1519,48 +1384,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
tensorflow::TensorHandle* handle = nullptr;
tensorflow::Device* device;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
tensorflow::CustomDevice* dev;
status->status = context->FindCustomDeviceFromName(device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorToDevice(
tensorflow::TensorHandleFromInterface(h->handle), &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
}
return nullptr;
}
// Handle tensor handles currently in custom devices
const char* handle_device_name = h->handle->DeviceName(&status->status);
if (!status->status.ok()) {
return nullptr;
}
tensorflow::CustomDevice* dev;
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorFromDevice(
tensorflow::TensorHandleFromInterface(h->handle), device_name, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
// Handle regular case.
status->status = tensorflow::EagerCopyToDevice(
tensorflow::TensorHandleFromInterface(h->handle), context,
&context->Executor(), device, false, &handle);
auto* result = ctx->context->CopyTensorHandleToDevice(h->handle, device_name,
&status->status);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
return new TFE_TensorHandle{result};
}
return nullptr;
}
@ -1615,9 +1447,7 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
TF_Status* status) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(t))};
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(t)};
}
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
@ -1657,18 +1487,16 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
}
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
*attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
*attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (auto attribute : m) {
for (const auto& attribute : m) {
destination->Set(attribute.first, attribute.second);
}
}
@ -1724,6 +1552,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
// require TFE_Op* and just convert it internally a NameAttrValue, so
// consider adding an overload to the C API to make this case easier.
TFE_OpSetAttrFunction(op, attr_name, func_op);
TFE_DeleteOp(func_op);
} break;
case tensorflow::AttrValue::kList:
TF_FALLTHROUGH_INTENDED;
@ -1756,15 +1585,15 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
tensorflow::TensorHandle* tensor,
tensorflow::TensorHandle** result) override {
tensor->Ref();
TFE_TensorHandle tensor_handle{
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
TFE_TensorHandle tensor_handle{tensor};
TF_Status status;
TFE_TensorHandle* result_handle =
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
tensor_handle.handle->Release();
if (!status.status.ok()) return status.status;
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
(*result)->Ref();
delete result_handle;
TFE_DeleteTensorHandle(result_handle);
return status.status;
}
@ -1774,14 +1603,14 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
tensorflow::TensorHandle** result) override {
TF_Status status;
tensor->Ref();
TFE_TensorHandle tensor_handle{
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
TFE_TensorHandle tensor_handle{tensor};
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
tensor_handle.handle->Release();
if (!status.status.ok()) return status.status;
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
(*result)->Ref();
delete result_handle;
TFE_DeleteTensorHandle(result_handle);
return status.status;
}
@ -1792,9 +1621,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
inputs.reserve(op->Inputs().size());
for (int i = 0; i < op->Inputs().size(); ++i) {
op->Inputs()[i]->Ref();
inputs.push_back(new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
op->Inputs()[i])});
inputs.push_back(new TFE_TensorHandle{op->Inputs()[i]});
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
@ -1805,12 +1632,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle);
retvals[i]->Ref();
delete outputs[i];
TFE_DeleteTensorHandle(outputs[i]);
}
}
for (auto inp : inputs) {
delete inp;
TFE_DeleteTensorHandle(inp);
}
return status.status;
}
@ -1823,6 +1650,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
};
} // namespace
extern "C" {
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status) {
@ -1833,3 +1662,5 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
status->status =
context->RegisterCustomDevice(device_name, std::move(custom_device));
}
} // extern "C"

View File

@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/platform/status.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/jit/xla_device.h"
#endif // TENSORFLOW_EAGER_USE_XLA
@ -54,36 +57,32 @@ extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* h, TF_Status* status) {
return h->handle->TensorDebugInfo(&status->status);
}
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
Status* status) {
tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
const tensorflow::Tensor* tensor;
*status = handle_->Tensor(&tensor);
if (!status->ok()) {
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = absl::get<Device*>(handle_->device());
auto* device = absl::get<tensorflow::Device*>(handle->device());
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =
dynamic_cast<tensorflow::XlaDevice*>(device);
auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
if (xla_device != nullptr) {
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape;
*status = shape_fn(*tensor, &padded_shape);
if (!status->ok()) {
status->status = shape_fn(*tensor, &padded_shape);
if (!status->status.ok()) {
return nullptr;
}
if (VLOG_IS_ON(3)) {
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
if (!status->ok()) {
std::vector<int64> shape_to_log =
TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {
// Ignore the status here as we are simply logging.
*status = tensorflow::Status::OK();
status->status = tensorflow::Status::OK();
} else {
VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is "
@ -96,7 +95,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
// Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers).
*status = tensorflow::errors::InvalidArgument(
status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString());
return nullptr;
@ -108,13 +107,13 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) {
*status = tensorflow::errors::InvalidArgument(
status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString());
return nullptr;
}
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
*status = tensorflow::errors::InvalidArgument(
status->status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString());
return nullptr;
@ -139,15 +138,15 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
dev_dims.push_back(padded_shape.dimensions(dim_index));
}
}
*status = tensorflow::Status::OK();
status->status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims);
}
#endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
if (!status->ok()) {
std::vector<int64> dev_dims = TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {
return nullptr;
}
return new TFE_TensorDebugInfo(dev_dims);

View File

@ -21,8 +21,13 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
TEST(CApiDebug, ScalarCPU) {
TFE_TensorHandle* h = TestScalarTensorHandle(1.0f);
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h = TestScalarTensorHandle(ctx, 1.0f);
TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -30,12 +35,18 @@ TEST(CApiDebug, ScalarCPU) {
TFE_DeleteTensorDebugInfo(debug_info);
TFE_DeleteTensorHandle(h);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CApiDebug, 2DCPU) {
TFE_TensorHandle* h = TestMatrixTensorHandle3X2();
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h = TestMatrixTensorHandle3X2(ctx);
TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -46,5 +57,6 @@ TEST(CApiDebug, 2DCPU) {
TFE_DeleteTensorDebugInfo(debug_info);
TFE_DeleteTensorHandle(h);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}

View File

@ -15,22 +15,26 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/strcat.h"
using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
op_to_reset->operation->Clear();
status->status =
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
} else {
@ -525,7 +529,11 @@ void TFE_DeleteCancellationManager(
void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
status->status = op->operation->SetCancellationManager(cancellation_manager);
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation);
operation->SetCancellationManager(
&cancellation_manager->cancellation_manager);
status->status = tensorflow::Status::OK();
}
TFE_Executor* TFE_NewExecutor(bool is_async) {
@ -574,12 +582,6 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
};
}
void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
TF_Status* status) {
h->handle->EnableImplicitMirroring();
status->status = tensorflow::Status::OK();
}
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context =
@ -600,3 +602,33 @@ void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
};
status->status = tensorflow::Status::OK();
}
TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
const int64_t* dims, int num_dims,
TF_Status* status) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
if (ctx == nullptr || ctx->context == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid Context");
return nullptr;
}
tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor(
static_cast<tensorflow::DataType>(dtype), dimvec);
if (t == nullptr) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype);
return nullptr;
}
return new TF_Tensor{t};
}
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
TF_Status* status) {
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)};
}

View File

@ -392,12 +392,6 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status);
// If the TensorHandle is copied to another device as part of an op execution,
// the copy is destroyed after the op has executed. Enabling implicit mirroring
// causes the copy to be held as a mirror for the lifetime of the TensorHandle.
TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring(
TFE_TensorHandle*, TF_Status*);
// This function will block till the operation that produces `h` has
// completed. This is only valid on local TFE_TensorHandles. The pointer
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
@ -521,15 +515,34 @@ typedef struct TFE_CustomDevice {
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
TFE_CustomDevice device,
const char* device_name,
void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,
TF_Buffer* buf,
TF_Status* status);
// Allocate and return a new Tensor on the host.
//
// The caller must set the Tensor values by writing them to the pointer returned
// by TF_TensorData with length TF_TensorByteSize.
TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
TF_DataType dtype,
const int64_t* dims,
int num_dims,
TF_Status* status);
// Given a Tensor, wrap it with a TensorHandle
//
// Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context.
// The context should be identical to that of the Tensor.
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -21,9 +21,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@ -378,7 +378,7 @@ void Executor_MatMul_CPU(bool async) {
TFE_Executor* executor = TFE_NewExecutor(async);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
int num_retvals = 2;
@ -423,7 +423,7 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TF_Tensor* m_data = TFE_TensorHandleResolve(m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float* m_float = static_cast<float*>(TF_TensorData(m_data));

View File

@ -15,42 +15,20 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#include <algorithm>
#include <cstddef>
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_context_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_op_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" // IWYU pragma: export
// TODO(b/154564140): Move this to its own header. This requires splitting
// c_api_experimental.h
struct TFE_ContextOptions {
TF_SessionOptions session_options;
// true if async execution is enabled.
@ -64,181 +42,4 @@ struct TFE_ContextOptions {
bool use_tfrt = false;
};
struct TFE_Context {
std::unique_ptr<tensorflow::AbstractContextInterface> context;
};
struct TFE_TensorHandle {
std::unique_ptr<tensorflow::AbstractTensorHandleInterface> handle;
};
struct TFE_TensorDebugInfo {
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
: dev_dims(dims) {}
// Fully-padded, minor-to-major.
std::vector<tensorflow::int64> dev_dims;
};
struct TFE_Op {
std::unique_ptr<tensorflow::AbstractOperationInterface> operation;
};
struct TFE_MonitoringCounterCell {
tensorflow::monitoring::CounterCell cell;
};
template <int NumLabels>
struct TFE_MonitoringCounter {
template <typename... LabelDesc>
TFE_MonitoringCounter(const char* name, const char* description,
LabelDesc&&... label) {
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
};
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringIntGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
};
struct TFE_MonitoringStringGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
};
struct TFE_MonitoringBoolGaugeCell {
tensorflow::monitoring::GaugeCell<bool> cell;
};
template <typename ValueType, int NumLabels>
struct TFE_MonitoringGauge {
template <typename... LabelDesc>
TFE_MonitoringGauge(const char* name, const char* description,
LabelDesc&&... label) {
gauge = absl::WrapUnique(
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
};
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBuckets {
explicit TFE_MonitoringBuckets(
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
fn) {
create_buckets = fn;
}
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
create_buckets;
};
struct TFE_MonitoringSamplerCell {
tensorflow::monitoring::SamplerCell cell;
};
template <int NumLabels>
struct TFE_MonitoringSampler {
template <typename... LabelDesc>
TFE_MonitoringSampler(
const char* name,
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
const char* description, LabelDesc&&... label) {
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
{name, description, label...}, std::move(buckets)));
}
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
};
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
namespace tensorflow {
// Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
const char* attr_name, TF_Status* status);
} // namespace tensorflow
struct TFE_CancellationManager {
tensorflow::CancellationManager cancellation_manager;
};
struct TFE_Executor {
explicit TFE_Executor(bool async)
: owned_executor(new tensorflow::EagerExecutor(async)) {}
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
: owned_executor(nullptr), unowned_executor(executor) {}
tensorflow::EagerExecutor* executor() {
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
}
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
tensorflow::EagerExecutor* unowned_executor;
};
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
const char* op_name)
: name(op_name), attributes(value) {}
const char* name;
const tensorflow::AttrBuilder* attributes;
};
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/protobuf.h"
@ -74,8 +75,8 @@ void TestRemoteExecute(bool async) {
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
@ -128,7 +129,45 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }",
&def));
return def.SerializeAsString();
}
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -159,23 +198,45 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Handles are on task0 (local), and task2, but op is on task1.
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
TFE_Op* matmul = nullptr;
if (func) {
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
@ -183,12 +244,10 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!async) {
if (!remote && !async) {
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
// The input handles should never change since they have been mirrored.
ASSERT_EQ(op->GetInput(1), remote_arg);
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
@ -218,6 +277,9 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
@ -228,16 +290,22 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
}
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(false, true);
TestRemoteExecuteSilentCopies(false, true, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true, true);
TestRemoteExecuteSilentCopies(true, true, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopies(true, true, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(false, false);
TestRemoteExecuteSilentCopies(false, false, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(true, false);
TestRemoteExecuteSilentCopies(true, false, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopies(true, false, true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
@ -268,8 +336,8 @@ void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
// Use large matrices so that RPCs don't return before we get a chance
// to call TFE_DeleteContext.
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
@ -332,7 +400,7 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
const char* remote_device_name,
const char* local_device_name) {
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
TFE_OpSetDevice(matmul, remote_device_name, status);
@ -415,7 +483,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle.
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle();
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
// Check that copying it to the old remote device (named localhost) fails.
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);

View File

@ -19,16 +19,22 @@ limitations under the License.
#include <string>
// clang-format off
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
@ -47,7 +53,7 @@ void BM_InitOp(int iters) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TFE_Op* matmul = MatMulOp(ctx, m, m);
@ -71,12 +77,19 @@ void BM_Execute(int iters, int async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TFE_OpReset(matmul, "MatMul", nullptr, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
@ -106,12 +119,16 @@ void BM_Execute_Identity(int iters, int async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* identity = IdentityOp(ctx, m);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* identity = TFE_NewOp(ctx, "Identity", status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TFE_OpReset(identity, "Identity", nullptr, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(identity, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(identity, &retvals[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
@ -153,11 +170,16 @@ TEST(CAPI, Context) {
}
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
ASSERT_EQ(16, TF_TensorByteSize(t));
float data[4] = {0};
@ -168,6 +190,7 @@ TEST(CAPI, TensorHandle) {
EXPECT_EQ(4.0, data[3]);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(h);
TFE_DeleteContext(ctx);
}
void TensorHandleCopyBetweenDevices(bool async) {
@ -179,7 +202,7 @@ void TensorHandleCopyBetweenDevices(bool async) {
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -255,7 +278,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) {
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
const char* kErrorDevice = "NoSuchDevice:0";
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get());
@ -296,7 +319,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -382,7 +405,7 @@ void TensorHandleSilentCopy(bool async,
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
@ -393,6 +416,11 @@ void TensorHandleSilentCopy(bool async,
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle);
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
if (cpu_op) {
string cpu_device_name;
@ -408,15 +436,8 @@ void TensorHandleSilentCopy(bool async,
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Validate if the input was replaced with a different TensorHandle
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
// The input handles should never change since they have been mirrored.
EXPECT_EQ(op->GetInput(0), arg0);
EXPECT_EQ(op->GetInput(1), arg1);
// The CPU handle should have been copied and have a mirror on the GPU
ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device));
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
@ -455,7 +476,7 @@ void SetAndGetOpDevices(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
// Disable the test if no GPU is present.
@ -487,40 +508,35 @@ TEST(CAPI, TensorHandleNullptr) {
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(t, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
const char* device_name = TFE_TensorHandleDeviceName(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_name, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
device_name = TFE_TensorHandleBackingDeviceName(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_name, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int num_dims = TFE_TensorHandleNumDims(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(num_dims, -1);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int dim = TFE_TensorHandleDim(h, 0, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(dim, -1);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
}
TEST(CAPI, TensorHandleDevices) {
@ -531,7 +547,7 @@ TEST(CAPI, TensorHandleDevices) {
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
@ -581,15 +597,16 @@ TEST(CAPI, TensorHandleDevices) {
TFE_DeleteContext(ctx);
}
void ExecuteAdd(bool async, bool forward_input) {
void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, tfrt);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
TFE_TensorHandle* n = TestMatrixTensorHandle100x100(ctx);
// If a GPU exists, copy the handle to GPU so that we can exercise
// unprotecting a mirror.
std::string gpu_device_name;
@ -597,12 +614,11 @@ void ExecuteAdd(bool async, bool forward_input) {
TFE_TensorHandle* n_gpu =
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
TFE_DeleteTensorHandle(n);
n = n_gpu;
}
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
TFE_TensorHandle* m = TestMatrixTensorHandle100x100(ctx);
// Store pointer to raw buffer for validation of forwarding behaviour.
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
@ -619,17 +635,6 @@ void ExecuteAdd(bool async, bool forward_input) {
}
int num_retvals = 1;
if (async) {
// Enqueue dummy ops so we backlog async execution & actually test async.
for (int i = 0; i < 10000; ++i) {
TFE_TensorHandle* dummy = nullptr;
TFE_Execute(add_op, &dummy, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(dummy);
}
}
TFE_TensorHandle* retval = nullptr;
TFE_Execute(add_op, &retval, &num_retvals, status);
EXPECT_EQ(1, num_retvals);
@ -649,7 +654,6 @@ void ExecuteAdd(bool async, bool forward_input) {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float result[100 * 100] = {0};
@ -659,12 +663,42 @@ void ExecuteAdd(bool async, bool forward_input) {
for (int i = 0; i < 100 * 100; ++i) {
EXPECT_EQ(2.0f, result[i]);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
TEST(CAPI, ExecuteAdd) {
ExecuteAdd(
/*async=*/false,
/*forward_input*/ false,
/*tfrt*/ false);
}
TEST(CAPI, ExecuteAddAsync) {
ExecuteAdd(
/*async=*/true,
/*forward_input*/ false,
/*tfrt*/ false);
}
TEST(CAPI, ExecuteAddForward) {
ExecuteAdd(
/*async=*/false,
/*forward_input*/ true,
/*tfrt*/ false);
}
TEST(CAPI, ExecuteAddForwardAsync) {
ExecuteAdd(
/*async=*/true,
/*forward_input*/ true,
/*tfrt*/ false);
}
#ifdef PLATFORM_GOOGLE
// TODO(b/153349425): Add add forwarding tests for TFRT
TEST(CAPI, ExecuteAddTfrt) {
ExecuteAdd(
/*async=*/false,
/*forward_input*/ false,
/*tfrt*/ true);
}
#endif
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
@ -674,7 +708,7 @@ void Execute_MatMul_CPU(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
int num_retvals = 2;
@ -710,8 +744,8 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m1 = TestMatrixTensorHandle();
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2();
TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(ctx);
TFE_Op* matmul = MatMulOp(ctx, m1, m2);
TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0",
status);
@ -782,8 +816,8 @@ void Execute_MatMul_CPU_Type_Error(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m1 = TestMatrixTensorHandle();
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle();
TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m1, m2);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
@ -812,8 +846,8 @@ TEST(CAPI, Execute_Min_CPU) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
TFE_Op* minOp = MinOp(ctx, input, axis);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
@ -847,7 +881,7 @@ void Execute_MatMul_XLA_CPU(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_OpSetXLACompilation(matmul, true);
@ -889,8 +923,8 @@ void Execute_Min_XLA_CPU(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
TFE_Op* minOp = MinOp(ctx, input, axis);
TFE_OpSetXLACompilation(minOp, true);
@ -930,7 +964,7 @@ void ExecuteWithTracing(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
@ -1016,7 +1050,7 @@ void FunctionDefAndExecute(bool async) {
if (clear_cache) {
TFE_ContextClearCaches(ctx);
}
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* retval[1] = {nullptr};
int num_retvals = 1;
TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
@ -1065,7 +1099,7 @@ void BM_ExecuteFunction(int iters, int async) {
status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
@ -1280,11 +1314,15 @@ TEST(CAPI, StringAttributes) {
}
TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
TFE_TensorHandle* h_shares_tensor =
TFE_TensorHandleCopySharingTensor(h, status.get());
@ -1302,13 +1340,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
TFE_DeleteTensorHandle(h);
TFE_DeleteTensorHandle(h_shares_tensor);
TFE_DeleteContext(ctx);
}
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
tensorflow::AttrValueMap attr_values;
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
->Attrs()
.FillAttrValueMap(&attr_values);
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation);
operation->Attrs().FillAttrValueMap(&attr_values);
return attr_values;
}
@ -1319,8 +1358,8 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
TFE_Op* minOp = TFE_NewOp(ctx, "Min", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(minOp, input, status);
@ -1356,9 +1395,9 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_TensorHandle* dim = TestScalarTensorHandle(0);
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
@ -1396,9 +1435,9 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* condition = TestScalarTensorHandle(true);
TFE_TensorHandle* t1 = TestMatrixTensorHandle();
TFE_TensorHandle* t2 = TestAxisTensorHandle();
TFE_TensorHandle* condition = TestScalarTensorHandle(ctx, true);
TFE_TensorHandle* t1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* t2 = TestAxisTensorHandle(ctx);
TFE_Op* assertOp = TFE_NewOp(ctx, "Assert", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(assertOp, condition, status);
@ -1435,9 +1474,9 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_TensorHandle* dim = TestScalarTensorHandle(0);
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
@ -1470,8 +1509,8 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -1518,8 +1557,8 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
@ -1563,8 +1602,8 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
copy_op->operation.get());
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(copy_op->operation);
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
@ -1599,26 +1638,26 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
name_and_attrs.attr().find("dtype")->second.type());
TF_DeleteBuffer(serialized_attr_values);
TFE_Op* second_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_Op* var_op_2 = TFE_NewOp(ctx, "VarHandleOp", status);
string serialized_dtype;
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
&serialized_dtype));
TFE_OpSetAttrValueProto(
second_var_op, "dtype",
var_op_2, "dtype",
reinterpret_cast<const void*>(serialized_dtype.c_str()),
serialized_dtype.length(), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
second_var_op->operation.get());
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(var_op_2->operation);
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(var_op);
TFE_DeleteOp(second_var_op);
TFE_DeleteOp(var_op_2);
TFE_DeleteContext(ctx);
}

View File

@ -16,115 +16,117 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
using tensorflow::string;
TFE_TensorHandle* TestScalarTensorHandle(float value) {
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
float data[] = {value};
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(int value) {
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
int data[] = {value};
TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, sizeof(int));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(bool value) {
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value) {
bool data[] = {value};
TF_Tensor* t = TF_AllocateTensor(TF_BOOL, nullptr, 0, sizeof(bool));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_BOOL, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx) {
int64_t dims[] = {2, 2};
double data[] = {1.0, 2.0, 3.0, 4.0};
TF_Tensor* t = TF_AllocateTensor(
TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_DOUBLE, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle() {
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) {
int64_t dims[] = {2, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle100x100() {
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
constexpr int64_t dims[] = {100, 100};
constexpr int num_elements = dims[0] * dims[1];
float data[num_elements];
for (int i = 0; i < num_elements; ++i) {
data[i] = 1.0f;
}
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx) {
int64_t dims[] = {3, 2};
double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle3X2() {
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
int64_t dims[] = {3, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
@ -187,14 +189,14 @@ TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
return op;
}
TFE_TensorHandle* TestAxisTensorHandle() {
TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx) {
int64_t dims[] = {1};
int data[] = {1};
TF_Tensor* t = TF_AllocateTensor(
TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);

View File

@ -19,28 +19,28 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
// Return a tensor handle containing a float scalar
TFE_TensorHandle* TestScalarTensorHandle(float value);
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
// Return a tensor handle containing a int scalar
TFE_TensorHandle* TestScalarTensorHandle(int value);
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
// Return a tensor handle containing a bool scalar
TFE_TensorHandle* TestScalarTensorHandle(bool value);
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
// Return a tensor handle containing a 2x2 matrix of doubles
TFE_TensorHandle* DoubleTestMatrixTensorHandle();
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
// Return a tensor handle containing a 2x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle();
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx);
// Return a tensor handle containing a 100x100 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle100x100();
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);
// Return a tensor handle containing a 3x2 matrix of doubles
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
// Return a tensor handle containing a 3x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle3X2();
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
// Return an add op multiplying `a` by `b`.
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
@ -55,7 +55,7 @@ TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a);
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
// Return an 1-D INT32 tensor containing a single value 1.
TFE_TensorHandle* TestAxisTensorHandle();
TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx);
// Return an op taking minimum of `input` long `axis` dimension.
TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,

View File

@ -24,9 +24,9 @@ limitations under the License.
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/strcat.h"
using tensorflow::string;
@ -38,96 +38,159 @@ typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_ExecutionContext* ctx,
TF_Status* s);
struct TF_ExecutionContext {
explicit TF_ExecutionContext() {}
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
ExecuteOperation execution_callback;
};
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum ExecutionContextKind { GraphContext, EagerContext };
explicit TF_ExecutionContext(ExecutionContextKind kind) : k(kind) {}
ExecutionContextKind getKind() const { return k; }
struct TF_AbstractTensor {
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
};
virtual void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_Status* s) = 0;
virtual TF_AbstractOp* CreateOperation() = 0;
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
virtual ~TF_ExecutionContext() {}
struct TF_AbstractOp {
string op_type;
string op_name;
private:
const ExecutionContextKind k;
};
TF_ExecutionContext* TF_NewExecutionContext() {
return new TF_ExecutionContext();
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
TF_AbstractOp* TF_NewAbstractOp() {
TF_AbstractOp* op = new TF_AbstractOp;
return op;
template <typename T, typename S>
T* dynamic_cast_helper(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
TF_AbstractTensor* TF_NewAbstractTensor() {
TF_AbstractTensor* t = new TF_AbstractTensor;
return t;
}
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
struct TF_GraphContext {
TF_Graph* graph;
// TODO(srbs): Handle captures.
};
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
auto ctx = new TF_GraphContext;
ctx->graph = g;
return ctx;
}
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
class TF_GraphContext;
class TF_EagerContext;
struct TF_GraphTensor {
TF_Output output;
TF_GraphContext* ctx;
};
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
TF_Status* s) {
TF_GraphTensor* t = new TF_GraphTensor;
t->output = output;
t->ctx = ctx;
return t;
}
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
return t->output;
}
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s) {
at->t = t;
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
struct TF_AbstractTensor {
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
~TF_AbstractTensor() {
if (absl::holds_alternative<TFE_TensorHandle*>(t)) {
TFE_DeleteTensorHandle(absl::get<TFE_TensorHandle*>(t));
} else if (absl::holds_alternative<TF_GraphTensor*>(t)) {
delete absl::get<TF_GraphTensor*>(t);
}
}
return absl::get<TFE_TensorHandle*>(at->t);
};
struct TF_AbstractOp {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum AbstractOpKind { GraphOp, EagerOp };
explicit TF_AbstractOp(AbstractOpKind kind) : k(kind) {}
AbstractOpKind getKind() const { return k; }
virtual void SetOpType(const char* const op_type, TF_Status* s) = 0;
virtual void SetOpName(const char* const op_name, TF_Status* s) = 0;
virtual void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) = 0;
virtual ~TF_AbstractOp() {}
private:
const AbstractOpKind k;
};
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
return c->CreateOperation();
}
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s) {
at->t = t;
}
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
string msg = absl::StrCat("Not an graph tensor handle.");
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
class TF_GraphOp : public TF_AbstractOp {
public:
explicit TF_GraphOp(TF_Graph* g) : TF_AbstractOp(kKind), g_(g) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
absl::StrCat("SetOpType called on already built op.").c_str());
return;
}
if (op_name_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type, op_name_));
op_name_ = nullptr;
} else {
op_type_ = op_type;
}
}
return absl::get<TF_GraphTensor*>(at->t);
}
void SetOpName(const char* const op_name, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
absl::StrCat("SetOpName called on already built op.").c_str());
return;
}
if (op_type_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type_, op_name));
op_type_ = nullptr;
} else {
op_name_ = op_name;
}
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (!op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
"op_type and op_name must be specified before specifying attrs.");
return;
}
TF_SetAttrType(op_.get(), attr_name, value);
}
~TF_GraphOp() override {}
static constexpr AbstractOpKind kKind = GraphOp;
private:
friend class TF_GraphContext; // For access to op_.
TF_Graph* g_;
std::unique_ptr<TF_OperationDescription> op_;
// Hold `op_type` and `op_name` till both are available since we need both
// to build a graph operation.
const char* op_type_ = nullptr;
const char* op_name_ = nullptr;
};
class TF_EagerOp : public TF_AbstractOp {
public:
explicit TF_EagerOp(TFE_Context* ctx) : TF_AbstractOp(kKind), ctx_(ctx) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
op_ = TFE_NewOp(ctx_, op_type, s);
}
void SetOpName(const char* const op_name, TF_Status* s) override {
// Name is ignored in eager mode.
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (op_ == nullptr) {
TF_SetStatus(s, TF_FAILED_PRECONDITION,
"op_type must be specified before specifying attrs.");
return;
}
TFE_OpSetAttrType(op_, attr_name, value);
}
~TF_EagerOp() override { TFE_DeleteOp(op_); }
static constexpr AbstractOpKind kKind = EagerOp;
private:
friend class TF_EagerContext; // For access to op_.
TFE_Op* op_ = nullptr;
TFE_Context* ctx_;
};
bool IsEagerTensor(const TF_AbstractTensor* const t) {
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
@ -138,6 +201,221 @@ struct TF_OutputList {
int expected_num_outputs = -1;
};
struct TF_AbstractFunction {
TF_Function* func = nullptr;
~TF_AbstractFunction() { TF_DeleteFunction(func); }
};
class TF_EagerContext : public TF_ExecutionContext {
public:
TF_EagerContext() : TF_ExecutionContext(kKind) {}
void Build(TFE_ContextOptions* options, TF_Status* status) {
eager_ctx_ = TFE_NewContext(options, status);
}
TF_AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new TF_EagerOp(eager_ctx_);
}
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_Status* s) override {
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
if (eager_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast TF_AbstractOp to TF_EagerOp.");
return;
}
auto* tfe_op = eager_op->op_;
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
if (!IsEagerTensor(inputs[i])) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
auto* t = new TF_AbstractTensor();
t->t = retvals[i];
o->outputs.push_back(t);
}
}
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
TFE_ContextAddFunction(eager_ctx_, func->func, s);
}
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
static constexpr ExecutionContextKind kKind = EagerContext;
private:
friend TFE_Context* TF_ExecutionContextGetTFEContext(
TF_ExecutionContext* ctx);
TFE_Context* eager_ctx_;
};
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
return absl::get<TF_GraphTensor*>(t->t)->ctx;
}
class TF_GraphContext : public TF_ExecutionContext {
public:
TF_GraphContext()
: TF_ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
TF_AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new TF_GraphOp(graph_.get());
}
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_Status* s) override {
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
if (graph_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast TF_AbstractOp to TF_GraphOp.");
return;
}
auto* tf_opdesc = graph_op->op_.release();
for (int i = 0; i < num_inputs; ++i) {
auto* input = inputs[i];
if (IsEagerTensor(input)) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (GetGraphContext(input) != this) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
graph_op->op_ = nullptr;
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
auto* t = new TF_AbstractTensor;
TF_GraphTensor* graph_t = new TF_GraphTensor;
graph_t->ctx = this;
graph_t->output = {operation, i};
t->t = graph_t;
o->outputs.push_back(t);
}
}
TF_Function* ToFunction(const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs,
TF_Status* status) const {
std::vector<TF_Output> graph_inputs;
graph_inputs.resize(num_inputs);
std::vector<TF_Output> graph_outputs;
graph_outputs.resize(num_outputs);
for (int i = 0; i < num_inputs; i++) {
graph_inputs[i] = absl::get<TF_GraphTensor*>(inputs[i].t)->output;
}
for (int i = 0; i < num_outputs; i++) {
graph_outputs[i] = absl::get<TF_GraphTensor*>(outputs[i].t)->output;
}
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
graph_inputs.size(), graph_inputs.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, fn_name, status);
}
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Registering graph functions has not been implemented yet.");
}
~TF_GraphContext() override {}
static constexpr ExecutionContextKind kKind = GraphContext;
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
};
struct TF_GraphContextOptions {};
struct TF_EagerContextOptions {
explicit TF_EagerContextOptions(TFE_ContextOptions* options)
: options(options) {}
TFE_ContextOptions* options; // Not owned.
};
struct TF_ExecutionContextOptions {
absl::variant<TF_GraphContextOptions*, TF_EagerContextOptions*> options;
~TF_ExecutionContextOptions() {
if (absl::holds_alternative<TF_GraphContextOptions*>(options)) {
delete absl::get<TF_GraphContextOptions*>(options);
} else if (absl::holds_alternative<TF_EagerContextOptions*>(options)) {
delete absl::get<TF_EagerContextOptions*>(options);
}
}
};
TF_ExecutionContextOptions* TF_NewGraphContextOptions() {
auto* options = new TF_ExecutionContextOptions();
options->options = new TF_GraphContextOptions();
return options;
}
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions* options) {
delete options;
}
TF_ExecutionContextOptions* TF_NewEagerContextOptions(
TFE_ContextOptions* tfe_options) {
auto* options = new TF_ExecutionContextOptions();
options->options = new TF_EagerContextOptions(tfe_options);
return options;
}
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions* options,
TF_Status* s) {
if (absl::holds_alternative<TF_EagerContextOptions*>(options->options)) {
auto* ctx = new TF_EagerContext();
ctx->Build(absl::get<TF_EagerContextOptions*>(options->options)->options,
s);
return ctx;
} else {
return new TF_GraphContext();
}
}
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
@ -149,113 +427,74 @@ TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return o->outputs[i];
}
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
auto* tfe_op =
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
if (!IsEagerTensor(inputs[i])) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
TFE_DeleteOp(tfe_op);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
auto* t = TF_NewAbstractTensor();
t->t = retvals[i];
o->outputs.push_back(t);
}
}
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
return absl::get<TF_GraphTensor*>(t->t)->ctx;
}
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
TF_Graph* g = graph_ctx->graph;
auto* tf_opdesc =
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
for (int i = 0; i < num_inputs; ++i) {
auto* input = inputs[i];
if (IsEagerTensor(input)) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (GetGraphContext(input) != graph_ctx) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
auto* t = TF_NewAbstractTensor();
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
t->t = output_t;
o->outputs.push_back(t);
}
}
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context,
TF_Status* s) {
context->ctx = eager_context;
context->execution_callback = &ExecuteOperationEager;
}
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status* s) {
context->ctx = graph_context;
context->execution_callback = &ExecuteOperationGraph;
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {
op->op_type = op_type;
op->SetOpType(op_type, s);
}
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s) {
op->op_name = op_name;
op->SetOpName(op_name, s);
}
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s) {
op->SetAttrType(attr_name, value, s);
}
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
ctx->ExecuteOperation(op, num_inputs, inputs, o, s);
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(fn_body);
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
return nullptr;
}
TF_AbstractFunction* func = new TF_AbstractFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, inputs, num_outputs,
outputs, status);
return func;
}
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { delete func; }
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
TF_AbstractFunction* func,
TF_Status* s) {
ctx->RegisterFunction(func, s);
}
// Temporary APIs till we figure out how to create scalar valued Eager
// tensors and how to get value out of eager abstract tensors.
TF_AbstractTensor* TF_NewAbstractTensor() {
TF_AbstractTensor* t = new TF_AbstractTensor;
return t;
}
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s) {
at->t = t;
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TFE_TensorHandle*>(at->t);
}
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
return dynamic_cast_helper<TF_EagerContext>(ctx)->eager_ctx_;
}

View File

@ -15,8 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status.h"
#ifdef __cplusplus
extern "C" {
@ -41,32 +41,19 @@ typedef struct TF_AbstractTensor TF_AbstractTensor;
// could contain the op type and other attributes.
typedef struct TF_AbstractOp TF_AbstractOp;
TF_ExecutionContext* TF_NewExecutionContext();
// `TF_ExecutionContextOptions` define what type of `TF_ExecutionContext` is
// created. It can be used to pass context specific params.
typedef struct TF_ExecutionContextOptions TF_ExecutionContextOptions;
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions*);
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
TF_AbstractOp* TF_NewAbstractOp();
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
void TF_DeleteAbstractOp(TF_AbstractOp*);
TF_AbstractTensor* TF_NewAbstractTensor();
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// -----------------------------------------------------------------------------
// APIs for Eager and graph modes
// -----------------------------------------------------------------------------
// Keeps track of the current graph and other state e.g. captures etc.
typedef struct TF_GraphContext TF_GraphContext;
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
void TF_DeleteGraphContext(TF_GraphContext*);
// `eager_context` must outlive `context`.
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context, TF_Status*);
// `graph_context` must outlive `context`.
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status*);
// TODO(srbs): Add APIs for specifying attrs etc.
// `op_type` must outlive `op`.
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
@ -74,25 +61,9 @@ void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
// `op_name` must outlive `op`.
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s);
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
typedef struct TF_GraphTensor TF_GraphTensor;
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
TF_Status* s);
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
void TF_DeleteGraphTensor(TF_GraphTensor* t);
// `t` must outlive `at`.
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s);
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
// `t` must outlive `at`.
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s);
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s);
// `attr_name` must outlive `op`.
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s);
// TF_OutputList just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
@ -104,6 +75,17 @@ void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
int TF_OutputListNumOutputs(TF_OutputList* o);
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// Stores a function representation that can be used for execution or for
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
TF_AbstractFunction*, TF_Status*);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph, and after
// execution/node creation it'll go and record things that happened in any tape
@ -112,6 +94,23 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s);
// -----------------------------------------------------------------------------
// APIs specific to Eager and graph modes
// -----------------------------------------------------------------------------
TF_ExecutionContextOptions* TF_NewGraphContextOptions();
TF_ExecutionContextOptions* TF_NewEagerContextOptions(TFE_ContextOptions*);
// Temporary APIs till we figure out how to create scalar valued Eager
// tensors and how to get value out of eager abstract tensors.
TF_AbstractTensor* TF_NewAbstractTensor();
void TF_AbstractTensorSetEagerTensor(
TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s); // `at` takes ownership of `t`.
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -21,9 +21,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@ -33,26 +33,25 @@ namespace tensorflow {
namespace {
TEST(UnifedCAPI, TestBasicEager) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* t = TestScalarTensorHandle(2.0f);
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -69,7 +68,6 @@ TEST(UnifedCAPI, TestBasicEager) {
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TFE_DeleteTensorHandle(t);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
@ -83,100 +81,98 @@ TEST(UnifedCAPI, TestBasicEager) {
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TFE_DeleteTensorHandle(result_t);
TF_DeleteOutputList(o);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TestBasicGraph) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
// Enter a graph context.
TF_Graph* g = TF_NewGraph();
TF_GraphContext* graph_context = TF_NewGraphContext(g);
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
auto* operation = TF_FinishOperation(placeholder_op, status.get());
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output placeholder_t = {operation, 0};
TF_GraphTensor* graph_t =
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* t = TF_NewAbstractTensor();
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(op, "my_add", status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {t, t};
TF_OutputList* o = TF_NewOutputList();
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* add_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(t);
TF_DeleteGraphTensor(graph_t);
TF_DeleteAbstractOp(add_op);
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TF_GraphTensor* result_graph_tensor =
TF_AbstractTensorGetGraphTensor(result, status.get());
TF_DeleteAbstractTensor(result);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output result_output =
TF_GraphTensorToOutput(result_graph_tensor, status.get());
TF_DeleteGraphTensor(result_graph_tensor);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
string fn_name = "double";
TF_Function* f = TF_GraphToFunction(
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
nullptr, nullptr, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractTensor(output_t);
// Build an eager context to run the function.
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
TF_ExecutionContextOptions* eager_ctx_options =
TF_NewEagerContextOptions(opts);
TF_ExecutionContext* eager_execution_ctx =
TF_NewExecutionContext(eager_ctx_options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Build the abstract op to run the function.
TFE_ContextAddFunction(eager_ctx, f, status.get());
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOp* fn_op = TF_NewAbstractOp();
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f);
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
TF_AbstractTensor* final_result = TF_OutputListGet(add_outputs, 0);
TFE_TensorHandle* final =
TF_AbstractTensorGetEagerTensor(final_result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -185,19 +181,201 @@ TEST(UnifedCAPI, TestBasicGraph) {
float* f_value = static_cast<float*>(TF_TensorData(f_t));
ASSERT_EQ(*f_value, 4.0);
TF_DeleteOutputList(o);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TFE_DeleteTensorHandle(input_eager);
TF_DeleteAbstractTensor(final_result);
TFE_DeleteTensorHandle(final);
TF_DeleteTensor(f_t);
TF_DeleteFunction(f);
TF_DeleteAbstractFunction(func);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteExecutionContextOptions(eager_ctx_options);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
ASSERT_EQ(nullptr, func);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
TF_DeleteGraphContext(graph_context);
TF_DeleteGraph(g);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// This should fail.
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
TF_DeleteAbstractOp(placeholder_op);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// This should fail.
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
TF_DeleteAbstractOp(placeholder_op);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an Eager context.
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an Eager operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at, at};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build a Graph context.
TF_ExecutionContextOptions* graph_options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(graph_options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute eager op using graph context.
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(graph_options);
}
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* add_outputs = TF_NewOutputList();
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* eager_ctx_options =
TF_NewEagerContextOptions(opts);
TF_ExecutionContext* eager_execution_ctx =
TF_NewExecutionContext(eager_ctx_options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
// Clean up operation and inputs.
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractOp(add_op);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteExecutionContextOptions(eager_ctx_options);
TF_DeleteExecutionContextOptions(options);
}
} // namespace

View File

@ -1,143 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/platform/casts.h"
namespace tensorflow {
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt64Scalar(
int64 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateUint64Scalar(
uint64 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt32Scalar(
int32 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateFloatScalar(
float value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateDoubleScalar(
double value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateHalfScalar(
Eigen::half value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateStringScalar(
tstring value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface>
ContextInterface::CreateComplex128Scalar(complex128 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolScalar(
bool value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt64Tensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_INT64, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateUint64Tensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_UINT64, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt32Tensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_INT32, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateFloatTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_FLOAT, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateDoubleTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_DOUBLE, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateHalfTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_HALF, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateStringTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_STRING, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface>
ContextInterface::CreateComplex128Tensor(absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_COMPLEX128, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_BOOL, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorHandleInterface>
ContextInterface::CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t) {
Tensor tensor = tensorflow::down_cast<TensorInterface*>(t.get())->Tensor();
return std::make_unique<TensorHandleInterface>(
TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/ctx_->HostCPU(),
/*op_device=*/nullptr, ctx_));
}
std::unique_ptr<AbstractOperationInterface>
ContextInterface::CreateOperation() {
return std::make_unique<tensorflow::OperationInterface>(ctx_);
}
void ContextInterface::ListDevices(
std::vector<tensorflow::DeviceAttributes>* devices) {
ctx_->ListDevices(devices);
}
} // namespace tensorflow

View File

@ -15,13 +15,15 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
namespace tensorflow {
@ -32,123 +34,47 @@ namespace tensorflow {
// TensorHandles & Operations.
class AbstractContextInterface {
public:
virtual ~AbstractContextInterface() {}
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
virtual void Release() = 0;
// Scalar creation functions
virtual std::unique_ptr<AbstractTensorInterface> CreateInt64Scalar(
int64 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateUint64Scalar(
uint64 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateInt32Scalar(
int32 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateFloatScalar(
float value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateDoubleScalar(
double value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateHalfScalar(
Eigen::half value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateStringScalar(
tstring value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateComplex128Scalar(
complex128 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateBoolScalar(
bool value) = 0;
// Optimized scalar creation functions
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
virtual AbstractTensorInterface* CreateInt32Scalar(int32 value) = 0;
virtual AbstractTensorInterface* CreateFloatScalar(float value) = 0;
virtual AbstractTensorInterface* CreateDoubleScalar(double value) = 0;
virtual AbstractTensorInterface* CreateHalfScalar(Eigen::half value) = 0;
virtual AbstractTensorInterface* CreateStringScalar(tstring value) = 0;
virtual AbstractTensorInterface* CreateComplex128Scalar(complex128 value) = 0;
virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0;
// Tensor creation functions
virtual std::unique_ptr<AbstractTensorInterface> CreateInt64Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateUint64Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateInt32Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateFloatTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateDoubleTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateHalfTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateStringTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateTensor(
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
// Create a handle to wrap and manage a Tensor
virtual std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t) = 0;
virtual AbstractTensorHandleInterface* CreateLocalHandle(
AbstractTensorInterface* t) = 0;
// Copy the handle to another device.
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
AbstractTensorHandleInterface* handle, const char* device_name,
Status* status) = 0;
// Create an operation to perform op execution
virtual std::unique_ptr<AbstractOperationInterface> CreateOperation() = 0;
virtual AbstractOperationInterface* CreateOperation() = 0;
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
protected:
virtual ~AbstractContextInterface() {}
};
// TODO(gjn): Try to move these all to EagerContext and make it implement
// AbstractContextInterface. Currently, this is not so straightforward because
// of various BUILD file dependencies.
class ContextInterface : public AbstractContextInterface {
public:
explicit ContextInterface(EagerContext* ctx) : ctx_(ctx) {}
~ContextInterface() override {}
std::unique_ptr<AbstractTensorInterface> CreateInt64Scalar(
int64 value) override;
std::unique_ptr<AbstractTensorInterface> CreateUint64Scalar(
uint64 value) override;
std::unique_ptr<AbstractTensorInterface> CreateInt32Scalar(
int32 value) override;
std::unique_ptr<AbstractTensorInterface> CreateFloatScalar(
float value) override;
std::unique_ptr<AbstractTensorInterface> CreateDoubleScalar(
double value) override;
std::unique_ptr<AbstractTensorInterface> CreateHalfScalar(
Eigen::half value) override;
std::unique_ptr<AbstractTensorInterface> CreateStringScalar(
tensorflow::tstring value) override;
std::unique_ptr<AbstractTensorInterface> CreateComplex128Scalar(
tensorflow::complex128 value) override;
std::unique_ptr<AbstractTensorInterface> CreateBoolScalar(
bool value) override;
std::unique_ptr<AbstractTensorInterface> CreateInt64Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateUint64Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateInt32Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateFloatTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateDoubleTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateHalfTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateStringTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t) override;
std::unique_ptr<AbstractOperationInterface> CreateOperation() override;
void ListDevices(std::vector<DeviceAttributes>* devices) override;
// For runtime specific APIs, provide ability to get the underlying context.
EagerContext* Context() const { return ctx_; }
private:
EagerContext* ctx_;
};
inline EagerContext* ContextFromInterface(
const std::unique_ptr<AbstractContextInterface>& context) {
return down_cast<ContextInterface*>(context.get())->Context();
}
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_

View File

@ -16,134 +16,16 @@ limitations under the License.
// A simple logging device to test custom device registration.
#include <memory>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/custom_device_testutil.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -156,7 +38,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
@ -245,7 +127,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
UnpackTensorHandle(var_value, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
@ -296,7 +176,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
@ -331,7 +211,7 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
@ -346,16 +226,21 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpAddInput(op.get(), var_handle, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
<< "Execution should fail because the variable is being used on the "
"wrong device.";
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
ASSERT_EQ(
tensorflow::string(name),
tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
TFE_DeleteTensorHandle(var_value);
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
@ -366,6 +251,79 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, InputBasedPlacement) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
ASSERT_FALSE(arrived);
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
status.get()),
TFE_DeleteTensorHandle);
ASSERT_TRUE(arrived);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
arrived = false;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
status.get()),
TFE_DeleteTensorHandle);
ASSERT_TRUE(arrived);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Base case: two CPU inputs executes fine.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
TFE_TensorHandle* retval;
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteTensorHandle(retval);
// Custom device: inputs in same custom device works.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
// Custom device: inputs in different custom devices fails.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
// Custom device: mix of custom/physical fails.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -394,5 +352,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}
} // namespace

View File

@ -0,0 +1,172 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
} // namespace
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status) {
return reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
->tensor;
}
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info) {
TFE_CustomDevice* custom_device = new TFE_CustomDevice;
custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device->delete_device = &DeleteLoggingDevice;
custom_device->execute = &LoggingDeviceExecute;
*device = custom_device;
LoggingDevice* logging_device = new LoggingDevice;
logging_device->arrived_flag = arrived_flag;
logging_device->executed_flag = executed_flag;
logging_device->device_name = name;
logging_device->underlying_device =
"/job:localhost/replica:0/task:0/device:CPU:0";
*device_info = reinterpret_cast<void*>(logging_device);
}

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info);
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status);
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/logging.h"
@ -40,9 +41,8 @@ struct TfDlManagedTensorCtx {
// Gets tensor from eager tensor handle.
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
@ -286,7 +286,8 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
return static_cast<void*>(dlm_tensor);
}
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
TFE_Context* ctx) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name =
@ -319,16 +320,10 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
return nullptr;
}
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
TFE_DeleteContext(ctx);
return handle;
}

View File

@ -30,7 +30,8 @@ TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status);
TF_Status* status,
TFE_Context* ctx);
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);

View File

@ -1,309 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/operation_interface.h"
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
OperationInterface::OperationInterface(EagerContext* ctx) : operation_(ctx) {}
const string& OperationInterface::DeviceName() const {
absl::variant<Device*, CustomDevice*> variant_device =
(operation_.Device() == kVariantDeviceNull)
? operation_.EagerContext().HostCPU()
: operation_.Device();
return absl::visit([](auto* d) -> const string& { return d->name(); },
variant_device);
}
Status OperationInterface::SetDeviceName(const char* name) {
return operation_.SetDeviceName(name);
}
Status OperationInterface::SetAttrString(const char* attr_name,
const char* data, size_t length) {
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
return Status::OK();
}
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
return Status::OK();
}
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
operation_.MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
operation_.MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status OperationInterface::SetAttrType(const char* attr_name,
TF_DataType value) {
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
return Status::OK();
}
Status OperationInterface::SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
operation_.MutableAttrs()->Set(attr_name, proto);
return Status::OK();
}
Status OperationInterface::SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) {
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(value->Name());
OperationInterface* value_operation =
tensorflow::down_cast<OperationInterface*>(value.get());
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
const char* data,
size_t length) {
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(data, length);
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
Status OperationInterface::SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) {
Tensor t;
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
operation_.MutableAttrs()->Set(attr_name, t);
return Status::OK();
}
Status OperationInterface::SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
operation_.MutableAttrs()->Set(attr_name, v);
return Status::OK();
}
Status OperationInterface::SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const float>(values, num_values));
return Status::OK();
}
Status OperationInterface::SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const DataType>(
reinterpret_cast<const DataType*>(values), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) {
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
auto value_operation =
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
funcs[i].set_name(value_operation->operation_.Name());
value_operation->operation_.Attrs().FillAttrValueMap(
funcs[i].mutable_attr());
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
return Status::OK();
}
const OpDef* OperationInterface::GetOpDef(Status* status) {
const tensorflow::OpDef* op_def = operation_.OpDef();
if (op_def) return op_def;
*status = OpDefForOp(Name(), &op_def);
return op_def;
}
Status OperationInterface::InputLength(const char* input_name, int* length) {
Status status;
const tensorflow::OpDef* op_def = GetOpDef(&status);
if (!status.ok()) {
return status;
}
AttrValueMap attrs;
operation_.Attrs().FillAttrValueMap(&attrs);
NameRangeMap name_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
return errors::InvalidArgument("Input '", input_name, "' not found");
}
*length = iter->second.second - iter->second.first;
return Status::OK();
}
Status OperationInterface::OutputLength(const char* output_name, int* length) {
Status status;
const tensorflow::OpDef* op_def = GetOpDef(&status);
if (!status.ok()) {
return status;
}
AttrValueMap attrs;
operation_.Attrs().FillAttrValueMap(&attrs);
NameRangeMap name_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
return errors::InvalidArgument("Output '", output_name, "' not found");
}
*length = iter->second.second - iter->second.first;
return Status::OK();
}
Status OperationInterface::AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
TensorHandle* h = TensorHandleFromInterface(input);
operation_.AddInput(h);
return operation_.MaybeInferSingleInputAttrs(h);
}
Status OperationInterface::AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) {
for (auto& input : inputs) {
TensorHandle* h = TensorHandleFromInterface(input);
operation_.AddInput(h);
}
return operation_.InferInputListAttrs(inputs.size());
}
Status OperationInterface::Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) {
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
TF_RETURN_IF_ERROR(
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
for (int i = 0; i < *num_retvals; ++i) {
retvals->at(i).reset(
new tensorflow::TensorHandleInterface(handle_retvals[i]));
}
return Status::OK();
}
Status OperationInterface::SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
operation_.SetCancellationManager(
&cancellation_manager->cancellation_manager);
return Status::OK();
}
Status OperationInterface::SetUseXla(bool enable) {
operation_.SetUseXla(enable);
return Status::OK();
}
} // namespace tensorflow

View File

@ -15,36 +15,62 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#include <memory>
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
struct TFE_Op;
namespace tensorflow {
// Abstract interface to an operation.
class AbstractOperationInterface {
public:
virtual ~AbstractOperationInterface() {}
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
virtual void Clear() = 0;
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const string& Name() const = 0;
// Returns the operation's device name.
//
// The value returned may be different from the one set by SetDeviceName, but
// it will be compatible with it: the name will be updated by device placement
// logic to refer to the specific device chosen.
//
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
// chosen for the operation by the device placement logic in the
// executor. After that, the value returned by DeviceName will be a full
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
virtual const string& DeviceName() const = 0;
// Sets the operation device name.
//
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
// the result will be used as a constraint for device placement. See the
// documentation for DeviceName for more details.
//
// The value will override the previous value - that is, no "merging" of
// existing and given constraints will be performed.
virtual Status SetDeviceName(const char* name) = 0;
virtual Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
virtual Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) = 0;
virtual Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) = 0;
absl::Span<AbstractTensorHandleInterface*> inputs) = 0;
virtual Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual Status SetAttrString(const char* attr_name, const char* data,
@ -52,15 +78,15 @@ class AbstractOperationInterface {
virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual Status SetAttrType(const char* attr_name, TF_DataType value) = 0;
virtual Status SetAttrType(const char* attr_name, DataType value) = 0;
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) = 0;
virtual Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
virtual Status SetAttrFunction(const char* attr_name,
const AbstractOperationInterface* value) = 0;
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) = 0;
virtual Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) = 0;
virtual Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) = 0;
virtual Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths, int num_values) = 0;
@ -68,103 +94,25 @@ class AbstractOperationInterface {
int num_values) = 0;
virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) = 0;
virtual Status SetAttrTypeList(const char* attr_name,
const TF_DataType* values, int num_values) = 0;
virtual Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) = 0;
virtual Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) = 0;
virtual Status SetAttrFunctionList(const char* attr_name,
const TFE_Op** value, int num_values) = 0;
virtual Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperationInterface*> values) = 0;
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual Status SetUseXla(bool enable) {
return errors::Unimplemented("SetUseXla not implemented");
}
virtual Status SetUseXla(bool enable) = 0;
virtual Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
return errors::Unimplemented("SetCancellationManager not implemented");
}
};
class OpDef;
class OperationInterface : public AbstractOperationInterface {
public:
explicit OperationInterface(EagerContext* ctx);
~OperationInterface() override{};
void Clear() override { operation_.Clear(); }
Status Reset(const char* op, const char* raw_device_name) override {
return operation_.Reset(op, raw_device_name, false, nullptr);
}
const string& Name() const override { return operation_.Name(); }
const string& DeviceName() const override;
Status SetDeviceName(const char* name) override;
Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) override;
Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) override;
const tensorflow::OpDef* OpDef() const override {
return operation_.OpDef();
};
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrInt(const char* attr_name, int64_t value) override;
Status SetAttrFloat(const char* attr_name, float value) override;
Status SetAttrBool(const char* attr_name, bool value) override;
Status SetAttrType(const char* attr_name, TF_DataType value) override;
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override;
Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) override;
Status SetAttrFunctionName(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override;
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override;
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
int num_values) override;
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override;
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override;
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
int num_values) override;
Status InputLength(const char* input_name, int* length) override;
Status OutputLength(const char* output_name, int* length) override;
Status SetUseXla(bool enable) override;
Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) override;
// TODO(gjn): Remove once TFE_InferShapes is removed
const AttrBuilder& Attrs() const { return operation_.Attrs(); }
AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
private:
const tensorflow::OpDef* GetOpDef(Status* status);
EagerOperation operation_;
protected:
virtual ~AbstractOperationInterface() {}
};
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "parallel_device",
srcs = ["parallel_device.cc"],
hdrs = ["parallel_device.h"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
tf_cc_test(
name = "parallel_device_test",
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,597 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <memory>
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace eager {
namespace {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
};
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Creates a vector of `count` new executors (threads).
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
std::vector<ExecutorPtr> executors;
executors.reserve(count);
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
// A representation of the custom device passed in and out of the TFE custom
// device APIs, providing context about the parallel device to
// ParallelDeviceExecute.
class ParallelDevice {
public:
ParallelDevice(const std::string& name,
const std::vector<std::string>& devices);
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
//
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
// tensor, but other operations will implicitly broadcast non-parallel input
// tensors across the ParallelDevice's component devices.
//
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
// causes `Execute` to return non-parallel tensors.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK.
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
// Implements the parallel case for `Execute`, where all of the outputs of the
// operation are ParallelTensors, and all inputs are either ParallelTensors or
// should be implicitly broadcast. This means the operation is not
// TPUReplicatedInput or TPUReplicatedOutput.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ExecuteParallelOperation(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name,
const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
const std::string& device_name() const { return device_name_; }
private:
// The name of the parallel device
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
const std::string device_name_;
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
};
// The internal representation of a TFE_TensorHandle placed on a
// ParallelDevice. Contains a tuple of tensors, one on each of the
// `underlying_devices_` of the ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
std::unique_ptr<ParallelTensor> t,
TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
ParallelDevice::ParallelDevice(const std::string& name,
const std::vector<std::string>& devices)
: device_name_(name),
underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
if (device_name_ == current_device) {
std::string message(absl::StrCat(
"Tried to copy a TensorHandle to its existing device: ", device_name_));
TF_SetStatus(status, TF_INTERNAL, message.c_str());
return nullptr;
}
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
// TODO(allenl): We should remove "TPU" from these op names at the very least,
// or consider other ways of packing/unpacking parallel tensors.
if (operation_name == std::string("TPUReplicatedInput")) {
// Special-cased operation for packing per-device tensors into one parallel
// tensor.
if (inputs.size() != underlying_devices_.size()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
inputs.size()));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
std::vector<TensorHandlePtr> components;
components.reserve(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
std::string message(absl::StrCat(
"Expected all inputs to TPUReplicatedInput to be non-parallel "
"TensorHandles. The input ",
i,
" was a parallel tensor (already "
"placed on the parallel device)."));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
components.emplace_back(TFE_TensorHandleCopySharingTensor(
absl::get<TFE_TensorHandle*>(inputs[i]), status));
}
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
} else if (operation_name == std::string("TPUReplicatedOutput")) {
// Special-cased operation for un-packing one parallel tensor into
// per-device tensors.
OpPtr op(TFE_NewOp(context, operation_name, status));
TFE_OpAddAttrs(op.get(), attributes);
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
if (TF_GetCode(status) != TF_OK) return result;
if (expected_outputs != underlying_devices_.size()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(),
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"Expected the input to "
"TPUReplicatedOutput to be a parallel tensor (placed on the "
"parallel device).");
return result;
}
ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
std::vector<MaybeParallelTensorOwned> outputs;
outputs.reserve(t->num_tensors());
for (int i = 0; i < t->num_tensors(); ++i) {
TensorHandlePtr this_output(
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
outputs.emplace_back(std::move(this_output));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(outputs));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
ExecuteParallelOperation(context, std::move(inputs), operation_name,
attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
std::move(maybe_parallel_results.value()));
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(parallel_results.size());
for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
result_content.push_back(
MaybeParallelTensorOwned(std::move(parallel_result)));
}
result.emplace(std::move(result_content));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::ExecuteParallelOperation(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) {
first_op_output_count = real_num_outputs;
} else {
if (real_num_outputs != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero.
void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<ParallelTensor*>(data);
}
TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
TF_Status* status) {
// The resulting TensorHandle owns an opaque pointer to "device memory", which
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release();
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
context, t_released->device_.device_name().c_str(), t_released->dtype_,
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
&ParallelTensorDeallocator, nullptr, status));
}
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
// registration.
//
// Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing
// a ParallelTensor with one copy of `tensor` for each device in the
// ParallelDevice.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
std::unique_ptr<ParallelTensor> parallel_tensor(
dev->CopyToParallelDevice(context, tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
status)
.release();
}
// For TFE_CustomDevice::copy_tensor_from_device in the parallel device
// registration.
//
// Currently this is an error, and un-packing ParallelTensors must be performed
// explicitly by running a TPUReplicatedOutput operation on the parallel device.
//
// TODO(allenl): There are some use-cases that are only supported by copying to
// host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
// need to return something here or address these use-cases one by one.
TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a parallel device.");
return nullptr;
}
// For TFE_CustomDevice::execute in the parallel device registration.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
if (TF_GetCode(status) != TF_OK) return;
if (dev->device_name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
TFE_TensorHandleDevicePointer(inputs[i], status)));
if (TF_GetCode(status) != TF_OK) return;
} else {
typed_inputs.emplace_back(inputs[i]);
}
}
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
*num_outputs, status));
if (TF_GetCode(status) != TF_OK) return;
if (!maybe_typed_outputs.has_value()) {
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
return;
}
std::vector<MaybeParallelTensorOwned> typed_outputs(
std::move(maybe_typed_outputs.value()));
if (typed_outputs.size() > *num_outputs) {
TF_SetStatus(status, TF_INTERNAL,
"The allocated output buffer was too small.");
return;
}
for (int i = 0; i < typed_outputs.size(); ++i) {
MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
} else {
outputs[i] = ParallelTensor::AsTensorHandle(
context,
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
typed_output)),
status)
.release();
if (TF_GetCode(status) != TF_OK) return;
}
}
*num_outputs = typed_outputs.size();
}
// For TFE_CustomDevice::delete_device in the parallel device registration.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void DeleteParallelDevice(void* device_info) {
delete reinterpret_cast<ParallelDevice*>(device_info);
}
} // namespace
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
const char** underlying_devices,
int num_underlying_devices, TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToParallelDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
custom_device.delete_device = &DeleteParallelDevice;
custom_device.execute = &ParallelDeviceExecute;
std::vector<std::string> underlying_devices_vector;
underlying_devices_vector.reserve(num_underlying_devices);
for (int device_index = 0; device_index < num_underlying_devices;
++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]);
}
ParallelDevice* d =
new ParallelDevice(device_name, underlying_devices_vector);
TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
}
} // namespace eager
} // namespace tensorflow

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
#include "tensorflow/c/eager/c_api.h"
namespace tensorflow {
namespace eager {
// Register a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed
// on each underlying device.
//
// For example if `device_name` is
// "/job:localhost/replica:0/task:0/device:CUSTOM:0"
// and `underlying_devices` is
// {"/job:localhost/replica:0/task:0/device:GPU:0",
// "/job:localhost/replica:0/task:0/device:GPU:1"}
// Then executing an operation on CUSTOM:0 will execute it on GPU:0 and GPU:1.
//
// Implicit copies onto `device_name` are allowed, replicating the value once
// per device in `underlying_devices`. Implicit copies off of the device throw
// an error.
//
// All component tensors must have the same dtype. Currently they must also have
// the same shape, although this requirement may be relaxed in the future.
//
// `device_name` must not name an existing physical or custom device (see
// the documentation for TFE_RegisterCustomDevice for more information).
//
// Tensors may be copied on or off the device explicitly using
// TPUReplicatedInput and TPUReplicatedOutput respectively. For example, with
// two component devices, running `x = TPUReplicatedInput(inputs=[a, b])` on the
// parallel device creates a parallel tensor `x` with `a` on the first of
// `underlying_devices` and `b` on the second. Running `a_unpacked, b_unpacked =
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
// into its components.
//
// `context` owns the parallel device. `underlying_devices` must stay valid
// while the parallel device is in use.
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
const char** underlying_devices,
int num_underlying_devices, TF_Status* status);
} // namespace eager
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_

View File

@ -0,0 +1,917 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <array>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
// NOTE(allenl): These tests currently go through TFE_Execute and so are
// integration testing rather than purely testing the parallel device. They
// correspond fairly well to the implementation, but testing the C++ directly is
// another option.
// Functor for making unique_ptr to TFE_TensorHandle slightly more
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
// template argument requires passing a function pointer to
// TFE_DeleteTensorHandle when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
// A helper for performing common operations on variables. A much more
// restricted stand-in for tf.Variable in Python.
class Variable {
public:
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
// indication of the dtype of the variable's value.
//
// Note that creating this resource-dtype handle can fail, so `Create` is a
// separate static method which returns a status.
Variable(TFE_TensorHandle* handle, TF_DataType type)
: handle_(handle), type_(type) {}
// Helper for constructing a resource handle and wrapping it in a `Variable`
// object.
static Variable* Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status);
// Dereferences the backing buffer for the variable. Note that since this can
// fail (it runs operations), it must be called explicitly and the resulting
// `status` checked.
void Destroy(TFE_Context* context, TF_Status* status);
// Reads from the variable.
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
// Assigns a new value to the variable.
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
// Adds `value` to the existing value of the variable.
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status);
private:
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
// AssignSub, ...).
void GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status);
// The a handle for the resource-dtype tensor pointing to the variable's
// buffer.
TFE_TensorHandle* handle_;
// The dtype of the variable's buffer (input dtype for assignments, output
// dtype of read operations).
TF_DataType type_;
};
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type);
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
TFE_OpSetAttrString(op.get(), "container", "", 0);
// Use the special GUID for no buffer sharing
//
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
// only reasonable way to make variables with no aliasing using the eager C
// API.
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
no_sharing.length());
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return new Variable(var_handle, type);
}
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
// Free the backing buffer for the variable.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
// Delete the variable handle itself.
TFE_DeleteTensorHandle(handle_);
}
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type_);
int num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(var_value);
}
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrType(op.get(), "dtype", type_);
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), value, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
}
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignAddVariableOp", context, value, status);
}
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignVariableOp", context, value, status);
}
// Passed to `TF_NewTensor` to indicate how an array of floats should be
// deleted.
static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
const int num_bytes = sizeof(float);
float* values = new float[1];
values[0] = v;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status) {
const int num_bytes = v.size() * sizeof(float);
float* values = new float[v.size()];
memcpy(values, v.data(), num_bytes);
int64_t dims = v.size();
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
&FloatDeallocator, nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_TensorHandle* result_handles[num_replicas];
int num_retvals = num_replicas;
TFE_Execute(op.get(), result_handles, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
for (int i = 0; i < num_replicas; ++i) {
(*components)[i].reset(result_handles[i]);
}
}
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
for (int i = 0; i < num_replicas; ++i) {
TFE_OpAddInput(op.get(), components[i], status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), second, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* first_device = TFE_TensorHandleDeviceName(first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), first_device, status);
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
// Assert that `handle` is equal to `expected_value`.
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(expected_value,
*static_cast<float*>(TF_TensorData(value_zero.get())));
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device) {
// Register the custom device
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
tensorflow::eager::RegisterParallelDevice(
context, device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
// device.
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
to_delete->Destroy(context, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
delete to_delete;
};
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
status.get()),
variable_deleter);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Assign an initial value to the variable, implicitly mirroring it to each
// component device.
{
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
variable->Assign(context, initial_value.get(), status.get());
}
// Read from the variable and verify that we have a parallel tensor.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 20.);
AssertScalarFloatEq(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Add a parallel tensor with different values on each device to the variable.
{
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value =
CreatePerDeviceValues(context, components, device_name, status.get());
variable->AssignAdd(context, combined_value.get(), status.get());
}
// Read the variable and verify that each component has the right modified
// value.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 23.);
AssertScalarFloatEq(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}
TEST(PARALLEL_DEVICE, TestBasicCPU) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
BasicTestsForTwoDevices(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1");
}
TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
BasicTestsForTwoDevices(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:0");
}
TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Skip the test if no TPU is available.
std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> devices(
TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool has_tpu = false;
for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
++device_index) {
std::string device_type =
TF_DeviceListType(devices.get(), device_index, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
if (device_type == "TPU") {
has_tpu = true;
break;
}
}
if (has_tpu) {
BasicTestsForTwoDevices(context.get(),
"/job:localhost/replica:0/task:0/device:TPU:0",
"/job:localhost/replica:0/task:0/device:TPU:0");
}
}
TEST(PARALLEL_DEVICE, TestExplicitCopies) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
const char* first_device_name =
"/job:localhost/replica:0/task:0/device:CPU:0";
underlying_devices.push_back(first_device_name);
const char* second_device_name =
"/job:localhost/replica:0/task:0/device:CPU:1";
underlying_devices.push_back(second_device_name);
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Copying on to a parallel device is OK.
TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice(
cpu_value.get(), context.get(), device_name, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* backing_device =
TFE_TensorHandleBackingDeviceName(device_value.get(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(std::string(device_name), backing_device);
// Un-pack the parallel tensor to verify that the copy was successful.
{
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context.get(), device_value.get(), &components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// The value of the original tensor is replicated on each device.
AssertScalarFloatEq(components[0].get(), 3.);
AssertScalarFloatEq(components[1].get(), 3.);
// Verify that the mirrors are placed on the component devices.
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Copies off of parallel devices must be explicit.
TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
device_value.get(), context.get(), first_device_name, status.get()));
ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL);
}
TEST(PARALLEL_DEVICE, TestDifferentShapes) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create two vectors with different lengths
std::vector<float> size_two_value{1., 2.};
std::vector<float> size_three_value{1., 2., 3.};
TensorHandlePtr size_two(
VectorFloatTensorHandle(size_two_value, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr size_three(
VectorFloatTensorHandle(size_three_value, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Try to combine these values into a single parallel tensor.
std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED)
<< TF_Message(status.get());
}
TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
3),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a parallel device with two CPUs
const char* first_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> first_underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
tensorflow::eager::RegisterParallelDevice(
context.get(), first_device_name, first_underlying_devices.data(),
first_underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a second parallel device with the first parallel device and one
// additional CPU.
const char* second_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
std::vector<const char*> second_underlying_devices{
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
"/job:localhost/replica:0/task:0/device:CPU:2"};
tensorflow::eager::RegisterParallelDevice(
context.get(), second_device_name, second_underlying_devices.data(),
second_underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a tensor on the first parallel device
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
TensorHandlePtr first_combined_value = CreatePerDeviceValues(
context.get(), components, first_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Nest the first parallel tensor into a second
TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
components[0] = first_combined_value.get();
components[1] = value_three.get();
TensorHandlePtr second_combined_value = CreatePerDeviceValues(
context.get(), components, second_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr negative_one(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr multiply_result(Multiply(context.get(),
second_combined_value.get(),
negative_one.get(), status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Un-pack the parallel tensor to verify that the operation was
// successful. The resulting structure should be:
// second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
std::array<TensorHandlePtr, 2> second_components;
ExtractPerDeviceValues(context.get(), multiply_result.get(),
&second_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(second_components[1].get(), 9.);
// Verify that the mirrors are placed on the component devices.
std::string first_device = TFE_TensorHandleBackingDeviceName(
second_components[0].get(), status.get());
ASSERT_EQ(second_underlying_devices[0], first_device);
std::string second_device = TFE_TensorHandleBackingDeviceName(
second_components[1].get(), status.get());
ASSERT_EQ(second_underlying_devices[1], second_device);
// Un-pack the first parallel device's tensor too
std::array<TensorHandlePtr, 2> first_components;
ExtractPerDeviceValues(context.get(), second_components[0].get(),
&first_components, status.get());
AssertScalarFloatEq(first_components[0].get(), 3.);
AssertScalarFloatEq(first_components[1].get(), 6.);
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
status.get());
ASSERT_EQ(first_underlying_devices[0], first_device);
second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
status.get());
ASSERT_EQ(first_underlying_devices[1], second_device);
}
TEST(PARALLEL_DEVICE, TestInvalidPacking) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
{
// Try to pack two TensorHandles onto a parallel device with a single
// component.
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
{
// Try to extract the wrong number of components from a parallel tensor
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), correct_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> incorrect_components;
ExtractPerDeviceValues(context.get(), combined_value.get(),
&incorrect_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
{
// Try to pass a ParallelTensor to TPUReplicatedInput
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), correct_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 1> incorrect_components{combined_value.get()};
TensorHandlePtr recombined_value = CreatePerDeviceValues(
context.get(), incorrect_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
{
// Try to pass a non-parallel tensor to TPUReplicatedOutput
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
TFE_DeleteOp);
if (TF_GetCode(status.get()) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
TFE_OpAddInput(op.get(), value_one.get(), status.get());
if (TF_GetCode(status.get()) != TF_OK) return;
TFE_OpSetDevice(op.get(), device_name, status.get());
if (TF_GetCode(status.get()) != TF_OK) return;
TFE_TensorHandle* result_handles;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
}
TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
int group_size, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
TFE_OpSetAttrInt(op.get(), "group_size", group_size);
TFE_OpSetAttrInt(op.get(), "group_key", 0);
TFE_OpSetAttrInt(op.get(), "instance_key", 0);
const std::string merge_op("Add");
TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
merge_op.length());
const std::string final_op("Id");
TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
final_op.length());
TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
TEST(PARALLEL_DEVICE, TestCollective) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a tensor on the parallel device
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
TensorHandlePtr parallel_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Run a collective sum, so each component should now be the same.
TensorHandlePtr reduced(
CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> result_components;
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 3.);
AssertScalarFloatEq(result_components[1].get(), 3.);
}
void RegisterCollectiveMulFunction(TFE_Context* context,
const char* function_name, int group_size,
TF_Status* status) {
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> body(TF_NewGraph(),
TF_DeleteGraph);
TF_OperationDescription* placeholder_desc =
TF_NewOperation(body.get(), "Placeholder", "Placeholder");
TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
if (TF_GetCode(status) != TF_OK) return;
TF_Output x{placeholder_op, 0};
TF_OperationDescription* reduce_desc =
TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
TF_SetAttrInt(reduce_desc, "group_size", group_size);
TF_SetAttrInt(reduce_desc, "group_key", 0);
TF_SetAttrInt(reduce_desc, "instance_key", 0);
const std::string merge_op("Mul");
TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
merge_op.length());
const std::string final_op("Id");
TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
final_op.length());
TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
TF_AddInput(reduce_desc, x);
TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
if (TF_GetCode(status) != TF_OK) return;
TF_Operation* operations[]{placeholder_op, reduce_op};
TF_Output y{reduce_op, 0};
const char* output_name = "y";
std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> function(
TF_GraphToFunction(
/* fn_body */ body.get(), /* fn_name */ function_name,
/* append_hash_to_fn_name */ 0, /* num_opers */ 2,
/* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
/* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
/* opts */ nullptr, /* description */ "", /* status */ status),
TF_DeleteFunction);
if (TF_GetCode(status) != TF_OK) return;
TFE_ContextAddFunction(context, function.get(), status);
}
TEST(PARALLEL_DEVICE, TestFunction) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* function_name = "test_reduce_mul";
RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
TensorHandlePtr parallel_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* raw_result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr reduced(raw_result_handle);
std::array<TensorHandlePtr, 2> result_components;
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
std::string first_device = TFE_TensorHandleBackingDeviceName(
result_components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device = TFE_TensorHandleBackingDeviceName(
result_components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}

View File

@ -856,7 +856,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
}
VLOG(1) << "Final gradients size: "
<< gradients.size() - used_gradient_ids.size();
for (auto grad_pair : gradients) {
for (const auto& grad_pair : gradients) {
if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
for (const auto& g : grad_pair.second) {
vspace.DeleteGradient(g);

View File

@ -15,11 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
@ -34,75 +32,37 @@ namespace tensorflow {
// is needed a static_cast can be applied.
class AbstractTensorHandleInterface {
public:
virtual ~AbstractTensorHandleInterface() {}
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
// Check if the handle is in a valid initialized state.
virtual bool IsValid(Status* status) const = 0;
// Returns tensor dtype.
virtual TF_DataType DataType() const = 0;
virtual tensorflow::DataType DataType() const = 0;
// Returns number of dimensions.
virtual int NumDims(Status* status) const = 0;
virtual Status NumDims(int* num_dims) const = 0;
// Returns number of elements across all dimensions.
virtual int64_t NumElements(Status* status) const = 0;
virtual Status NumElements(int64* num_elements) const = 0;
// Returns size of specified dimension
virtual int64_t Dim(int dim_index, Status* status) const = 0;
virtual Status Dim(int dim_index, int64* dim) const = 0;
// Returns the device which created the handle.
virtual const char* DeviceName(Status* status) const = 0;
// Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual TF_Tensor* Resolve(Status* status) = 0;
// Returns debug information about the tensor.
virtual TFE_TensorDebugInfo* TensorDebugInfo(Status* status) = 0;
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
// Maintain mirror tensors for any implicit copies to local devices. This
// setting is offered on a per tensor handle basis to avoid potential memory
// over utilization due to holding on to mirrors as well as the original
// tensor. Note this setting overrides the context mirroring policy whereby if
// the mirroring policy is MIRRORING_NONE, we will still continue to mirror
// this tensor.
virtual void EnableImplicitMirroring() = 0;
protected:
virtual ~AbstractTensorHandleInterface() {}
};
// TODO(gjn): Try to move these all to TensorHandle and make it implement
// AbstractTensorHandleInterface. Currently, this is not so straightforward
// because of various BUILD file dependencies.
class TensorHandleInterface : public AbstractTensorHandleInterface {
public:
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
~TensorHandleInterface() override;
bool IsValid(Status* status) const override;
TF_DataType DataType() const override;
int NumDims(Status* status) const override;
int64_t NumElements(Status* status) const override;
int64_t Dim(int dim_index, Status* status) const override;
const char* DeviceName(Status* status) const override;
const char* BackingDeviceName(Status* status) const override;
TF_Tensor* Resolve(Status* status) override;
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
AbstractTensorHandleInterface* Copy() override;
void EnableImplicitMirroring() override;
// For runtime specific APIs, provide ability to get the underlying handle.
TensorHandle* Handle() { return handle_; }
private:
TensorHandle* handle_;
};
inline TensorHandle* TensorHandleFromInterface(
const std::unique_ptr<AbstractTensorHandleInterface>& handle) {
return down_cast<TensorHandleInterface*>(handle.get())->Handle();
}
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_

View File

@ -0,0 +1,24 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
#include "tensorflow/core/framework/cancellation.h"
struct TFE_CancellationManager {
tensorflow::CancellationManager cancellation_manager;
};
#endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#include "tensorflow/c/eager/context_interface.h"
// Wraps a pointer to a context implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying context object. Instead, call
// TFE_DeleteContext who calls Release() on the context pointer and deletes
// the TFE_Context structure.
struct TFE_Context {
tensorflow::AbstractContextInterface* context;
};
#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_

View File

@ -0,0 +1,37 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
#include <memory>
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
struct TFE_Executor {
explicit TFE_Executor(bool async)
: owned_executor(new tensorflow::EagerExecutor(async)) {}
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
: owned_executor(nullptr), unowned_executor(executor) {}
tensorflow::EagerExecutor* executor() {
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
}
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
tensorflow::EagerExecutor* unowned_executor;
};
#endif // TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_

View File

@ -0,0 +1,146 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
#include <functional>
#include <memory>
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/types.h"
struct TFE_MonitoringCounterCell {
tensorflow::monitoring::CounterCell cell;
};
template <int NumLabels>
struct TFE_MonitoringCounter {
template <typename... LabelDesc>
TFE_MonitoringCounter(const char* name, const char* description,
LabelDesc&&... label) {
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
};
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringIntGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
};
struct TFE_MonitoringStringGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
};
struct TFE_MonitoringBoolGaugeCell {
tensorflow::monitoring::GaugeCell<bool> cell;
};
template <typename ValueType, int NumLabels>
struct TFE_MonitoringGauge {
template <typename... LabelDesc>
TFE_MonitoringGauge(const char* name, const char* description,
LabelDesc&&... label) {
gauge = absl::WrapUnique(
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
};
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBuckets {
explicit TFE_MonitoringBuckets(
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
fn) {
create_buckets = fn;
}
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
create_buckets;
};
struct TFE_MonitoringSamplerCell {
tensorflow::monitoring::SamplerCell cell;
};
template <int NumLabels>
struct TFE_MonitoringSampler {
template <typename... LabelDesc>
TFE_MonitoringSampler(
const char* name,
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
const char* description, LabelDesc&&... label) {
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
{name, description, label...}, std::move(buckets)));
}
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
};
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
#endif // TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_

View File

@ -0,0 +1,52 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
#include <algorithm>
#include <cstddef>
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <vector>
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/attr_value.pb.h"
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
const char* op_name)
: name(op_name), attributes(value) {}
const char* name;
const tensorflow::AttrBuilder* attributes;
};
namespace tensorflow {
// Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
const char* attr_name, TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#include "tensorflow/c/eager/operation_interface.h"
// Wraps a pointer to an operation implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying operation object. Instead, call
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
// the TFE_Op structure.
struct TFE_Op {
tensorflow::AbstractOperationInterface* operation;
};
#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
#include <vector>
#include "tensorflow/core/platform/types.h"
struct TFE_TensorDebugInfo {
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
: dev_dims(dims) {}
// Fully-padded, minor-to-major.
std::vector<tensorflow::int64> dev_dims;
};
#endif // TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#include "tensorflow/c/eager/tensor_handle_interface.h"
// Wraps a pointer to a tensor handle implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying handle object. Instead, call
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
// the TFE_TensorHandle structure.
struct TFE_TensorHandle {
tensorflow::AbstractTensorHandleInterface* handle;
};
#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_

View File

@ -24,8 +24,8 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
using tensorflow::ServerFactory;

View File

@ -22,8 +22,8 @@ limitations under the License.
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {

View File

@ -37,9 +37,9 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"

View File

@ -24,9 +24,9 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
namespace tensorflow {

View File

@ -0,0 +1,66 @@
# Tensorflow C SavedModel API
## Overview
These are the new experimental C SavedModel APIs for loading and running
SavedModels in a TF2-idiomatic fashion. See
[RFC 207](https://github.com/tensorflow/community/pull/207) for additional
context.
The directory structure is as follows:
```none
saved_model/
public/
internal/
core/
```
## saved_model/public
`saved_model/public` is intended to house *only the public headers* of the
SavedModel C API.
These headers:
1. declare opaque C types (like `TF_SavedModel`),
2. declare the functions that operate on these types (like `TF_LoadSavedModel`).
Once they leave experimental, these APIs should be considered stable for use
by external clients.
These headers are in a separate directory to make it obvious to clients which
headers they should depend on, and which headers are implementation details.
Separating these public headers by directory also allow future programmatic
checks to ensure that TF public headers only `#include` other public TF headers.
## saved_model/internal
`saved_model/internal` is the "glue" between the C API and the internal C++
implementation.
Its role is to:
1. implement the C API functions declared in `saved_model/public`
2. define the C API types declared in `saved_model/public`
The files fulfilling 1. are named `*.cc` (eg: `concrete_function.cc`), while
the files fulfilling 2. are `*type.h` (eg: `concrete_function_type.h`).
The headers exposing the internal implementation of the opaque C types are only
visible to other implementors of the C API. This is similar to how other
TF C API implementations use `tf_status_internal.h` (to extract the underlying
`tensorflow::Status`). All other targets in this directory are private.
## saved_model/core
`saved_model/core` contains pure C++ "Classes" underlying the C API types
in `saved_model/public/`. These are implementation
details subject to change, and have limited visibility to implementors only.
This is the bottom-most layer of the `C++ -> C -> C++` sandwich.

View File

@ -0,0 +1,46 @@
# Experimental SavedModel C APIs for TensorFlow. See RFC
# https://github.com/tensorflow/community/pull/207
# Targets in this directory are pure C++ "Classes" underlying the C API types
# under tf/c/experimental/saved_model/public/. They are subject to change and
# have visibility limited to Tensorflow's implementation only.
package(
default_visibility = [
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "concrete_function",
srcs = [
"concrete_function.cc",
],
hdrs = [
"concrete_function.h",
],
deps = [
":function_metadata",
"//tensorflow/c/eager:operation_interface",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "function_metadata",
hdrs = [
"function_metadata.h",
],
)
cc_library(
name = "saved_model_api",
hdrs = [
"saved_model_api.h",
],
deps = [
":concrete_function",
"//tensorflow/core:lib",
],
)

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
namespace tensorflow {
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
ConcreteFunction::GetCaptures() const {
return captures_;
}
const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const {
return metadata_;
}
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
#include <vector>
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/core/framework/function.pb.h"
namespace tensorflow {
// Note that ConcreteFunctions's lifetimes are effectively bound
// to the SavedModel they are loaded from, since they retain pointers
// to the TensorHandles owned by the SavedModel, and the FunctionDef
// of the SavedModel.
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
// TFRT integration with TF Serving. Do not add more virtual implementations of
// this class. Eventually we want to remove this virtual base class indirection
// and have only a single implementation.
class ConcreteFunction {
public:
virtual ~ConcreteFunction() = 0;
// This method returns the "Call" Op used to execute the function.
virtual AbstractOperationInterface* GetCallOp() = 0;
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
const;
const FunctionMetadata& GetFunctionMetadata() const;
private:
FunctionMetadata metadata_;
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
FunctionDef* function_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_

View File

@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
class Fused2Ops<string kernel> : NativeCodeCall<
"fuseOps(&$_builder, {$0, $1}, \"" # kernel # "\")">;
class Fused3Ops<string kernel> : NativeCodeCall<
"fuseOps(&$_builder, {$0, $1, $2}, \"" # kernel # "\")">;
namespace tensorflow {
def : Pat<(HLO_AddOp:$add (HLO_MulOp:$mul $_, $_, $_), $_, $_),
(Fused2Ops<"generic.mul_add"> $mul, $add)>;
class FunctionMetadata {
// TODO(bmzhao): Fill in with fields as necessary
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
// TFRT integration with TF Serving. Do not add more virtual implementations of
// this class. Eventually we want to remove this virtual base class indirection
// and have only a single implementation.
class SavedModelAPI {
public:
// Retrieve a function from the TF2 SavedModel, using the "path" to a function
// in a TF2 savedmodel.
// Note: `function` is a double pointer, so that implementations are
// able to return a pointer to an internal member.
virtual Status GetFunction(const std::string& function_path,
ConcreteFunction** function) = 0;
// Retrieve a function from a SavedModel, using the key of the
// SignatureDef map:
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
ConcreteFunction** function) = 0;
virtual const std::vector<ConcreteFunction*>& ListFunctions() = 0;
virtual ~SavedModelAPI() = default;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_

View File

@ -0,0 +1,157 @@
# Experimental Implementation of SavedModel C APIs for TensorFlow. See RFC
# https://github.com/tensorflow/community/pull/207
# External clients should not worry about this directory; all contents are implementation details.
# Code in this directory is intended to form the glue between the C API and the internal C++
# implementation by
# 1. mapping C API calls onto correponding methods of C++ objects
# 2. mapping opaque C types onto C++ classes
# Note(bmzhao): The *.cc files in this directory form the direct implementation of the
# C API functions exposed in tf/c/experimental/saved_model/public/.
# Note(bmzhao): All *type.h files in this directory are the internal definitions of
# the opaque C types. These headers should only be visible to internal tensorflow
# implementors.
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "conversion_macros",
hdrs = [
"conversion_macros.h",
],
)
cc_library(
name = "concrete_function",
srcs = [
"concrete_function.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
],
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
# so that we can depend on c/eager/c_api_unified_experimental.h.
features = ["-layering_check"],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":concrete_function_type",
":function_metadata",
":function_metadata_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
)
cc_library(
name = "concrete_function_list",
srcs = [
"concrete_function_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:concrete_function_list.h",
],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":concrete_function",
":concrete_function_list_type",
":concrete_function_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
],
)
cc_library(
name = "concrete_function_list_type",
hdrs = [
"concrete_function_list_type.h",
],
deps = [
":conversion_macros",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
],
)
cc_library(
name = "concrete_function_type",
hdrs = [
"concrete_function_type.h",
],
deps = [
":conversion_macros",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
],
)
cc_library(
name = "function_metadata",
srcs = [
"function_metadata.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:function_metadata.h",
],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":function_metadata_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
)
cc_library(
name = "function_metadata_type",
hdrs = [
"function_metadata_type.h",
],
deps = [
":conversion_macros",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
)
cc_library(
name = "saved_model_api",
srcs = [
"saved_model_api.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:saved_model_api.h",
],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":concrete_function",
":concrete_function_list",
":concrete_function_list_type",
":concrete_function_type",
":saved_model_api_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/core:lib",
],
)
cc_library(
name = "saved_model_api_type",
hdrs = [
"saved_model_api_type.h",
],
deps = [
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
],
)

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
extern "C" {
TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
return tensorflow::wrap(&tensorflow::unwrap(func)->GetFunctionMetadata());
}
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
// internal header, and implement this function.
return nullptr;
}
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
return new TFE_Op{tensorflow::unwrap(func)->GetCallOp()};
}
} // end extern "C"

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
extern "C" {
size_t TF_ConcreteFunctionListNumOutputs(TF_ConcreteFunctionList* list) {
return tensorflow::unwrap(list)->size();
}
TF_ConcreteFunction* TF_ConcreteFunctionListGet(TF_ConcreteFunctionList* list,
int i) {
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
}
} // end extern "C"

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(std::vector<tensorflow::ConcreteFunction*>,
TF_ConcreteFunctionList)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
// It doesn't make sense to wrap tensorflow::ConcreteFunction* in a separate
// struct, since the lifetime of the struct and the raw pointer it wraps would
// be different. Therefore TF_ConcreteFunction* = tensorflow::ConcreteFunction*.
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ConcreteFunction, TF_ConcreteFunction)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline wrapper *wrap(const cpp_impl *i) { \
return reinterpret_cast<wrapper *>(const_cast<cpp_impl *>(i)); \
}
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_

View File

@ -0,0 +1,20 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
// TODO(bmzhao): Add getter functions here as necessary.

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::FunctionMetadata, TF_FunctionMetadata)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_

View File

@ -0,0 +1,67 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/status.h"
extern "C" {
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
const char* const* tags, int tags_len,
TF_Status* status) {
// TODO(bmzhao): Add a virtual "LoadSavedModel" method to
// AbstractContextInterface, and call it here.
return nullptr;
}
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
char* function_path,
TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetFunction(function_path, &result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
}
return tensorflow::wrap(result);
}
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, char* signature_def_key, TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
}
return tensorflow::wrap(result);
}
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
return tensorflow::wrap(&model->saved_model->ListFunctions());
}
} // end extern "C"

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
#include <memory>
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
struct TF_SavedModel {
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
};
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_

View File

@ -0,0 +1,63 @@
# Experimental SavedModel C APIs for TensorFlow.
# See RFC https://github.com/tensorflow/community/pull/207
# All headers are on the public surface of Tensorflow's C API.
# Once moved out of experimental, these will be stable.
# The idea behind a separate public/ directory is to make apparent
# which headers are part of TF's public interface (and which headers)
# are implementation details. This structure allows us to also perform future
# programmatic checks that all "public" headers only include other "public"
# headers.
package(
# This is intentionally public
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
# TODO(bmzhao): Remove these exports_files and rules, swap with cc_public_library instead.
# cc_public_library would allows us to separate the header dep graph from header+srcs dep graph.
exports_files(
[
"concrete_function.h",
"concrete_function_list.h",
"function_metadata.h",
"saved_model_api.h",
],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
)
# The purpose of this header is to provide insulation against
# future changes where we rename/move a public header, without
# forcing all clients to change their "#includes".
cc_library(
name = "c_saved_model_api",
hdrs = ["c_saved_model_api.h"],
deps = [
":concrete_function",
":concrete_function_list",
":function_metadata",
":saved_model_api",
],
)
alias(
name = "concrete_function",
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function",
)
alias(
name = "concrete_function_list",
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function_list",
)
alias(
name = "function_metadata",
actual = "//tensorflow/c/experimental/saved_model/internal:function_metadata",
)
alias(
name = "saved_model_api",
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
)

View File

@ -1,4 +1,4 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,19 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_
#include "tensorflow/lite/experimental/ruy/ruy/platform.h"
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
// IWYU pragma: begin_exports
#if RUY_PLATFORM(NEON)
#include "tensorflow/lite/experimental/ruy/ruy/kernel_arm.h"
#elif RUY_PLATFORM(X86)
#include "tensorflow/lite/experimental/ruy/ruy/kernel_x86.h"
#else
#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h"
#endif
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
// IWYU pragma: end_exports
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that corresponds to a Function loaded from a SavedModel.
// TODO(bmzhao): Work together w/srbs@ to make sure this composes w/the
// C++ Unified Eager/Graph API's AbstractFunction
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
// Returns FunctionMetadata associated with `func`. Metadata's lifetime is
// bound to `func`, which is bound to the TF_SavedModel it was loaded from.
TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
TF_ConcreteFunction* func);
// Returns a list of TensorHandles implicitly captured by this function.
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
TF_ConcreteFunction* func);
// Returns a TFE_Op suitable for executing this function.
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
TF_ConcreteFunction* func);
// Deletes `func`.
TF_CAPI_EXPORT extern void TF_DeleteConcreteFunction(TF_ConcreteFunction* func);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
// Returns the size of `list`.
TF_CAPI_EXPORT size_t
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
// Returns the `i`th TF_ConcreteFunction in the list.
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
TF_ConcreteFunctionList* list, int i);
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
#include "tensorflow/c/c_api_macros.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type used to store any metadata associated with a function.
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
// TODO(bmzhao): Add getters for fields as we determine what metadata
// we want to expose.
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_

View File

@ -0,0 +1,96 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/tf_status.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type representing a Tensorflow "SavedModel"
// (https://www.tensorflow.org/guide/saved_model) that we always pass by pointer
// to achieve ABI stability.
typedef struct TF_SavedModel TF_SavedModel;
// Load a SavedModel from `dirname`.
//
// Params:
// dirname - A directory filepath that the SavedModel is at.
// ctx - A TFE_Context containing optional load/TF runtime options.
// `ctx` must outlive the returned TF_SavedModel pointer.
// tags - Pointer to char* array of SavedModel tags. Optional if the SavedModel
// contains a single Metagraph, as for those exported from
// `tf.saved_model.save`.
// tags_len - number of elements in the `tags` array.
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a newly created
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModel(const char* dirname,
TFE_Context* ctx,
const char* const* tags,
int tags_len,
TF_Status* status);
// Deletes a TF_SavedModel, and frees any resources owned by it.
TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model);
// Retrieve a function from the TF2 SavedModel via function path.
//
// Params:
// model - The TF2 SavedModel to load a function from.
// function_path - A string containing the path from the root saved python
// object to a tf.function method.
// TODO(bmzhao): Add a detailed example of this with a
// python tf.module before moving this out of experimental.
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// TF_ConcreteFunction instance. The lifetime of this instance is
// "conceptually" bound to `model`. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
TF_SavedModel* model, char* function_path, TF_Status* status);
// Retrieve a function from the TF SavedModel via a SignatureDef key.
//
// Params:
// model - The SavedModel to load a function from.
// signature_def_key - The string key of the SignatureDef map of a SavedModel:
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// TF_ConcreteFunction instance. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, char* signature_def_key, TF_Status* status);
// Returns a list of all ConcreteFunctions stored in this SavedModel.
// The lifetime of the returned list is bound to `model`.
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
TF_SavedModel* model);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_

View File

@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
#ifndef TENSORFLOW_C_TENSOR_INTERFACE_H_
#define TENSORFLOW_C_TENSOR_INTERFACE_H_
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Abstract interface to a Tensor.
//
@ -28,10 +29,11 @@ limitations under the License.
// is needed a static_cast can be applied.
class AbstractTensorInterface {
public:
virtual ~AbstractTensorInterface() {}
// Release any underlying resources, including the interface object.
virtual void Release() = 0;
// Returns tensor dtype.
virtual TF_DataType Type() const = 0;
virtual DataType Type() const = 0;
// Returns number of dimensions.
virtual int NumDims() const = 0;
// Returns size of specified dimension
@ -47,37 +49,11 @@ class AbstractTensorInterface {
virtual bool IsAligned() const = 0;
// Returns if their is sole ownership of this Tensor and thus it can be moved.
virtual bool CanMove() const = 0;
};
namespace tensorflow {
class TensorInterface : public AbstractTensorInterface {
public:
TensorInterface() {}
explicit TensorInterface(tensorflow::Tensor t) : tensor_(std::move(t)) {}
~TensorInterface() override {}
TF_DataType Type() const override;
int NumDims() const override;
int64_t Dim(int dim_index) const override;
int64_t NumElements() const override;
size_t ByteSize() const override;
void* Data() const override;
bool IsAligned() const override;
bool CanMove() const override;
Status ToTensor(tensorflow::Tensor* dst) const;
Status BitcastFrom(const TensorInterface& from, TF_DataType type,
const int64_t* new_dims, int num_new_dims);
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
tensorflow::Tensor Tensor() { return tensor_; }
private:
tensorflow::Tensor tensor_;
protected:
virtual ~AbstractTensorInterface() {}
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_

View File

@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/error.h"
#include "tensorflow/core/platform/status.h"
using ::tensorflow::IOError;
using ::tensorflow::Status;

View File

@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_C_TF_STATUS_HELPER_H_
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_TF_STATUS_INTERNAL_H_
#define TENSORFLOW_C_TF_STATUS_INTERNAL_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
// Internal structures used by the status C API. These are likely to change
// and should not be depended on.

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/platform/casts.h"
using tensorflow::Status;
using tensorflow::Tensor;
@ -82,7 +83,7 @@ TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
return new TF_Tensor{new tensorflow::TensorInterface(ret)};
}
} // namespace
@ -131,9 +132,21 @@ TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
return t->tensor->CanMove() ? t : nullptr;
}
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
void TF_DeleteTensor(TF_Tensor* t) {
if (t == nullptr) {
return;
}
TF_DataType TF_TensorType(const TF_Tensor* t) { return t->tensor->Type(); }
if (t->tensor) {
t->tensor->Release();
}
delete t;
}
TF_DataType TF_TensorType(const TF_Tensor* t) {
return static_cast<TF_DataType>(t->tensor->Type());
}
int TF_NumDims(const TF_Tensor* t) { return t->tensor->NumDims(); }
@ -159,15 +172,18 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
int num_new_dims, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
Status cc_status(
static_cast<tensorflow::TensorInterface*>(to->tensor.get())
->BitcastFrom(*static_cast<const tensorflow::TensorInterface*>(
from->tensor.get()),
type, new_dims, num_new_dims));
tensorflow::down_cast<tensorflow::TensorInterface*>(to->tensor)
->BitcastFrom(
*tensorflow::down_cast<const tensorflow::TensorInterface*>(
from->tensor),
static_cast<tensorflow::DataType>(type), new_dims, num_new_dims));
Set_TF_Status_from_Status(status, cc_status);
}
namespace tensorflow {
void TensorInterface::Release() { delete this; }
bool TensorInterface::CanMove() const {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
@ -180,9 +196,7 @@ bool TensorInterface::CanMove() const {
return false;
}
TF_DataType TensorInterface::Type() const {
return static_cast<TF_DataType>(tensor_.dtype());
}
DataType TensorInterface::Type() const { return tensor_.dtype(); }
int TensorInterface::NumDims() const { return tensor_.dims(); }
@ -202,15 +216,13 @@ void* TensorInterface::Data() const {
return tensorflow::TensorCApi::Buffer(tensor_)->data();
}
Status TensorInterface::BitcastFrom(const TensorInterface& from,
TF_DataType type, const int64_t* new_dims,
int num_new_dims) {
Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type,
const int64_t* new_dims, int num_new_dims) {
tensorflow::TensorShape s;
for (int i = 0; i < num_new_dims; ++i) {
s.AddDim(new_dims[i]);
}
return tensor_.BitcastFrom(from.tensor_,
static_cast<tensorflow::DataType>(type), s);
return tensor_.BitcastFrom(from.tensor_, type, s);
}
} // namespace tensorflow
@ -327,7 +339,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
if (!tensor.CopyFrom(src, src.shape())) {
return nullptr;
}
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(tensor)};
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
}
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
// encoded sequence of strings.
@ -377,7 +389,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
return static_cast<const tensorflow::TensorInterface*>(src->tensor.get())
return tensorflow::down_cast<const tensorflow::TensorInterface*>(src->tensor)
->ToTensor(dst);
}

View File

@ -18,11 +18,12 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/casts.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
@ -31,7 +32,7 @@ limitations under the License.
// passed to or returned from C functions *by pointer*. Otherwise, changes to
// its internal structure will break the C API's binary interface.
typedef struct TF_Tensor {
std::unique_ptr<AbstractTensorInterface> tensor;
tensorflow::AbstractTensorInterface* tensor;
} TF_Tensor;
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
@ -86,6 +87,41 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
// Defaults to deallocating using CPU allocator. You can pass pointer to
// a different Allocator as `arg`.
void deallocate_buffer(void* data, size_t len, void* arg);
class TensorInterface : public AbstractTensorInterface {
public:
TensorInterface() {}
explicit TensorInterface(tensorflow::Tensor t) : tensor_(std::move(t)) {}
~TensorInterface() override {}
void Release() override;
DataType Type() const override;
int NumDims() const override;
int64_t Dim(int dim_index) const override;
int64_t NumElements() const override;
size_t ByteSize() const override;
void* Data() const override;
bool IsAligned() const override;
bool CanMove() const override;
Status ToTensor(tensorflow::Tensor* dst) const;
Status BitcastFrom(const TensorInterface& from, DataType type,
const int64_t* new_dims, int num_new_dims);
tensorflow::Tensor& Tensor() { return tensor_; }
private:
tensorflow::Tensor tensor_;
};
inline Tensor& TensorFromInterface(AbstractTensorInterface* tensor) {
return down_cast<TensorInterface*>(tensor)->Tensor();
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_

View File

@ -14,10 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
using tensorflow::GraphDef;

View File

@ -156,6 +156,7 @@ cc_library(
":array_grad",
":data_flow_grad",
":image_grad",
":manip_grad",
":math_grad",
":nn_grad",
],
@ -494,6 +495,32 @@ tf_cc_test(
],
)
cc_library(
name = "manip_grad",
srcs = ["gradients/manip_grad.cc"],
deps = [
":cc_ops",
":grad_op_registry",
":gradients",
],
alwayslink = 1,
)
tf_cc_test(
name = "gradients_manip_grad_test",
srcs = ["gradients/manip_grad_test.cc"],
deps = [
":array_ops",
":cc_ops",
":gradient_checker",
":manip_grad",
":testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
# Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these
tf_gen_op_wrappers_cc(
name = "math_ops",

Some files were not shown because too many files have changed in this diff Show More