Merge branch 'master' into tf32_test_fix

This commit is contained in:
Reed 2020-08-25 10:30:06 -07:00 committed by GitHub
commit 355d6b55b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1881 changed files with 64299 additions and 23095 deletions

View File

@ -461,12 +461,12 @@ build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
# Map default to CUDA 10.1.
# Map default to CUDA 11 for PY35 and greater.
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda11.0_nvcc_py3.5
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.0_nvcc_py3.6
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.0_nvcc_py3.7
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.0_nvcc_py3.8
# Deprecated configs that people might still use.
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
@ -583,9 +583,9 @@ build:release_cpu_macos --config=avx_linux
build:release_gpu_common --config=release_common
build:release_gpu_common --config=cuda
build:release_gpu_common --config=tensorrt
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
build:release_gpu_common --action_env=TF_CUDA_VERSION="10"
build:release_gpu_common --action_env=TF_CUDNN_VERSION="7"
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0"
build:release_gpu_common --action_env=TF_CUDA_VERSION="11"
build:release_gpu_common --action_env=TF_CUDNN_VERSION="8"
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70"
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
@ -595,8 +595,7 @@ build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
build:release_gpu_linux --config=release_gpu_common
build:release_gpu_linux --config=avx_linux
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain
build:release_windows_common --config=release_common
build:release_windows_common --define=no_tensorflow_py_deps=true
build:release_windows_common --announce_rc

View File

@ -22,6 +22,7 @@
* Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers and happens to work before TF 2.4. These will explicitly be unsupported now. Converting these ops to Functional API op layers was unreliable before TF 2.4, and prone to erroring incomprehensibly or being silently buggy.
* Code that directly asserts on a Keras symbolic value in cases where ops like `tf.rank` used to return a static or symbolic value depending on if the input had a fully static shape or not. Now these ops always return symbolic values.
* Code already susceptible to leaking tensors outside of graphs becomes slightly more likely to do so now.
* Code that tries directly getting gradients with respect to symbolic Keras inputs/outputs. Use GradientTape on the actual Tensors passed to the already-constructed model instead.
* Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient.
* Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know.
* Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed.
@ -33,6 +34,9 @@
shape assumptions (note that you can pass shapes with `None` entries for axes
that are meant to be dynamic). You can also disable the input checking
entirely by setting `model.input_spec = None`.
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
removed).
## Known Caveats
@ -77,6 +81,12 @@
server and set `dispatcher_fault_tolerance=True`. The dispatcher will
store its state to `work_dir`, so that on restart it can continue from its
previous state after restart.
* Added tf.data service support for sharing dataset graphs via shared
filesystem instead of over RPC. This reduces load on the dispatcher,
improving performance of distributing datasets. For this to work, the
dispatcher's `work_dir` must be accessible from workers. If the worker
fails to read from the `work_dir`, it falls back to using RPC for dataset
graph transfer.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.
* We have implemented an optimization which reorders data-discarding
@ -84,11 +94,14 @@
dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.data.Options` were previously immutable and can now be overriden.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Given the same seed, the stateless functions
produce the same results independent of how many times the function is
called, and independent of global seed settings.
`tf.image.random_*` function. Added a new op
`stateless_sample_distorted_bounding_box` which is a determinstic
version of `sample_distorted_bounding_box` op. Given the same seed, these
stateless functions/ops produce the same results independent of how many
times the function is called, and independent of global seed settings.
* `tf.distribute`:
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
@ -100,8 +113,13 @@
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Added `beta` parameter to FTRL optimizer to match paper.
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
* Added `mobilenet_v3` to keras application model.
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
customization of how gradients are aggregated across devices, as well as
`gradients_transformers` to allow for custom gradient transformations
(such as gradient clipping).
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
@ -145,6 +163,14 @@
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* `tf.train.Checkpoint`:
* Now accepts a `root` argument in the initialization, which generates a
checkpoint with a root object. This allows users to create a `Checkpoint`
object that is compatible with Keras `model.save_weights()` and
`model.load_weights`. The checkpoint is also compatible with the
checkpoint saved in the `variables/` folder in the SavedModel.
* When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
@ -241,6 +267,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
* Mutable tables now restore checkpointed values when loaded from SavedModel.
* GPU
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
* Remove environmental variable `TF_USE_CUDNN`.
* Others
* Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`.
* Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations.
@ -1572,6 +1599,7 @@ Yuan (Terry) Tang, Yuchen Ying, Yves-Noel Weweler, zhangyujing, zjjott, zyeric,
color palette of the frame. This has been fixed now
* image.resize now considers proper pixel centers and has new kernels
(incl. anti-aliasing).
* Added an isotonic regression solver (tf.nn.isotonic_regression).
* Performance
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically
dispatches the best kernel implementation based on CPU vector

View File

@ -16,5 +16,5 @@
set configure_dir=%~dp0
set configure_dir=%configure_dir:~0,-1%
python %configure_dir%\configure.py %* || ( exit /b )
python "%configure_dir%\configure.py" %* || ( exit /b )
echo Configuration finished

View File

@ -23,6 +23,7 @@ filegroup(
srcs = [
"c_api.h",
"c_api_experimental.h",
"c_api_macros.h",
"tensor_interface.h",
"tf_attrtype.h",
"tf_datatype.h",
@ -57,10 +58,11 @@ filegroup(
visibility = ["//visibility:public"],
)
filegroup(
cc_library(
name = "pywrap_required_hdrs",
srcs = [
textual_hdrs = [
"c_api_internal.h",
"c_api_macros.h",
"conversion_macros.h",
"python_api.h",
"tensor_interface.h",
@ -79,6 +81,7 @@ tf_cuda_library(
hdrs = [
"c_api.h",
"c_api_internal.h",
"c_api_macros.h",
"tf_datatype.h",
"tf_tensor.h",
"tf_tstring.h",
@ -217,6 +220,7 @@ cc_library(
name = "logging",
srcs = ["logging.cc"],
hdrs = ["logging.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
"//tensorflow/core/platform:logging",
@ -310,6 +314,7 @@ cc_library(
hdrs = ["tf_tensor.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
":tensor_interface",
":tf_datatype",
":tf_status",
@ -336,6 +341,7 @@ tf_cuda_library(
],
visibility = ["//tensorflow:internal"],
deps = [
":c_api_macros",
":tensor_interface",
":tf_datatype",
":tf_status",

View File

@ -30,4 +30,17 @@ limitations under the License.
#endif // _WIN32
#endif // SWIG
// TF_Bool is the C API typedef for unsigned char, while TF_BOOL is
// the datatype for boolean tensors.
#ifndef TF_Bool
#define TF_Bool unsigned char
#endif // TF_Bool
// Macro used to calculate struct size for maintaining ABI stability across
// different struct implementations.
#ifndef TF_OFFSET_OF_END
#define TF_OFFSET_OF_END(TYPE, MEMBER) \
(offsetof(TYPE, MEMBER) + sizeof(((TYPE *)0)->MEMBER))
#endif // TF_OFFSET_OF_END
#endif // TENSORFLOW_C_C_API_MACROS_H_

View File

@ -240,6 +240,7 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:array_grad",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/cc/profiler",
@ -249,6 +250,76 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "mnist_gradients_testutil",
srcs = [
"mnist_gradients_testutil.cc",
],
hdrs = [
"mnist_gradients_testutil.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_tensor",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:span",
],
)
tf_cuda_cc_test(
name = "mnist_gradients_test",
size = "small",
srcs = [
"mnist_gradients_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [
"nomac",
"notap", # TODO(b/166150182): Enable
"no_oss", # TODO(b/166150182): Enable
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradients_internal",
":mnist_gradients_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@ -508,6 +579,27 @@ tf_cuda_cc_test(
],
)
tf_cuda_library(
name = "c_api_remote_test_util",
testonly = 1,
srcs = ["c_api_remote_test_util.cc"],
hdrs = ["c_api_remote_test_util.h"],
visibility = ["//tensorflow:__subpackages__"],
deps = [
":c_api",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "c_api_remote_test",
size = "small",
@ -524,6 +616,7 @@ tf_cuda_cc_test(
":c_api",
":c_api_experimental",
":c_api_internal",
":c_api_remote_test_util",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
@ -540,6 +633,25 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "c_api_remote_function_test",
size = "small",
srcs = [
"c_api_remote_function_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
],
deps = [
":c_api_remote_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cuda_cc_test(
name = "c_api_distributed_test",
size = "small",

View File

@ -518,7 +518,8 @@ void TestDistributedFunctionCancellation(bool inject_error) {
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunctionWithGraphError();
const string function_def = inject_error ? VariableAddFunctionWithGraphError()
: VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace {
void TestRemoteExecuteSilentCopiesFunc(bool async, bool remote,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/true,
heavy_load_on_streaming_rpc,
remote_func_outputs);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/true);
}
} // namespace

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
@ -116,242 +117,24 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
string MatMulFunction(const string& matmul_device) {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
absl::StrCat(" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" device: '",
matmul_device, "'",
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }"),
&def));
return def.SerializeAsString();
}
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
const string matmul_device = remote_func_outputs ? task2_name : "";
string function_def = MatMulFunction(matmul_device);
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async && !remote_func_outputs) {
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
void TestRemoteExecuteSilentCopiesOp(bool async, bool remote,
bool remote_func_outputs = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false,
remote_func_outputs);
}
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
// TODO(b/162618595): Enable this test once we remove the check of remote
// outputs in ProcessFunctionLibraryRuntime.
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
/*func=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/true);
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/false);
}
} // namespace

View File

@ -0,0 +1,222 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using ::tensorflow::string;
string MatMulFunction(const string& matmul_device) {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
absl::StrCat(" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" device: '",
matmul_device, "'",
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }"),
&def));
return def.SerializeAsString();
}
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
const string matmul_device = remote_func_outputs ? task2_name : "";
string function_def = MatMulFunction(matmul_device);
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async && !remote_func_outputs) {
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
if (remote_func_outputs) {
const string backing_device =
TFE_TensorHandleBackingDeviceName(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(backing_device, task2_name);
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
// Run a function containing a MatMul op and check its output.
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false);
#endif // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_

View File

@ -102,6 +102,32 @@ TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
return th;
}
TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
constexpr int64_t dims[] = {100, 100};
constexpr int num_elements = dims[0] * dims[1];

View File

@ -40,6 +40,14 @@ TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
float data[], int64_t dims[],
int num_dims);
// Get a Matrix TensorHandle with given float values and dimensions
TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims);
// Get a Matrix TensorHandle with given int values and dimensions
TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims);
// Return a tensor handle containing a 100x100 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);

View File

@ -85,7 +85,11 @@ class GraphOperation : public TracingOperation {
return errors::FailedPrecondition(
"GraphOperation::Reset must be called before calling SetOpName.");
}
op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name));
// TODO(b/145674566): We use Graph::NewName to get a unique name here but
// this may not be consistent with python's naming policy.
mutex_lock l(g_->mu);
op_.reset(new TF_OperationDescription(g_, op_type_.c_str(),
g_->graph.NewName(op_name).c_str()));
return Status::OK();
}
const string& Name() const override { return op_type_; }

View File

@ -30,6 +30,9 @@ using tensorflow::string;
namespace tensorflow {
namespace {
// The tests are parameterized on:
// - a string representing the tracing implementation: "mlir" or "graphdef".
// - a boolean that when true enables TFRT as the execution engine.
class UnifiedCAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
protected:
@ -554,7 +557,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add1", s);
TF_AbstractOpSetOpName(add_op, "my_add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
@ -576,7 +579,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add2", s);
TF_AbstractOpSetOpName(add_op, "my_add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
@ -983,6 +986,10 @@ TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
// The above tests are run for a combination of:
// - graphdef and MLIR tracing engine
// - Using TFRT as an execution runtime (true == enable TFRT)
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Combine(::testing::Values("graphdef",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/eager/gradients.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
@ -23,25 +24,97 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
Status GradientRegistry::Register(const string& op_name,
GradientFunctionFactory factory) {
namespace {
Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
AbstractTensorHandle** result) {
AbstractOperationPtr op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("ZerosLike", ToId(t)).c_str()));
}
TF_RETURN_IF_ERROR(op->AddInput(t));
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
*result = outputs[0];
return Status::OK();
}
} // namespace
class IncomingGradientsImpl : public IncomingGradients {
public:
explicit IncomingGradientsImpl(
absl::Span<AbstractTensorHandle* const> grad_inputs, Context* ctx,
DefaultGradientFunction* default_gradients)
: grad_inputs_(grad_inputs),
ctx_(ctx),
default_gradients_(default_gradients) {}
AbstractTensorHandle* operator[](int i) const override {
return default_gradients_->get(ctx_, grad_inputs_, i);
}
size_t size() const override { return grad_inputs_.size(); }
private:
absl::Span<AbstractTensorHandle* const> grad_inputs_;
Context* ctx_;
DefaultGradientFunction* default_gradients_;
};
AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op)
: outputs_(op.outputs) {
for (auto output : outputs_) {
output->Ref();
}
}
AbstractTensorHandle* AllZerosDefaultGradients::get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
if (grad_inputs[i]) {
return grad_inputs[i];
}
if (cached_default_grads_[i]) {
return cached_default_grads_[i].get();
}
AbstractTensorHandle* result = nullptr;
Status s = ZerosLike(ctx->ctx, outputs_[i], &result);
if (!s.ok()) {
if (result) {
result->Unref();
}
VLOG(1) << "Failed to create ZerosLike for index " << i;
return nullptr;
}
cached_default_grads_[i].reset(result);
return result;
}
PassThroughDefaultGradients::PassThroughDefaultGradients(
const ForwardOperation& op) {}
AbstractTensorHandle* PassThroughDefaultGradients::get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
return grad_inputs[i];
}
Status GradientRegistry::Register(
const string& op_name, BackwardFunctionFactory backward_function_factory) {
auto iter = registry_.find(op_name);
if (iter != registry_.end()) {
const string error_msg = "Gradient already exists for op: " + op_name + ".";
return errors::AlreadyExists(error_msg);
}
registry_.insert({op_name, factory});
registry_.insert({op_name, backward_function_factory});
return Status::OK();
}
Status GradientRegistry::Lookup(
const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const {
std::unique_ptr<BackwardFunction>* backward_function) const {
auto iter = registry_.find(op.op_name);
if (iter == registry_.end()) {
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
return errors::NotFound(error_msg);
}
grad_fn->reset(iter->second(op));
backward_function->reset(iter->second(op));
return Status::OK();
}
@ -92,33 +165,8 @@ AbstractTensorHandle* TapeTensor::OnesLike() const {
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("ZerosLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
// Returns the number of elements in the gradient tensor.
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
@ -159,13 +207,16 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients(
// Calls the passed-in backward function.
Status TapeVSpace::CallBackwardFunction(
GradientFunction* backward_function,
BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const {
if (backward_function == nullptr) return Status::OK();
Context ctx = {ctx_};
return backward_function->Compute(&ctx, output_gradients, result);
IncomingGradientsImpl incoming_gradients(
output_gradients, &ctx, backward_function->GetDefaultGradientFunction());
return backward_function->GetGradientFunction()->Compute(
&ctx, incoming_gradients, result);
}
// Looks up the ID of a Gradient.
@ -373,15 +424,15 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
}
tape->RecordOperation(
op_->Name(), tape_tensors, input_ids, input_dtypes,
[registry, forward_op_]() -> GradientFunction* {
std::unique_ptr<GradientFunction> grad_fn;
Status s = registry.Lookup(*forward_op_, &grad_fn);
[registry, forward_op_]() -> BackwardFunction* {
std::unique_ptr<BackwardFunction> backward_fn;
Status s = registry.Lookup(*forward_op_, &backward_fn);
if (!s.ok()) {
return nullptr;
}
return grad_fn.release();
return backward_fn.release();
},
[](GradientFunction* ptr) {
[](BackwardFunction* ptr) {
if (ptr) {
delete ptr;
}

View File

@ -55,18 +55,25 @@ struct Context {
public:
AbstractContext* ctx;
};
class IncomingGradients {
public:
virtual AbstractTensorHandle* operator[](int i) const = 0;
virtual size_t size() const = 0;
virtual ~IncomingGradients() {}
};
class GradientFunction {
public:
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
// `grad_inputs`.
virtual Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
virtual ~GradientFunction() {}
};
// Metadata from the forward operation that is made available to the
// gradient registerer to instantiate a GradientFunction.
// gradient registerer to instantiate a BackwardFunction.
struct ForwardOperation {
public:
string op_name;
@ -76,18 +83,86 @@ struct ForwardOperation {
AbstractContext* ctx;
};
using GradientFunctionFactory =
std::function<GradientFunction*(const ForwardOperation& op)>;
// Map from op name to a `GradientFunctionFactory`.
class GradientRegistry {
// Interface for building default zeros gradients for op outputs which are
// missing incoming gradients. Custom implementations of this can be used to
// control which of the forward op's output tensors/their metadata needs to
// be kept around in memory to build the default zeros grad.
//
// Some common helper implementations are provided below.
class DefaultGradientFunction {
public:
Status Register(const string& op, GradientFunctionFactory factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const;
virtual AbstractTensorHandle* get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) = 0;
virtual ~DefaultGradientFunction() {}
};
// Returns zeros for any `nullptr` in `grad_inputs`.
//
// This may require keeping track of all of forward op's output
// tensors and hence may incur a higher memory footprint. Use sparingly.
//
// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor
// handle.
//
// The destructor of this class `Unref`'s any cached tensor handles so users of
// those tensor handles should `Ref` them in order to keep them alive if needed.
class AllZerosDefaultGradients : public DefaultGradientFunction {
public:
explicit AllZerosDefaultGradients(const ForwardOperation& op);
AbstractTensorHandle* get(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) override;
private:
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
// TODO(srbs): We do not always need to keep the tensors around. In immediate
// execution mode we just need to store the shape and dtype. During tracing
// we may need to keep the tensor around if the shape is not full defined.
std::vector<AbstractTensorHandle*> outputs_;
std::vector<AbstractTensorHandlePtr> cached_default_grads_;
};
// Passes through `grad_inputs` as-is. The `GradientFunction`
// will be expected to deal with nullptr in `grad_inputs` if any.
class PassThroughDefaultGradients : public DefaultGradientFunction {
public:
explicit PassThroughDefaultGradients(const ForwardOperation& op);
AbstractTensorHandle* get(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) override;
};
// A `BackwardFunction` wraps a `GradientFunction` and a
// `DefaultGradientFunction`. Both are owned by this class' instance.
class BackwardFunction {
public:
BackwardFunction(GradientFunction* gradient_function,
DefaultGradientFunction* default_gradients)
: gradient_function_(gradient_function),
default_gradients_(default_gradients) {}
GradientFunction* GetGradientFunction() { return gradient_function_.get(); }
DefaultGradientFunction* GetDefaultGradientFunction() {
return default_gradients_.get();
}
private:
std::unique_ptr<GradientFunction> gradient_function_;
std::unique_ptr<DefaultGradientFunction> default_gradients_;
};
using BackwardFunctionFactory =
std::function<BackwardFunction*(const ForwardOperation& op)>;
// Map from op name to a `BackwardFunctionFactory`.
class GradientRegistry {
public:
Status Register(const string& op,
BackwardFunctionFactory backward_function_factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<BackwardFunction>* backward_function) const;
private:
absl::flat_hash_map<string, BackwardFunctionFactory> registry_;
};
// Returns a unique id for the tensor which is used by the tape to build
@ -106,9 +181,16 @@ int64 ToId(AbstractTensorHandle* t);
// allow us to trace the data dependencies between operations and hence compute
// gradients.
//
// This also implements `ZerosLike` and `OnesLike` to create the default
// This also implements `OnesLike` to create the default
// incoming gradients for tensors which do not already have an incoming
// gradient.
//
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
// of default zeros grads is handled by the `DefaultGradientFunction` registered
// for each op.
// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy.
// Figure out a way to avoid this.
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
class TapeTensor {
public:
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
@ -123,7 +205,7 @@ class TapeTensor {
private:
AbstractTensorHandle* handle_;
// The context where OnesLike and ZerosLike ops are to be created.
// The context where OnesLike ops are to be created.
AbstractContext* ctx_;
};
@ -132,7 +214,7 @@ class TapeTensor {
// gradient and for performing gradient aggregation.
// See `tensorflow::eager::VSpace` for more details.
class TapeVSpace
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
: public eager::VSpace<AbstractTensorHandle, BackwardFunction, TapeTensor> {
public:
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace() override {}
@ -147,7 +229,7 @@ class TapeVSpace
// Calls the passed-in backward function.
Status CallBackwardFunction(
GradientFunction* backward_function,
BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const override;
@ -168,8 +250,14 @@ class TapeVSpace
};
// A tracing/immediate-execution agnostic tape.
//
// Gradient functions defined for this library support handling null incoming
// gradients. `Tape::ComputeGradient` should be called with
// `build_default_zeros_grads=false`. Calling with
// `build_default_zeros_grads=true` (the default) is equivalent but just results
// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway.
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
GradientFunction, TapeTensor>;
BackwardFunction, TapeTensor>;
} // namespace gradients
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
@ -23,6 +24,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
@ -35,6 +37,8 @@ namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using std::vector;
using tracing::TracingOperation;
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
@ -47,6 +51,7 @@ class CppGradients
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
return Status::OK();
}
@ -60,9 +65,9 @@ Status Add(AbstractContext* ctx, Tape* tape,
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<tracing::TracingOperation>(add_op.get())) {
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName("my_add"));
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
@ -81,9 +86,9 @@ Status Exp(AbstractContext* ctx, Tape* tape,
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op));
if (isa<tracing::TracingOperation>(exp_op.get())) {
if (isa<TracingOperation>(exp_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(exp_op.get())->SetOpName("my_exp"));
dyn_cast<TracingOperation>(exp_op.get())->SetOpName("my_exp"));
}
TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
@ -91,6 +96,26 @@ Status Exp(AbstractContext* ctx, Tape* tape,
registry);
}
// Computes `IdentityN(inputs)` and records it on the tape.
Status IdentityN(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(identity_n_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
->SetOpName("my_identity_n"));
}
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
int num_retvals = outputs.size();
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
tape, registry);
}
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
@ -113,7 +138,8 @@ Status AddGradModel(AbstractContext* ctx,
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto add_output : add_outputs) {
add_output->Unref();
}
@ -143,7 +169,8 @@ Status ExpGradModel(AbstractContext* ctx,
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto exp_output : exp_outputs) {
exp_output->Unref();
}
@ -152,6 +179,41 @@ Status ExpGradModel(AbstractContext* ctx,
return Status::OK();
}
// Computes
// ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
// This should return [nullptr, 1].
Status IdentityNGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0]));
tape->Watch(ToId(inputs[1]));
vector<AbstractTensorHandle*> identity_n_outputs(2);
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs,
absl::MakeSpan(identity_n_outputs), registry));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto identity_n_output : identity_n_outputs) {
identity_n_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -183,21 +245,36 @@ Status RunModel(Model model, AbstractContext* ctx,
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
std::vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
OutputList output_list;
output_list.expected_num_outputs = outputs.size();
output_list.outputs.resize(outputs.size());
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(output_list.outputs), registry));
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
@ -212,8 +289,19 @@ Status RunModel(Model model, AbstractContext* ctx,
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size();
TF_RETURN_IF_ERROR(fn_op->Execute(outputs, &retvals));
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
@ -360,18 +448,77 @@ TEST_P(CppGradients, TestExpGrad) {
result_tensor = nullptr;
}
// TODO(b/160888630): Enable this test with mlir after AddInputList is
// supported. It is needed for AddN op which is used for gradient aggregation.
TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code:
//
// tape.watch(x1)
// tape.watch(x2)
// unused, y = IdentityN([x1, x2])
// outputs = tape.gradient(y, [x1, x2])
// Expected: [nullptr, 1]
//
// This test is interesting because the current implementation of GradientTape
// would return [0, 1] whereas we use build_default_zeros_grads=false here
// so we get back [nullptr, 1].
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x1;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x1.reset(x_raw);
}
AbstractTensorHandlePtr x2;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x2.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ(outputs[0], nullptr);
TF_Tensor* result_tensor;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
// supported. It is needed for IdentityN.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(true, false),
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif

View File

@ -57,15 +57,10 @@ class ImmediateExecutionContext : public AbstractContext {
// Create a tensor instance from the given data buffer and description.
// `memory_releaser` will be called on destruction, and it's responsible for
// cleaning up the underlying buffer. `convert_string` indicates whether it
// has to handle tstring conversion. Expected to be removed once tstring
// migration is done.
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
const int64_t* dims,
int num_dims, void* data,
size_t len, bool convert_string,
MemoryReleaser memory_releaser,
void* memory_releaser_arg) = 0;
// cleaning up the underlying buffer.
virtual AbstractTensorInterface* CreateTensor(
DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0;
// Create a handle to wrap and manage a Tensor
virtual ImmediateExecutionTensorHandle* CreateLocalHandle(

View File

@ -0,0 +1,781 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
}
};
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyLossRegisterer));
return Status::OK();
}
// ========================= Test Util Functions ==============================
// Get a scalar TensorHandle with given value
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
// Get a Matrix TensorHandle with given float values and dimensions
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
// Get a Matrix TensorHandle with given int values and dimensions
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return Status::OK();
}
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
A.reset(a_raw);
return A;
}
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
A.reset(a_raw);
return A;
}
// =========================== Start Tests ================================
TEST_P(CppGradients, TestMatMulGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
int64_t B_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
AbstractTensorHandlePtr B =
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(A)
* tape.watch(B)
* Y = AB
* outputs = tape.gradient(Y, [A, B])
*/
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dA_tensor;
s = GetValue(outputs[0], &dA_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(dA_tensor),
TF_TensorByteSize(dA_tensor));
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
float tolerance = 1e-3;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
}
TF_Tensor* dB_tensor;
s = GetValue(outputs[1], &dB_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(dB_tensor),
TF_TensorByteSize(dB_tensor));
float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f};
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(dA_tensor);
TF_DeleteTensor(dB_tensor);
}
TEST_P(CppGradients, TestMNISTForward) {
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims);
// W1 = first weights
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1};
int64_t dims_y[] = {2};
num_dims = sizeof(dims_y) / sizeof(dims_y[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims);
GradientRegistry registry;
// Run the Forward Pass
std::vector<AbstractTensorHandle*> outputs(2);
Status s =
RunModel(MNISTForwardModel, ctx.get(),
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Verify the Results
TF_Tensor* scores_tensor;
s = GetValue(outputs[0], &scores_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(scores_tensor),
TF_TensorByteSize(scores_tensor));
float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f};
float tolerance = 1e-3;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
}
TF_Tensor* loss_vals_tensor;
s = GetValue(outputs[1], &loss_vals_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
TF_TensorByteSize(loss_vals_tensor));
float expected_losses[2] = {9.6f, 27.2f};
for (int j = 0; j < 2; j++) {
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(scores_tensor);
TF_DeleteTensor(loss_vals_tensor);
}
TEST_P(CppGradients, TestMNISTForward2) {
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
int64_t X_dims[] = {3, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// W1 = first weights
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1, 1};
int64_t y_dims[] = {3};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
GradientRegistry registry;
// Run the Forward Pass
std::vector<AbstractTensorHandle*> outputs(2);
Status s =
RunModel(MNISTForwardModel, ctx.get(),
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Verify the Results
TF_Tensor* scores_tensor;
s = GetValue(outputs[0], &scores_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[6] = {0};
memcpy(&result_data[0], TF_TensorData(scores_tensor),
TF_TensorByteSize(scores_tensor));
float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f};
float tolerance = 1e-3;
for (int j = 0; j < 6; j++) {
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
}
TF_Tensor* loss_vals_tensor;
s = GetValue(outputs[1], &loss_vals_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
TF_TensorByteSize(loss_vals_tensor));
float expected_losses[3] = {9.6f, 27.2f, 44.8f};
for (int j = 0; j < 3; j++) {
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(scores_tensor);
TF_DeleteTensor(loss_vals_tensor);
}
TEST_P(CppGradients, TestMatMulTranspose) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
int64_t X_dims[] = {2, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// W1 = first weights
float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
GradientRegistry registry;
// Run the MatMul Op
std::vector<AbstractTensorHandle*> outputs(1);
Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Verify the Results
TF_Tensor* scores_tensor;
s = GetValue(outputs[0], &scores_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[6] = {0};
memcpy(&result_data[0], TF_TensorData(scores_tensor),
TF_TensorByteSize(scores_tensor));
float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f};
float tolerance = 1e-3;
for (int j = 0; j < 6; j++) {
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
}
}
TEST_P(CppGradients, TestReluGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
int64_t X_dims[] = {3, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(X)
* Y = Relu(X)
* outputs = tape.gradient(Y, [X])
*/
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(ReluGradModel, ctx.get(), {X.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dX_tensor;
s = GetValue(outputs[0], &dX_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[9] = {0};
memcpy(&result_data[0], TF_TensorData(dX_tensor),
TF_TensorByteSize(dX_tensor));
float expected_dX[9] = {1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
float tolerance = 1e-3;
for (int j = 0; j < 9; j++) {
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
}
outputs[0]->Unref();
TF_DeleteTensor(dX_tensor);
}
TEST_P(CppGradients, TestSoftmaxLossGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = scores
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
int64_t X_dims[] = {3, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// y = labels
int y_vals[] = {1, 0, 1};
int64_t y_dims[] = {3};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(X)
* tape.watch(labels)
* loss = SoftmaxLoss(X, labels)
* outputs = tape.gradient(loss, [X, labels])
*
*
*/
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SoftmaxLossGradModel, ctx.get(), {X.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dX_tensor;
s = GetValue(outputs[0], &dX_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[9] = {0};
memcpy(&result_data[0], TF_TensorData(dX_tensor),
TF_TensorByteSize(dX_tensor));
float expected_dX[9] = {0.090f, -0.7553f, 0.6652f, -0.9099f, 0.2447f,
0.6652f, 0.8437f, -0.8858f, 0.0420f};
float tolerance = 1e-3;
for (int j = 0; j < 9; j++) {
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
}
// Only Unref() first output as 2nd is nullptr grad for labels
outputs[0]->Unref();
TF_DeleteTensor(dX_tensor);
}
TEST_P(CppGradients, TestMNISTGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t X_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// W1 = first weights
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1};
int64_t y_dims[] = {2};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
// Register Grads
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
*
* tape.watch(W1)
* tape.watch(W2)
* mm = X*W1
* hidden = Relu(mm)
* scores = W2*hidden
* loss = SoftmaxLoss(scores, y)
* outputs = tape.gradient(loss, [A, B])
*
*/
std::vector<AbstractTensorHandle*> outputs(3);
s = RunModel(MNISTGradModel, ctx.get(),
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float tolerance = 1e-3;
TF_Tensor* dW1_tensor;
s = GetValue(outputs[0], &dW1_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(dW1_tensor),
TF_TensorByteSize(dW1_tensor));
float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f};
; // dLoss
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance);
}
TF_Tensor* dW2_tensor;
s = GetValue(outputs[1], &dW2_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(dW2_tensor),
TF_TensorByteSize(dW2_tensor));
float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f}; // dLoss
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
outputs[2]->Unref();
TF_DeleteTensor(dW1_tensor);
TF_DeleteTensor(dW2_tensor);
}
TEST_P(CppGradients, TestScalarMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr eta;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
eta.reset(x_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
GradientRegistry registry;
std::vector<AbstractTensorHandle*> outputs(1);
Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dA_tensor;
s = GetValue(outputs[0], &dA_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(dA_tensor),
TF_TensorByteSize(dA_tensor));
float tolerance = 1e-3;
float eta_val = 1.5f;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance);
}
outputs[0]->Unref();
TF_DeleteTensor(dA_tensor);
}
TEST_P(CppGradients, TestMNIST_Training) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t X_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// TODO(amturati): use random initializer for weights instead of
// constant values.
// W1 = first weights
float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1};
int64_t y_dims[] = {2};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
// Register Grads
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Prepare for training
std::vector<AbstractTensorHandle*> weights;
weights.push_back(W1.get());
weights.push_back(W2.get());
// Set learning rate to be 1e-1
AbstractTensorHandle* learning_rate = nullptr;
s = TestScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Train
int num_iters = 10;
std::vector<AbstractTensorHandle*> mnist_outputs(3);
std::vector<AbstractTensorHandle*> grads(2);
for (int i = 0; i < num_iters; i++) {
// Run Forward Pass
s = RunModel(MNISTGradModel, ctx.get(),
{X.get(), weights[0], weights[1], y.get()},
absl::MakeSpan(mnist_outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Fill grads
grads[0] = mnist_outputs[0];
grads[1] = mnist_outputs[1];
// Gradient Update
s = UpdateWeights(ctx.get(), grads, weights, learning_rate);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
grads[0]->Unref(); // release W1_grad
grads[1]->Unref(); // release W2_grad
mnist_outputs[2]->Unref(); // release loss
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,594 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
using std::vector;
using tracing::TracingOperation;
// ========================== Tape Ops ==============================
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry) {
AbstractOperationPtr matmul_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(matmul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
int num_retvals = 1;
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr relu_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(relu_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyLoss(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractTensorHandle* scores = inputs[0];
AbstractTensorHandle* labels = inputs[1];
AbstractOperationPtr sm_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(sm_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
int num_retvals = 2; // returns loss values and backprop
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
//===================== Test Models to run =========================
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
registry)); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto add_output : add_outputs) {
add_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
// Computes
// y = inputs[0] * inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status MatMulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
vector<AbstractTensorHandle*> mm_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute x*y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto mm_output : mm_outputs) {
mm_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
// Model to run 2-layer net
Status MNISTForwardModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
/**
* We will trace a 2-layer fully connected network for an MNIST model:
*
* def mnist_forward(X, W1, W2, y_labels):
* mm_out_1 = tf.matmul(X,W1)
* hidden_layer = tf.nn.relu(mm_out_1)
* scores = tf.matmul(hidden_layer,W2)
* softmax =
* tf.nn.sparse_softmax_cross_entropy_with_logits(scores,y_labels) return
* scores, softmax
*
* Use this convention for inputs:
*
* inputs = [X, W1, W2, y_labels]
*
*/
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
AbstractTensorHandle* W2 = inputs[2];
AbstractTensorHandle* y_labels = inputs[3];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W2.
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
absl::MakeSpan(temp_outputs), "relu",
registry)); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false,
registry)); // Compute W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
AbstractTensorHandle* loss_vals = temp_outputs[0];
outputs[0] = scores;
outputs[1] = loss_vals;
delete tape;
return Status::OK();
}
Status MatMulTransposeModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(X));
tape->Watch(ToId(W1));
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/true,
/*transpose_b=*/false, registry)); // Compute X*W1
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
}
Status ReluGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch X
vector<AbstractTensorHandle*> relu_outputs(1);
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
"relu0", registry)); // Relu(X)
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(relu_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto relu_output : relu_outputs) {
relu_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
Status SoftmaxLossGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch scores.
tape->Watch(ToId(inputs[1])); // Watch labels.
vector<AbstractTensorHandle*> sm_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sm_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
Status MNISTGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
AbstractTensorHandle* W2 = inputs[2];
AbstractTensorHandle* y_labels = inputs[3];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/true);
tape->Watch(ToId(X)); // Watch X.
tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W1.
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
AbstractTensorHandle* mm = temp_outputs[0];
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
absl::MakeSpan(temp_outputs), // Relu(X*W1)
"relu0", registry));
AbstractTensorHandle* hidden = temp_outputs[0];
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false,
registry)); // W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss", registry)); // W2*Relu(X*W1)
AbstractTensorHandle* loss = temp_outputs[0];
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(
tape->ComputeGradient(vspace, /*target_tensor_ids=*/{ToId(loss)},
/*source_tensor_ids=*/{ToId(W1), ToId(W2)},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
// Only release 2nd temp output as first holds loss values.
temp_outputs[1]->Unref();
outputs[0] = out_grads[0]; // dW1
outputs[1] = out_grads[1]; // dW2
outputs[2] = loss;
delete tape;
return Status::OK();
}
Status ScalarMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* eta = inputs[0];
AbstractTensorHandle* A = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
"scalarMul0", registry)); // Compute eta*A
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
}
// ============================= End Models ================================
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate) {
/* Update weights one by one using gradient update rule:
*
* w -= lr*grad[w]
*
* NOTE: assuming learning rate is positive
*/
Status s;
int num_grads = grads.size();
vector<AbstractTensorHandle*> temp_outputs(1);
std::string update_str;
// Negate learning rate for gradient descent
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
absl::MakeSpan(temp_outputs),
"neg_lr")); // Compute -lr
learning_rate = temp_outputs[0];
for (int i = 0; i < num_grads; i++) {
// Compute dW = -lr * grad(w[i])
update_str = "update_mul_" + std::to_string(i);
s = ops::Mul(ctx, {learning_rate, grads[i]}, absl::MakeSpan(temp_outputs),
update_str.c_str());
AbstractTensorHandle* dW = temp_outputs[0];
// Compute temp = weights[i] + dW
update_str = "update_add_" + std::to_string(i);
s = ops::Add(ctx, {weights[i], dW}, absl::MakeSpan(temp_outputs),
update_str.c_str());
// Update the weights
weights[i] = temp_outputs[0];
}
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
return unwrap(graph_ctx);
}
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), &handle));
params->emplace_back(handle);
}
return Status::OK();
}
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry) {
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
AbstractOperationPtr fn_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
return model(ctx, inputs, outputs, registry);
}
}
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}

View File

@ -0,0 +1,146 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
using namespace tensorflow;
using namespace tensorflow::gradients;
using namespace tensorflow::gradients::internal;
// ========================== Tape Ops ==============================
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` and records it on the tape.
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyLoss(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// ====================== End Tape Ops ============================
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes
// y = inputs[0] * inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status MatMulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes 2-layer Neural Network with Softmax Loss.
Status MNISTForwardModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes MatMul with first matrix tranposed.
Status MatMulTransposeModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify ReluGrad functionality
Status ReluGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify SoftmaxGrad functionality
Status SoftmaxLossGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify Multi-grad functionality for MNIST
Status MNISTGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify scalar-tensor multiplication Op
Status ScalarMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Updates the weights for a neural network given incoming grads and learning
// rate
Status UpdateWeights(AbstractContext* ctx,
std::vector<AbstractTensorHandle*>& grads,
std::vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate);
AbstractContext* BuildFunction(const char* fn_name);
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
std::vector<AbstractTensorHandle*>* params);
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry);
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);

View File

@ -76,10 +76,26 @@ cc_library(
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
],
)
tf_cc_test(
name = "parallel_device_lib_test",
srcs = ["parallel_device_lib_test.cc"],
deps = [
":parallel_device_lib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "parallel_device_testlib",
testonly = 1,

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
@ -118,6 +119,9 @@ class DeviceThread {
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
// Outputs
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
// TF_Status is an incomplete type and so can't be stack allocated. To avoid
// unnecessary allocations each Execute call, we keep one heap-allocated
// version for the thread.
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
@ -188,6 +192,9 @@ std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
if (TF_GetCode(status_.get()) != TF_OK) {
TF_SetStatus(status, TF_GetCode(status_.get()),
TF_Message(status_.get()));
// Reset the member `status_` so future op executions (after recovery from
// the bad `status`) start with an OK status.
TF_SetStatus(status_.get(), TF_OK, "");
}
execution_state_ = ExecutionState::kIdle;
result = std::move(op_outputs_);
@ -255,18 +262,27 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
std::unique_ptr<ParallelTensor> ParallelDevice::Vector(
TFE_Context* context, TF_Status* status,
absl::Span<const int32_t> values) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
if (values.size() != num_underlying_devices()) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
"Number of values did not match number of underlying devices.");
return nullptr;
}
for (int device_index = 0; device_index < num_underlying_devices();
++device_index) {
int32_t* device_id = new int32_t;
*device_id = device_index;
int32_t* device_value = new int32_t;
*device_value = values[device_index];
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value,
sizeof(int32_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int32_t*>(data);
@ -295,6 +311,16 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
std::vector<int32_t> ids;
ids.reserve(num_underlying_devices());
for (int i = 0; i < num_underlying_devices(); ++i) {
ids.push_back(i);
}
return Vector(context, status, ids);
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(TFE_Context* context,
const std::vector<ParallelTensor*>& inputs,
@ -319,21 +345,36 @@ ParallelDevice::Execute(TFE_Context* context,
std::move(device_inputs), attributes,
expected_max_outputs);
}
StatusPtr first_bad_status(nullptr);
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
per_device_output_tensors.push_back(device_thread->Join(status));
if (TF_GetCode(status) != TF_OK) return result;
// We will run every Join even if there are bad statuses in case the user
// wants to recover and continue running ops on the parallel device (which
// would otherwise deadlock).
if (TF_GetCode(status) != TF_OK && first_bad_status == nullptr) {
first_bad_status.reset(TF_NewStatus());
TF_SetStatus(first_bad_status.get(), TF_GetCode(status),
TF_Message(status));
}
if (device_index == 0) {
first_op_output_count = per_device_output_tensors.rbegin()->size();
} else {
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
if (first_bad_status == nullptr &&
per_device_output_tensors.rbegin()->size() != first_op_output_count) {
first_bad_status.reset(TF_NewStatus());
TF_SetStatus(first_bad_status.get(), TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
}
if (first_bad_status != nullptr) {
TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
TF_Message(first_bad_status.get()));
return result;
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
@ -61,6 +62,11 @@ class ParallelDevice {
TFE_TensorHandle* tensor,
TF_Status* status) const;
// Construct a parallel tensor consisting of the scalar values from `values`.
std::unique_ptr<ParallelTensor> Vector(
TFE_Context* context, TF_Status* status,
absl::Span<const int32_t> values) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;

View File

@ -0,0 +1,84 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace parallel_device {
TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::vector<std::string> devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
ParallelDevice parallel_device(std::move(devices));
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
auto outputs =
parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
"VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
/*expected_max_outputs=*/1, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
std::vector<ParallelTensor*> handle_inputs;
handle_inputs.reserve(handles.size());
for (auto& handle : handles) {
handle_inputs.push_back(handle.get());
}
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> read_op(
TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT);
parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp",
TFE_OpGetAttrs(read_op.get()),
/*expected_max_outputs=*/1, status.get());
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
TF_SetStatus(status.get(), TF_OK, "");
// Check that ops still run successfully on the device.
parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
"VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
/*expected_max_outputs=*/1, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -146,13 +146,16 @@ class GradientTape {
// once) and produces the gradient of the target tensors with respect to the
// source tensors. The output gradients are used if not empty and not
// null. The result is populated with one tensor per target element.
// When running backward functions, builds zeros-like tensors for
// incoming grads which are nullptrs, unless `build_default_zeros_grads`
// is set to false.
Status ComputeGradient(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result);
std::vector<Gradient*>* result, bool build_default_zeros_grads = true);
bool IsPersistent() const { return persistent_; }
@ -655,8 +658,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
gtl::ArraySlice<Gradient*> output_gradients, std::vector<Gradient*>* result,
bool build_default_zeros_grads) {
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
@ -717,14 +720,14 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
out_gradients.push_back(nullptr);
zero_indices.push_back(i);
out_gradients.push_back(nullptr);
if (build_default_zeros_grads) {
auto func_name_it =
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() ||
func_name_it->second.find(i) == func_name_it->second.end()) {
zero_indices.push_back(i);
}
}
} else {
any_gradient_nonzero = true;
@ -745,6 +748,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
}
}
std::vector<Gradient*> in_gradients;
DCHECK(build_default_zeros_grads || zero_indices.empty());
if (any_gradient_nonzero) {
for (const auto i : zero_indices) {
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();

View File

@ -879,32 +879,6 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_SetStatus(status, TF_OK, "");
}
// TODO(vnvo2409): `RewriteObjectBlocking` will set `status` to `TF_NOT_FOUND`
// if the object does not exist. In that case, we will have to check if the
// `src` is a directory or not to set the correspondent `status` (i.e
// `TF_NOT_FOUND` if path `src` does not exist, `TF_FAILED_PRECONDITION` if
// path `src` is a directory).
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string bucket_dst, object_dst;
ParseGCSPath(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst);
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket_src, object_src);
TF_SetStatusFromGCSStatus(gcs_status, status);
}
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_Status* status) {
std::string bucket_src, object_src;
@ -956,6 +930,100 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
return false;
}
static void RenameObject(const TF_Filesystem* filesystem,
const std::string& src, const std::string& dst,
TF_Status* status) {
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string bucket_dst, object_dst;
ParseGCSPath(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst);
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK) return;
ClearFileCaches(gcs_file, dst);
DeleteFile(filesystem, src.c_str(), status);
}
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
if (!IsDirectory(filesystem, src, status)) {
if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
RenameObject(filesystem, src, dst, status);
return;
}
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, src, UINT64_MAX, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
std::string src_dir = src;
std::string dst_dir = dst;
MaybeAppendSlash(&src_dir);
MaybeAppendSlash(&dst_dir);
for (const std::string& children : childrens) {
RenameObject(filesystem, src_dir + children, dst_dir + children, status);
if (TF_GetCode(status) != TF_OK) return;
}
TF_SetStatus(status, TF_OK, "");
}
void DeleteRecursively(const TF_Filesystem* filesystem, const char* path,
uint64_t* undeleted_files, uint64_t* undeleted_dirs,
TF_Status* status) {
if (!undeleted_files || !undeleted_dirs)
return TF_SetStatus(
status, TF_INTERNAL,
"'undeleted_files' and 'undeleted_dirs' cannot be nullptr.");
*undeleted_files = 0;
*undeleted_dirs = 0;
if (!IsDirectory(filesystem, path, status)) {
*undeleted_dirs = 1;
return;
}
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, path, UINT64_MAX, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
std::string dir = path;
MaybeAppendSlash(&dir);
for (const std::string& children : childrens) {
const std::string& full_path = dir + children;
DeleteFile(filesystem, full_path.c_str(), status);
if (TF_GetCode(status) != TF_OK) {
if (IsDirectory(filesystem, full_path.c_str(), status))
// The object is a directory marker.
(*undeleted_dirs)++;
else
(*undeleted_files)++;
}
}
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, path, UINT64_MAX, false, false, status);
if (TF_GetCode(status) != TF_OK) return -1;
int num_entries = childrens.size();
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++)
(*entries)[i] = strdup(childrens[i].c_str());
TF_SetStatus(status, TF_OK, "");
return num_entries;
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
std::string bucket, object;
@ -993,6 +1061,17 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
}
}
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri);
}
static void FlushCaches(const TF_Filesystem* filesystem) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
gcs_file->file_block_cache->Flush();
gcs_file->stat_cache->Clear();
}
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
@ -1009,6 +1088,13 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
@ -1018,6 +1104,20 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_gcs_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_gcs_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_gcs_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_gcs_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_gcs_filesystem::DeleteDir;
ops->filesystem_ops->delete_recursively =
tf_gcs_filesystem::DeleteRecursively;
ops->filesystem_ops->copy_file = tf_gcs_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_gcs_filesystem::PathExists;
ops->filesystem_ops->is_directory = tf_gcs_filesystem::IsDirectory;
ops->filesystem_ops->stat = tf_gcs_filesystem::Stat;
ops->filesystem_ops->get_children = tf_gcs_filesystem::GetChildren;
ops->filesystem_ops->translate_name = tf_gcs_filesystem::TranslateName;
ops->filesystem_ops->flush_caches = tf_gcs_filesystem::FlushCaches;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -26,6 +26,8 @@ cc_library(
}),
deps = [
":aws_crypto",
":aws_logging",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@aws",
@ -45,6 +47,18 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "aws_logging",
srcs = ["aws_logging.cc"],
hdrs = ["aws_logging.h"],
deps = [
"//tensorflow/c:logging",
"@aws",
"@com_google_absl//absl/synchronization",
],
alwayslink = 1,
)
tf_cc_test(
name = "s3_filesystem_test",
srcs = [

View File

@ -0,0 +1,159 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h"
#include <aws/core/Aws.h>
#include <aws/core/utils/logging/AWSLogging.h>
#include <aws/core/utils/logging/LogSystemInterface.h>
#include <cstdarg>
#include <cstdio>
#include <sstream>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/logging.h"
static constexpr char kAWSLoggingTag[] = "AWSLogging";
static const std::map<const std::string, const Aws::Utils::Logging::LogLevel>
log_levels_string_to_aws = {
{"off", Aws::Utils::Logging::LogLevel::Off},
{"fatal", Aws::Utils::Logging::LogLevel::Fatal},
{"error", Aws::Utils::Logging::LogLevel::Error},
{"warn", Aws::Utils::Logging::LogLevel::Warn},
{"info", Aws::Utils::Logging::LogLevel::Info},
{"debug", Aws::Utils::Logging::LogLevel::Debug},
{"trace", Aws::Utils::Logging::LogLevel::Trace}};
static const std::map<const int, const Aws::Utils::Logging::LogLevel>
log_levels_tf_to_aws = {{0, Aws::Utils::Logging::LogLevel::Info},
{1, Aws::Utils::Logging::LogLevel::Warn},
{2, Aws::Utils::Logging::LogLevel::Error},
{3, Aws::Utils::Logging::LogLevel::Fatal}};
namespace tf_s3_filesystem {
AWSLogSystem::AWSLogSystem(Aws::Utils::Logging::LogLevel log_level)
: log_level_(log_level) {}
void AWSLogSystem::LogMessage(Aws::Utils::Logging::LogLevel log_level,
const std::string& message) {
if (message == "Initializing Curl library") return;
switch (log_level) {
case Aws::Utils::Logging::LogLevel::Info:
TF_Log(TF_INFO, message.c_str());
break;
case Aws::Utils::Logging::LogLevel::Warn:
TF_Log(TF_WARNING, message.c_str());
break;
case Aws::Utils::Logging::LogLevel::Error:
TF_Log(TF_ERROR, message.c_str());
break;
case Aws::Utils::Logging::LogLevel::Fatal:
TF_Log(TF_FATAL, message.c_str());
break;
default:
// this will match for DEBUG, TRACE
TF_Log(TF_INFO, message.c_str());
break;
}
}
void AWSLogSystem::Log(Aws::Utils::Logging::LogLevel log_level, const char* tag,
const char* format, ...) {
char buffer[256];
va_list args;
va_start(args, format);
vsnprintf(buffer, 256, format, args);
va_end(args);
LogMessage(log_level, buffer);
}
void AWSLogSystem::LogStream(Aws::Utils::Logging::LogLevel log_level,
const char* tag,
const Aws::OStringStream& message_stream) {
LogMessage(log_level, message_stream.rdbuf()->str().c_str());
}
void AWSLogSystem::Flush() { return; }
static Aws::Utils::Logging::LogLevel TfLogLevelToAwsLogLevel(int level) {
// Converts TF Log Levels INFO, WARNING, ERROR and FATAL to the AWS enum
// values for the levels
if (log_levels_tf_to_aws.find(level) != log_levels_tf_to_aws.end()) {
return log_levels_tf_to_aws.at(level);
} else {
// default to fatal
return Aws::Utils::Logging::LogLevel::Fatal;
}
}
static Aws::Utils::Logging::LogLevel ParseAwsLogLevelFromEnv() {
// defaults to FATAL log level for the AWS SDK
// this is because many normal tensorflow operations are logged as errors in
// the AWS SDK such as checking if a file exists can log an error in AWS SDK
// if the file does not actually exist. Another such case is when reading a
// file till the end, TensorFlow expects to see an InvalidRange exception at
// the end, but this would be an error in the AWS SDK. This confuses users,
// hence the default setting.
Aws::Utils::Logging::LogLevel log_level =
Aws::Utils::Logging::LogLevel::Fatal;
const char* aws_env_var_val = getenv("AWS_LOG_LEVEL");
if (aws_env_var_val != nullptr) {
std::string maybe_integer_str(aws_env_var_val, strlen(aws_env_var_val));
std::istringstream ss(maybe_integer_str);
int level;
ss >> level;
if (ss.fail()) {
// wasn't a number
// expecting a string
std::string level_str = maybe_integer_str;
if (log_levels_string_to_aws.find(level_str) !=
log_levels_string_to_aws.end()) {
log_level = log_levels_string_to_aws.at(level_str);
}
} else {
// backwards compatibility
// valid number, but this number follows the standard TensorFlow log
// levels need to convert this to AWS SDK logging level number
log_level = TfLogLevelToAwsLogLevel(level);
}
}
return log_level;
}
static bool initialized = false;
ABSL_CONST_INIT static absl::Mutex s3_logging_mutex(absl::kConstInit);
void AWSLogSystem::InitializeAWSLogging() {
absl::MutexLock l(&s3_logging_mutex);
if (!initialized) {
Aws::Utils::Logging::InitializeAWSLogging(Aws::MakeShared<AWSLogSystem>(
kAWSLoggingTag, ParseAwsLogLevelFromEnv()));
initialized = true;
return;
}
}
void AWSLogSystem::ShutdownAWSLogging() {
absl::MutexLock l(&s3_logging_mutex);
if (initialized) {
Aws::Utils::Logging::ShutdownAWSLogging();
initialized = false;
return;
}
}
} // namespace tf_s3_filesystem

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_
#include <aws/core/utils/logging/LogLevel.h>
#include <aws/core/utils/logging/LogSystemInterface.h>
#include <atomic>
#include <string>
namespace tf_s3_filesystem {
class AWSLogSystem : public Aws::Utils::Logging::LogSystemInterface {
public:
static void InitializeAWSLogging();
static void ShutdownAWSLogging();
explicit AWSLogSystem(Aws::Utils::Logging::LogLevel log_level);
virtual ~AWSLogSystem() = default;
// Gets the currently configured log level.
Aws::Utils::Logging::LogLevel GetLogLevel(void) const override {
return log_level_;
}
// Set a new log level. This has the immediate effect of changing the log.
void SetLogLevel(Aws::Utils::Logging::LogLevel log_level) {
log_level_.store(log_level);
}
// Does a printf style output to ProcessFormattedStatement. Don't use this,
// it's unsafe. See LogStream.
void Log(Aws::Utils::Logging::LogLevel log_level, const char* tag,
const char* format, ...) override;
// Writes the stream to ProcessFormattedStatement.
void LogStream(Aws::Utils::Logging::LogLevel log_level, const char* tag,
const Aws::OStringStream& messageStream) override;
// Flushes the buffered messages if the logger supports buffering
void Flush() override;
private:
void LogMessage(Aws::Utils::Logging::LogLevel log_level,
const std::string& message);
std::atomic<Aws::Utils::Logging::LogLevel> log_level_;
};
} // namespace tf_s3_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_

View File

@ -38,6 +38,8 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for S3 environments.
@ -186,6 +188,8 @@ static void GetS3Client(tf_s3_filesystem::S3File* s3_file) {
absl::MutexLock l(&s3_file->initialization_lock);
if (s3_file->s3_client.get() == nullptr) {
tf_s3_filesystem::AWSLogSystem::InitializeAWSLogging();
Aws::SDKOptions options;
options.cryptoOptions.sha256Factory_create_fn = []() {
return Aws::MakeShared<tf_s3_filesystem::AWSSHA256Factory>(
@ -250,6 +254,7 @@ static void ShutdownClient(Aws::S3::S3Client* s3_client) {
delete s3_client;
Aws::SDKOptions options;
Aws::ShutdownAPI(options);
tf_s3_filesystem::AWSLogSystem::ShutdownAWSLogging();
}
}
@ -281,6 +286,7 @@ void Cleanup(TF_RandomAccessFile* file) {
static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
TF_VLog(3, "ReadFile using S3Client\n");
Aws::S3::Model::GetObjectRequest get_object_request;
get_object_request.WithBucket(s3_file->bucket).WithKey(s3_file->object);
Aws::String bytes =
@ -306,12 +312,14 @@ static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
TF_VLog(3, "Using TransferManager\n");
auto create_download_stream = [&]() {
return Aws::New<TFS3UnderlyingStream>(
"S3ReadStream",
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
"S3ReadStream", reinterpret_cast<unsigned char*>(buffer), n));
};
TF_VLog(3, "Created stream to read with transferManager\n");
auto handle = s3_file->transfer_manager->DownloadFile(
s3_file->bucket, s3_file->object, offset, n, create_download_stream);
handle->WaitUntilFinished();
@ -322,6 +330,10 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
retries++ < kDownloadRetries) {
// Only failed parts will be downloaded again.
TF_VLog(
1,
"Retrying read of s3://%s/%s after failure. Current retry count: %u\n",
s3_file->bucket.c_str(), s3_file->object.c_str(), retries);
s3_file->transfer_manager->RetryDownload(handle);
handle->WaitUntilFinished();
}
@ -341,6 +353,8 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
TF_VLog(1, "ReadFilefromS3 s3://%s/%s from %u for n: %u\n",
s3_file->bucket.c_str(), s3_file->object.c_str(), offset, n);
if (s3_file->use_multi_part_download)
return ReadS3TransferManager(s3_file, offset, n, buffer, status);
else
@ -416,6 +430,8 @@ void Sync(const TF_WritableFile* file, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
return;
}
TF_VLog(1, "WriteFileToS3: s3://%s/%s\n", s3_file->bucket.c_str(),
s3_file->object.c_str());
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
auto handle = s3_file->transfer_manager->UploadFile(
s3_file->outfile, s3_file->bucket, s3_file->object,
@ -426,6 +442,10 @@ void Sync(const TF_WritableFile* file, TF_Status* status) {
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
retries++ < kUploadRetries) {
// if multipart upload was used, only the failed parts will be re-sent
TF_VLog(1,
"Retrying upload of s3://%s/%s after failure. Current retry count: "
"%u\n",
s3_file->bucket.c_str(), s3_file->object.c_str(), retries);
s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle);
handle->WaitUntilFinished();
}
@ -613,6 +633,7 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
TF_VLog(1, "Stat on path: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -737,6 +758,8 @@ static void SimpleCopyFile(const Aws::String& source,
const Aws::String& bucket_dst,
const Aws::String& object_dst, S3File* s3_file,
TF_Status* status) {
TF_VLog(1, "SimpleCopyFile from %s to %s/%s\n", bucket_dst.c_str(),
object_dst.c_str());
Aws::S3::Model::CopyObjectRequest copy_object_request;
copy_object_request.WithCopySource(source)
.WithBucket(bucket_dst)
@ -801,6 +824,8 @@ static void MultiPartCopy(const Aws::String& source,
const Aws::String& object_dst, const size_t num_parts,
const uint64_t file_size, S3File* s3_file,
TF_Status* status) {
TF_VLog(1, "MultiPartCopy from %s to %s/%s\n", bucket_dst.c_str(),
object_dst.c_str());
Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request;
create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst);
@ -827,6 +852,8 @@ static void MultiPartCopy(const Aws::String& source,
auto chunk_size =
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
TF_VLog(1, "Copying from %s in %u parts of size %u each\n", source.c_str(),
num_parts, chunk_size);
size_t retries = 0;
while (retries++ < 3) {
// Queue up parts.
@ -891,6 +918,9 @@ static void MultiPartCopy(const Aws::String& source,
status);
} else {
// Retry.
TF_Log(TF_ERROR,
"Retrying failed copy of part %u due to an error with S3\n",
part_number);
num_finished_parts--;
}
}
@ -967,6 +997,7 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_VLog(1, "DeleteFile: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -985,6 +1016,7 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_VLog(1, "CreateDir: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -1026,6 +1058,7 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_VLog(1, "DeleteDir: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -1060,6 +1093,7 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
TF_VLog(1, "RenameFile from: %s to %s\n", src, dst);
Aws::String bucket_src, object_src;
ParseS3Path(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
@ -1120,6 +1154,7 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
TF_VLog(1, "GetChildren for path: %s\n", path);
Aws::String bucket, prefix;
ParseS3Path(path, true, &bucket, &prefix, status);
if (TF_GetCode(status) != TF_OK) return -1;

View File

@ -3,6 +3,24 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "array_grad",
srcs = ["array_grad.cc"],
hdrs = [
"array_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/core/lib/llvm_rtti",
],
)
cc_library(
name = "math_grad",
srcs = ["math_grad.cc"],
@ -19,6 +37,28 @@ cc_library(
"//tensorflow/c/eager:gradients",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
],
)
cc_library(
name = "nn_grad",
srcs = ["nn_grad.cc"],
hdrs = [
"nn_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
],
)

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/array_grad.h"
namespace tensorflow {
namespace gradients {
namespace {
using std::vector;
class IdentityNGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(grad_inputs.size(), nullptr);
for (int i = 0; i < grad_inputs.size(); i++) {
auto grad_input = grad_inputs[i];
// TODO(srbs): Should we add a copy contructor to AbstractTensorHandle
// that takes care of this similar to `Tensor`?
if (grad_input) {
grad_input->Ref();
}
(*grad_outputs)[i] = grad_input;
}
return Status::OK();
}
~IdentityNGradientFunction() override {}
};
} // namespace
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) {
auto gradient_function = new IdentityNGradientFunction;
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_

View File

@ -15,13 +15,17 @@ limitations under the License.
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::Identity;
using tensorflow::ops::MatMul;
using tensorflow::ops::Mul;
using tensorflow::ops::ZerosLike;
namespace tensorflow {
namespace gradients {
@ -29,20 +33,23 @@ namespace {
class AddGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
vector<AbstractTensorHandle*> identity_outputs(1);
// TODO(b/145674566): Handle name unification in tracing code.
// TODO(b/161805092): Support broadcasting.
std::string name = "Identity_A";
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(identity_outputs),
"Identity0"));
name.c_str()));
(*grad_outputs)[0] = identity_outputs[0];
name = "Identity_B";
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(identity_outputs),
"Identity1"));
name.c_str()));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK();
}
@ -54,16 +61,18 @@ class ExpGradientFunction : public GradientFunction {
explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
exp->Ref();
}
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
vector<AbstractTensorHandle*> conj_outputs(1);
TF_RETURN_IF_ERROR(
Conj(ctx->ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), "ExpConj"));
std::string name = "Conj_Exp_Grad";
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {exp_.get()},
absl::MakeSpan(conj_outputs), name.c_str()));
AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
grad_outputs->resize(1);
name = "Mul_Exp_Grad";
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]},
absl::MakeSpan(*grad_outputs), "ExpGradMul"));
absl::MakeSpan(*grad_outputs), name.c_str()));
return Status::OK();
}
~ExpGradientFunction() override {}
@ -72,14 +81,142 @@ class ExpGradientFunction : public GradientFunction {
AbstractTensorHandlePtr exp_;
};
class MatMulGradientFunction : public GradientFunction {
public:
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
AttrBuilder f_attrs)
: forward_inputs(f_inputs), forward_attrs(f_attrs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
/* Given upstream grad U and a matmul op A*B, the gradients are:
*
* dA = U * B.T
* dB = A.T * U
*
* where A.T means `transpose(A)`
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
grad_outputs->resize(2);
// Get transpose attrs
bool t_a;
TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_a", &t_a));
bool t_b;
TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_b", &t_b));
// Conj each input
vector<AbstractTensorHandle*> conj_outputs(1);
std::string name = "Conj_A_MatMul_Grad";
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[0]},
absl::MakeSpan(conj_outputs), name.c_str()));
AbstractTensorHandle* A = conj_outputs[0];
name = "Conj_B_MatMul_Grad";
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[1]},
absl::MakeSpan(conj_outputs), name.c_str()));
AbstractTensorHandle* B = conj_outputs[0];
// Calc Grad
vector<AbstractTensorHandle*> matmul_A_outputs(1);
vector<AbstractTensorHandle*> matmul_B_outputs(1);
std::string name_grad_A = "MatMul_Grad_A";
std::string name_grad_B = "MatMul_Grad_B";
if (!t_a && !t_b) {
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ true));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ false));
} else if (!t_a && t_b) {
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ false));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ false));
} else if (t_a && !t_b) {
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ true));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ false));
} else { // t_a && t_b
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ true));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ true));
}
// Gradient for A
(*grad_outputs)[0] = matmul_A_outputs[0];
// Gradient for B
(*grad_outputs)[1] = matmul_B_outputs[0];
return Status::OK();
}
~MatMulGradientFunction() override {}
private:
vector<AbstractTensorHandle*> forward_inputs;
AttrBuilder forward_attrs;
};
} // namespace
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction;
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
auto gradient_function = new AddGradientFunction;
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
GradientFunction* ExpRegisterer(const ForwardOperation& op) {
return new ExpGradientFunction(op.outputs[0]);
BackwardFunction* ExpRegisterer(const ForwardOperation& op) {
auto gradient_function = new ExpGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
auto gradient_function = new MatMulGradientFunction(op.inputs, op.attrs);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients

View File

@ -19,9 +19,10 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
GradientFunction* AddRegisterer(const ForwardOperation& op);
GradientFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_

View File

@ -0,0 +1,111 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::Identity;
using tensorflow::ops::Mul;
using tensorflow::ops::ReluGrad;
using tensorflow::ops::SparseSoftmaxCrossEntropyLoss;
using tensorflow::ops::ZerosLike;
namespace tensorflow {
namespace gradients {
namespace {
class ReluGradientFunction : public GradientFunction {
public:
explicit ReluGradientFunction(vector<AbstractTensorHandle*> f_outputs)
: forward_outputs(f_outputs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
AbstractTensorHandle* upstream_grad = grad_inputs[0];
AbstractTensorHandle* activations = forward_outputs[0];
grad_outputs->resize(1);
vector<AbstractTensorHandle*> relugrad_outputs(1);
// Calculate Grad
std::string name = "relu_grad";
TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, activations},
absl::MakeSpan(relugrad_outputs),
name.c_str()));
(*grad_outputs)[0] = relugrad_outputs[0];
return Status::OK();
}
~ReluGradientFunction() override {}
private:
vector<AbstractTensorHandle*> forward_outputs;
};
class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
public:
explicit SparseSoftmaxCrossEntropyLossGradientFunction(
vector<AbstractTensorHandle*> f_outputs)
: forward_outputs(f_outputs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
// Grad for Softmax Input
std::string name = "Mul_Softmax_Grad";
vector<AbstractTensorHandle*> mul_outputs(1);
TF_RETURN_IF_ERROR(
ops::Mul(ctx->ctx, {grad_inputs[0], forward_outputs[1]},
absl::MakeSpan(mul_outputs),
name.c_str())); // upstream_grad * local softmax grad
(*grad_outputs)[0] = mul_outputs[0];
// Grad for labels is null
(*grad_outputs)[1] = nullptr;
return Status::OK();
}
~SparseSoftmaxCrossEntropyLossGradientFunction() override {}
private:
vector<AbstractTensorHandle*> forward_outputs;
};
} // namespace
BackwardFunction* ReluRegisterer(const ForwardOperation& op) {
auto gradient_function = new ReluGradientFunction(op.outputs);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer(
const ForwardOperation& op) {
auto gradient_function =
new SparseSoftmaxCrossEntropyLossGradientFunction(op.outputs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
BackwardFunction* ReluRegisterer(const ForwardOperation& op);
BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer(
const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_

View File

@ -38,10 +38,29 @@ cc_library(
deps = [
":array_ops",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:framework",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)
cc_library(
name = "nn_ops",
srcs = [
"nn_ops.cc",
],
hdrs = [
"nn_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],

View File

@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace ops {
// Creates an Identity op.
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
@ -35,5 +35,19 @@ Status Identity(AbstractContext* ctx,
return identity_op->Execute(outputs, &num_retvals);
}
Status ZerosLike(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr z_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
if (isa<tensorflow::tracing::TracingOperation>(z_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(z_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0]));
int num_retvals = 1;
return z_op->Execute(outputs, &num_retvals);
}
} // namespace ops
} // namespace tensorflow

View File

@ -22,9 +22,15 @@ limitations under the License.
namespace tensorflow {
namespace ops {
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status ZerosLike(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow

View File

@ -51,5 +51,60 @@ Status Conj(AbstractContext* ctx,
return Status::OK();
}
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr add_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(add_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status MatMul(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a = false, bool transpose_b = false) {
AbstractOperationPtr matmul_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(matmul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(matmul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1]));
TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_a", transpose_a));
TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_b", transpose_b));
int num_retvals = 1;
TF_RETURN_IF_ERROR(matmul_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr neg_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr));
if (isa<TracingOperation>(neg_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(neg_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0]));
int num_retvals = 1;
return neg_op->Execute(outputs, &num_retvals);
}
} // namespace ops
} // namespace tensorflow

View File

@ -25,6 +25,15 @@ Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status MatMul(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b);
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,67 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace ops {
// Softmax Loss given scores and labels, used by the SoftMaxLossGradient
Status SparseSoftmaxCrossEntropyLoss(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sm_loss_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(sm_loss_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(sm_loss_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0])); // input scores
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1])); // labels
// Outputs will contain: [loss_vals, gradients].
int num_retvals = 2;
TF_RETURN_IF_ERROR(sm_loss_op->Execute(outputs, &num_retvals));
return Status::OK();
}
// Computes Relu gradient given input features
Status ReluGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr relugrad_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(relugrad_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(relugrad_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0])); // upstream grads
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1])); // relu inputs
int num_retvals = 1;
TF_RETURN_IF_ERROR(relugrad_op->Execute(outputs, &num_retvals));
return Status::OK();
}
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace ops {
Status SparseSoftmaxCrossEntropyLoss(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status ReluGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_

View File

@ -44,7 +44,9 @@ cc_library(
],
deps = [
":concrete_function",
":signature_def_function",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)
@ -70,6 +72,26 @@ cc_library(
],
)
cc_library(
name = "signature_def_function",
hdrs = [
"signature_def_function.h",
],
deps = [
":signature_def_function_metadata",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "signature_def_function_metadata",
hdrs = [
"signature_def_function_metadata.h",
],
)
cc_library(
name = "test_utils",
testonly = True,
@ -115,6 +137,7 @@ cc_library(
":concrete_function",
":saved_model_api",
":saved_model_utils",
":signature_def_function",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
@ -206,13 +229,13 @@ tf_cc_test(
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
)

View File

@ -26,10 +26,14 @@ limitations under the License.
namespace tensorflow {
// Note that ConcreteFunctions's lifetimes are effectively bound
// to the SavedModel they are loaded from, since they retain pointers
// to the TensorHandles owned by the SavedModel, and the FunctionDef
// of the SavedModel.
// ConcreteFunctions correspond to an instance of a tf.function with a known set
// of inputs (either through get_concrete_function) or an input_signature.
// ConcreteFunction attempts to preserve the user-facing semantics of the
// tf.function python API and can take a limited set of types as arguments
// (to be modeled in tensorflow::Value), not just Tensors.
// SavedModelAPI's ConcreteFunctions' lifetimes are bound to the SavedModel they
// are loaded from, since they retain pointers to the TensorHandles owned by the
// SavedModel, and the FunctionDef of the SavedModel.
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
// TFRT integration with TF Serving. Do not add more virtual implementations of
// this class. Eventually we want to remove this virtual base class indirection

View File

@ -37,10 +37,11 @@ static const char kNoSharingResourceID[] =
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
const char* raw_device_name,
ImmediateTensorHandlePtr* handle) {
ImmediateOpPtr varhandle_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", raw_device_name));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
// Note that if shape is unknown rank, shape.dim_sizes() will be empty, and

View File

@ -31,6 +31,7 @@ namespace internal {
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
const char* raw_device_name,
ImmediateTensorHandlePtr* handle);
// Executes an AssignVariableOp using `ctx`, assigning the variable associated

View File

@ -55,7 +55,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
context(), DT_FLOAT, {}, nullptr, &handle));
// The created TensorHandle should be a DT_Resource
EXPECT_EQ(handle->DataType(), DT_RESOURCE);
}
@ -65,7 +65,7 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
context(), DT_FLOAT, {}, nullptr, &handle));
// Destroy the variable
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
@ -76,7 +76,7 @@ TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
ImmediateTensorHandlePtr variable;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &variable));
context(), DT_FLOAT, {}, nullptr, &variable));
// Create a Scalar float TensorHandle with value 42, and assign it to
// the variable.

View File

@ -65,10 +65,11 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
const char* raw_device_name,
std::unique_ptr<Variable>* output) {
ImmediateTensorHandlePtr handle;
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, &handle));
ctx, dtype, shape, raw_device_name, &handle));
output->reset(
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));

View File

@ -37,6 +37,7 @@ class Variable : public TensorHandleConvertible {
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
const char* raw_device_name,
std::unique_ptr<Variable>* output);
// The dtype of the underlying variable.

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
@ -39,11 +40,11 @@ class SavedModelAPI {
virtual Status GetFunction(const std::string& function_path,
ConcreteFunction** function) = 0;
// Retrieve a function from a SavedModel, using the key of the
// Retrieve a SignatureDefFunction from a SavedModel, using the key of the
// SignatureDef map:
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
ConcreteFunction** function) = 0;
SignatureDefFunction** function) = 0;
virtual std::vector<ConcreteFunction*> ListFunctions() = 0;

View File

@ -122,9 +122,9 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
tensorflow::TensorShape shape(variable.shape());
tensorflow::DataType dtype = variable.dtype();
TF_RETURN_IF_ERROR(
Variable::CreateUninitialized(ctx, dtype, shape, name, output));
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
ctx, dtype, shape, name,
variable.device().empty() ? nullptr : variable.device().c_str(), output));
return Status();
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
@ -38,9 +39,15 @@ namespace {
class SavedVariableLoadingTest : public ::testing::TestWithParam<
std::tuple<DataType, std::vector<int64>>> {
public:
SavedVariableLoadingTest()
: device_mgr_(testing::CreateTestingDeviceMgr()),
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
SavedVariableLoadingTest() {
SessionOptions options;
options.config.mutable_device_count()->insert({"CPU", 3});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
ctx_ = testing::CreateTestingEagerContext(device_mgr_.get());
}
EagerContext* context() { return ctx_.get(); }
@ -67,6 +74,39 @@ TEST_P(SavedVariableLoadingTest, LoadSavedVariableSuccessful) {
EXPECT_EQ(var->shape(), shape);
}
// Verify that a device specified in the SavedVariable is kept.
TEST_P(SavedVariableLoadingTest, LoadSavedVariableWithDevice) {
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
SavedVariable saved_variable;
saved_variable.set_dtype(dtype);
saved_variable.set_device("/job:localhost/replica:0/task:0/device:CPU:1"),
shape.AsProto(saved_variable.mutable_shape());
std::unique_ptr<Variable> var;
TF_ASSERT_OK(internal::LoadSavedVariable(context(), saved_variable, &var));
EXPECT_EQ(down_cast<TensorHandle*>(var->handle())->resource_device()->name(),
"/job:localhost/replica:0/task:0/device:CPU:1");
}
// Verify load failure if a non-existing device is specified.
TEST_P(SavedVariableLoadingTest, LoadSavedVariableWithInvalidDevice) {
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
SavedVariable saved_variable;
saved_variable.set_dtype(dtype);
saved_variable.set_device("/job:localhost/replica:0/task:0/device:CPU:99"),
shape.AsProto(saved_variable.mutable_shape());
std::unique_ptr<Variable> var;
ASSERT_NE(Status::OK(),
internal::LoadSavedVariable(context(), saved_variable, &var));
}
// Assigning and reading values should yield
// consistent results.
TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
@ -79,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
Status status;
std::unique_ptr<Variable> var;
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
absl::nullopt, &var));
absl::nullopt, nullptr, &var));
// Create a TensorHandle
ImmediateTensorHandlePtr expected_handle =

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
namespace tensorflow {
// See tensorflow/cc/experimental/saved_model/public/signature_def_function.h
// for SignatureDefFunction's intended user-facing semantics.
// This class is the "implementation" C++ part of the C++/C/C++ sandwich for
// a SignatureDefFunction.
// Note(bmzhao): Implementation-wise, SignatureDefFunctions are always saved as
// a "BareConcreteFunction", w/o a FunctionSpec, rather than a SavedFunction:
// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/core/protobuf/saved_object_graph.proto#L60
// Additionally they are guaranteed to be children of the .signatures attribute
// of the root object, where the child object "name" is the signature_def key:
// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/python/saved_model/signature_serialization.py#L181-L230
// One of the critical requirements of SignatureDef functions is that their
// inputs and outputs are "named". For example, a `.signatures` function:
// a. Requires users to pass: kwargs of all inputs:
// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/python/saved_model/signature_serialization.py#L119-L126
// b. Returns a dictionary of named outputs.
// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/python/saved_model/signature_serialization.py#L153-L161
// Since SignatureDefFunctions do not have FunctionSpecs, but guarantee the
// dictionary of inputs/outputs, we can parse these dictionaries' keys to obtain
// the input/output names of the SignatureDef:
// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/core/protobuf/meta_graph.proto#L318-L321
class SignatureDefFunction {
public:
virtual ~SignatureDefFunction() = default;
// Creates a "Call" Op used to execute the function.
virtual Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) const = 0;
virtual const SignatureDefFunctionMetadata& GetFunctionMetadata() const = 0;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
namespace tensorflow {
class SignatureDefFunctionMetadata {
// TODO(bmzhao): Fill in with fields as necessary
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@ -305,7 +306,7 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path,
}
Status TFSavedModelAPI::GetSignatureDefFunction(
const std::string& signature_def_key, ConcreteFunction** function) {
const std::string& signature_def_key, SignatureDefFunction** function) {
// TODO(bmzhao): Add support for retrieving a signaturedef function.
return errors::Unimplemented(
"Retrieving SignatureDef functions is unimplemented currently");

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/core/platform/status.h"
@ -55,7 +56,7 @@ class TFSavedModelAPI : public SavedModelAPI {
ConcreteFunction** function) override;
Status GetSignatureDefFunction(const std::string& signature_def_key,
ConcreteFunction** function) override;
SignatureDefFunction** function) override;
static Status Load(
const std::string& directory,

View File

@ -142,6 +142,8 @@ cc_library(
":concrete_function_list_type",
":concrete_function_type",
":saved_model_api_type",
":signature_def_function",
":signature_def_function_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
@ -165,6 +167,77 @@ cc_library(
],
)
cc_library(
name = "signature_def_function",
srcs = [
"signature_def_function.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:signature_def_function.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":signature_def_function_metadata",
":signature_def_function_metadata_type",
":signature_def_function_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "signature_def_function_type",
hdrs = [
"signature_def_function_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
],
)
cc_library(
name = "signature_def_function_metadata",
srcs = [
"signature_def_function_metadata.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":signature_def_function_metadata_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
],
)
cc_library(
name = "signature_def_function_metadata_type",
hdrs = [
"signature_def_function_metadata_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
],
)
tf_cc_test(
name = "saved_model_api_test",
size = "small",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -106,9 +107,11 @@ TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
return tensorflow::wrap(result);
}
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
TF_CAPI_EXPORT extern TF_SignatureDefFunction*
TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
const char* signature_def_key,
TF_Status* status) {
tensorflow::SignatureDefFunction* result = nullptr;
tensorflow::Status get_function_status =
tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
&result);

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/status.h"
extern "C" {
TF_SignatureDefFunctionMetadata* TF_SignatureDefFunctionGetMetadata(
TF_SignatureDefFunction* func) {
return tensorflow::wrap(const_cast<tensorflow::SignatureDefFunctionMetadata*>(
&tensorflow::unwrap(func)->GetFunctionMetadata()));
}
TFE_Op* TF_SignatureDefFunctionMakeCallOp(TF_SignatureDefFunction* func,
TFE_TensorHandle** inputs,
int num_inputs, TF_Status* status) {
tensorflow::ImmediateOpPtr call_op;
absl::Span<tensorflow::AbstractTensorHandle* const> input_span(
reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(inputs)),
static_cast<size_t>(num_inputs));
status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op);
if (!status->status.ok()) {
return nullptr;
}
return tensorflow::wrap(call_op.release());
}
} // end extern "C"

View File

@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
// Static initialization for TensorFlow.js op registration.
static mlir::DialectRegistration<mlir::tfjs::TFJSDialect> tfjs_ops;
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
// TODO(bmzhao): Add getter functions here as necessary.

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefFunctionMetadata,
TF_SignatureDefFunctionMetadata)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
typedef struct TF_SignatureDefFunction TF_SignatureDefFunction;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefFunction,
TF_SignatureDefFunction)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_

View File

@ -24,6 +24,8 @@ exports_files(
"concrete_function_list.h",
"function_metadata.h",
"saved_model_api.h",
"signature_def_function.h",
"signature_def_function_metadata.h",
],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
)
@ -39,6 +41,8 @@ cc_library(
":concrete_function_list",
":function_metadata",
":saved_model_api",
":signature_def_function",
":signature_def_function_metadata",
],
)
@ -61,3 +65,13 @@ alias(
name = "saved_model_api",
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
)
alias(
name = "signature_def_function",
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function",
)
alias(
name = "signature_def_function_metadata",
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata",
)

View File

@ -21,6 +21,8 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
// IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -40,6 +40,13 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
// The caller is responsible for deleting the returned TFE_Op. If op
// construction fails, `status` will be non-OK and the returned pointer will be
// null.
// TODO(bmzhao): Remove this function in a subsequent change; Design + implement
// a Function Execution interface for ConcreteFunction that accepts a tagged
// union of types (tensorflow::Value). This effectively requires moving much of
// the implementation of function.py/def_function.py to C++, and exposing a
// high-level API here. A strawman for what this interface could look like:
// TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value*
// inputs, int num_inputs, TF_Status* status);
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
#include "tensorflow/c/tf_status.h"
#ifdef __cplusplus
@ -91,10 +92,13 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// TF_ConcreteFunction instance. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, const char* signature_def_key, TF_Status* status);
// TF_SignatureDefFunction instance. Once `model` is deleted, all
// `TF_SignatureDefFunctions` retrieved from it are invalid, and have been
// deleted.
TF_CAPI_EXPORT extern TF_SignatureDefFunction*
TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
const char* signature_def_key,
TF_Status* status);
// Returns a list of all ConcreteFunctions stored in this SavedModel.
// The lifetime of the returned list is bound to `model`.

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that corresponds to a SignatureDefFunction loaded from a
// SavedModel.
typedef struct TF_SignatureDefFunction TF_SignatureDefFunction;
// Returns FunctionMetadata associated with `func`. Metadata's lifetime is
// bound to `func`, which is bound to the TF_SavedModel it was loaded from.
TF_CAPI_EXPORT extern TF_SignatureDefFunctionMetadata*
TF_SignatureDefFunctionGetMetadata(TF_SignatureDefFunction* func);
// Returns a TFE_Op suitable for executing this function. Caller must provide
// all function inputs in `inputs`, and must not add any additional inputs on
// the returned op. (i.e. don't call TFE_OpAddInput or TFE_OpAddInputList).
// The caller is responsible for deleting the returned TFE_Op. If op
// construction fails, `status` will be non-OK and the returned pointer will be
// null.
TF_CAPI_EXPORT extern TFE_Op* TF_SignatureDefFunctionMakeCallOp(
TF_SignatureDefFunction* func, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that corresponds to a SignatureDefFunction loaded from a
// SavedModel.
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_

View File

@ -0,0 +1,60 @@
# Description:
# StreamExecutor C API.
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "stream_executor",
srcs = ["stream_executor.cc"],
hdrs = ["stream_executor.h"],
visibility = ["//visibility:public"],
deps = [
":stream_executor_internal",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:lib",
"//tensorflow/stream_executor:executor_cache",
"//tensorflow/stream_executor:multi_platform_manager",
"//tensorflow/stream_executor:platform",
"//tensorflow/stream_executor:stream_executor_internal",
"//tensorflow/stream_executor:stream_executor_pimpl",
"//tensorflow/stream_executor:timer",
],
)
cc_library(
name = "stream_executor_internal",
hdrs = [
"stream_executor.h",
"stream_executor_internal.h",
],
deps = [
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
"//tensorflow/stream_executor:executor_cache",
"//tensorflow/stream_executor/lib",
],
)
tf_cc_test(
name = "stream_executor_test",
srcs = ["stream_executor_test.cc"],
deps = [
":stream_executor",
":stream_executor_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/protobuf:error_codes_proto_impl_cc",
"//tensorflow/stream_executor:multi_platform_manager",
"//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor:stream_executor_pimpl",
],
)

View File

@ -0,0 +1,809 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file extends/implements core stream executor base classes in terms of
// the C API defined in stream_executor.h. A class "CSomething" represents a
// "Something" that can be manipulated via calls in the C interface and a C
// struct called "SP_Something".
//
// This file also contains stream_executor::Platform registration for pluggable
// device.
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include <string>
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/timer.h"
using tensorflow::StatusFromTF_Status;
namespace stream_executor {
namespace {
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
do { \
if (STRUCT_OBJ.struct_size == 0) { \
return port::FailedPreconditionError( \
"struct_size field in " #STRUCT_NAME \
" must be set to " #SIZE_VALUE_NAME "."); \
} \
} while (0)
#define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \
do { \
if (STRUCT_OBJ.NAME == 0) { \
return port::FailedPreconditionError( \
"'" #NAME "' field in " #STRUCT_NAME " must be set."); \
} \
} while (0)
port::Status ValidateSPPlatform(const SP_Platform& platform) {
VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
VALIDATE_MEMBER(SP_Platform, platform, name);
VALIDATE_MEMBER(SP_Platform, platform, type);
VALIDATE_MEMBER(SP_Platform, platform, visible_device_count);
VALIDATE_MEMBER(SP_Platform, platform, create_device);
VALIDATE_MEMBER(SP_Platform, platform, destroy_device);
VALIDATE_MEMBER(SP_Platform, platform, create_stream_executor);
VALIDATE_MEMBER(SP_Platform, platform, destroy_stream_executor);
VALIDATE_MEMBER(SP_Platform, platform, create_timer_fns);
VALIDATE_MEMBER(SP_Platform, platform, destroy_timer_fns);
return port::Status::OK();
}
port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) {
VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE);
VALIDATE_MEMBER(SP_TimerFns, timer_fns, nanoseconds);
return port::Status::OK();
}
port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) {
VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats, SP_ALLOCATORSTATS_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return port::Status::OK();
}
port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) {
VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem,
SP_DEVICE_MEMORY_BASE_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return port::Status::OK();
}
port::Status ValidateSPDevice(const SP_Device& device) {
VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return port::Status::OK();
}
port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se) {
VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE);
VALIDATE_MEMBER(SP_StreamExecutor, se, allocate);
VALIDATE_MEMBER(SP_StreamExecutor, se, deallocate);
VALIDATE_MEMBER(SP_StreamExecutor, se, get_allocator_stats);
VALIDATE_MEMBER(SP_StreamExecutor, se, device_memory_usage);
VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream);
VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_stream);
VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream_dependency);
VALIDATE_MEMBER(SP_StreamExecutor, se, get_stream_status);
VALIDATE_MEMBER(SP_StreamExecutor, se, create_event);
VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_event);
VALIDATE_MEMBER(SP_StreamExecutor, se, get_event_status);
VALIDATE_MEMBER(SP_StreamExecutor, se, record_event);
VALIDATE_MEMBER(SP_StreamExecutor, se, wait_for_event);
VALIDATE_MEMBER(SP_StreamExecutor, se, create_timer);
VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_timer);
VALIDATE_MEMBER(SP_StreamExecutor, se, start_timer);
VALIDATE_MEMBER(SP_StreamExecutor, se, stop_timer);
VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_dtoh);
VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_htod);
VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_dtoh);
VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_htod);
VALIDATE_MEMBER(SP_StreamExecutor, se, block_host_for_event);
VALIDATE_MEMBER(SP_StreamExecutor, se, synchronize_all_activity);
VALIDATE_MEMBER(SP_StreamExecutor, se, host_callback);
return port::Status::OK();
}
port::Status ValidateSEPlatformRegistrationParams(
const SE_PlatformRegistrationParams& params) {
VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params,
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE);
VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform);
return port::Status::OK();
}
#undef VALIDATE_MEMBER
struct TFStatusDeleter {
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
};
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
class CStream : public internal::StreamInterface {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
stream_handle_(nullptr) {}
~CStream() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
port::Status s = StatusFromTF_Status(c_status.get());
return s;
}
void Destroy() {
if (stream_handle_ != nullptr) {
stream_executor_->destroy_stream(device_, stream_handle_);
stream_handle_ = nullptr;
}
}
SP_Stream Handle() { return stream_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Stream stream_handle_;
};
// Converts SE_EventStatus to Event::Status.
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
switch (s) {
case SE_EVENT_ERROR:
return Event::Status::kError;
case SE_EVENT_PENDING:
return Event::Status::kPending;
case SE_EVENT_COMPLETE:
return Event::Status::kComplete;
default:
return Event::Status::kUnknown;
}
}
class CEvent : public internal::EventInterface {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status Record(SP_Stream stream_handle) {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->record_event(device_, stream_handle, event_handle_,
c_status.get());
return StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (event_handle_ != nullptr) {
stream_executor_->destroy_event(device_, event_handle_);
event_handle_ = nullptr;
}
}
SP_Event Handle() { return event_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Event event_handle_;
};
class CTimer : public internal::TimerInterface {
public:
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
SP_TimerFns* timer_fns)
: device_(device),
stream_executor_(stream_executor),
timer_handle_(nullptr),
timer_fns_(timer_fns) {}
~CTimer() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
return StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (timer_handle_ != nullptr) {
stream_executor_->destroy_timer(device_, timer_handle_);
timer_handle_ = nullptr;
}
}
SP_Timer Handle() { return timer_handle_; }
uint64 Microseconds() const override {
return timer_fns_->nanoseconds(timer_handle_) / 1000;
}
uint64 Nanoseconds() const override {
return timer_fns_->nanoseconds(timer_handle_);
}
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Timer timer_handle_;
SP_TimerFns* timer_fns_;
};
// Converts DeviceMemoryBase to a C struct.
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
// `opaque` field inside SP_DeviceMemoryBase is not const.
// Therefore, we need to cast away the constness before setting it.
device_memory_base.opaque = const_cast<void*>(mem->opaque());
device_memory_base.size = mem->size();
device_memory_base.payload = mem->payload();
// TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here.
return device_memory_base;
}
DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
DeviceMemoryBase base(mem.opaque, mem.size);
base.SetPayload(mem.payload);
// TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here.
return base;
}
// Wrapper that allows passing std::function across C API.
struct HostCallbackContext {
std::function<port::Status()> callback;
};
// This wrapper allows calling `HostCallbackContext::callback` across C API.
// This function matches `SE_StatusCallbackFn` signature and will be passed as
// `callback_fn` to `host_callback` in `SP_StreamExecutor`.
void HostCallbackTrampoline(void* ctx, TF_Status* status) {
HostCallbackContext* host_ctx = static_cast<HostCallbackContext*>(ctx);
port::Status s = host_ctx->callback();
Set_TF_Status_from_Status(status, s);
delete host_ctx;
}
class CStreamExecutor : public internal::StreamExecutorInterface {
public:
explicit CStreamExecutor(SP_Device device,
void (*destroy_device)(SP_Device* const device),
SP_StreamExecutor* stream_executor,
SP_TimerFns* timer_fns, const std::string& name,
int visible_device_count)
: device_(std::move(device)),
destroy_device_(destroy_device),
stream_executor_(stream_executor),
timer_fns_(timer_fns),
platform_name_(name),
visible_device_count_(visible_device_count) {}
~CStreamExecutor() override { destroy_device_(&device_); }
port::Status Init(int device_ordinal, DeviceOptions device_options) override {
return port::Status::OK();
}
DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override {
SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
stream_executor_->allocate(&device_, size, memory_space, &mem);
port::Status status = ValidateSPDeviceMemoryBase(mem);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
return DeviceMemoryBaseFromC(mem);
}
DeviceMemoryBase Allocate(uint64 size) {
return Allocate(size, /*memory_space=*/0);
}
void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
uint64 size) override {
LOG(FATAL) << "GetSubBuffer is not supported by pluggable device.";
}
void Deallocate(DeviceMemoryBase* mem) override {
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem);
stream_executor_->deallocate(&device_, &device_memory_base);
}
void* HostMemoryAllocate(uint64 size) override {
return stream_executor_->host_memory_allocate(&device_, size);
}
void HostMemoryDeallocate(void* mem) override {
stream_executor_->host_memory_deallocate(&device_, mem);
}
bool HostMemoryRegister(void* mem, uint64 size) override { return false; }
bool HostMemoryUnregister(void* mem) override { return false; }
absl::optional<AllocatorStats> GetAllocatorStats() override {
SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE};
TF_Bool has_stats =
stream_executor_->get_allocator_stats(&device_, &c_stats);
if (!has_stats) {
return absl::nullopt;
}
port::Status status = ValidateSPAllocatorStats(c_stats);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
return absl::nullopt;
}
// TODO(annarev): validate SP_AllocatorStats.
::stream_executor::AllocatorStats stats;
stats.num_allocs = c_stats.num_allocs;
stats.bytes_in_use = c_stats.bytes_in_use;
stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
stats.largest_alloc_size = c_stats.largest_alloc_size;
if (c_stats.has_bytes_limit) {
stats.bytes_limit = c_stats.bytes_limit;
}
stats.bytes_reserved = c_stats.bytes_reserved;
stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
if (c_stats.has_bytes_reservable_limit) {
stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
}
stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
return stats;
}
bool SynchronizeAllActivity() override {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->synchronize_all_activity(&device_, c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
port::Status SynchronousMemZero(DeviceMemoryBase* location,
uint64 size) override {
// TODO(annarev): figure out if we should support memzero/memset
// functionality by allocating on host and then copying to device.
return port::UnimplementedError(
"SynchronousMemZero is not supported by pluggable device.");
}
port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
uint64 size) override {
return port::UnimplementedError(
"SynchronousMemSet is not supported by pluggable device.");
}
port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst,
const void* host_src, uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst);
stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src,
size, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status SynchronousMemcpy(void* host_dst,
const DeviceMemoryBase& gpu_src,
uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src);
stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base,
size, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst,
const DeviceMemoryBase& gpu_src,
uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
stream_executor_->sync_memcpy_dtod(&device_, &device_mem_dst,
&device_mem_src, size, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
uint64 size) override {
return port::UnimplementedError(
"MemZero is not supported by pluggable device.");
}
port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern,
uint64 size) override {
return port::UnimplementedError(
"Memset is not supported by pluggable device.");
}
port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
uint32 pattern, uint64 size) override {
return port::UnimplementedError(
"Memset32 is not supported by pluggable device.");
}
bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src,
uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
stream_executor_->memcpy_dtoh(&device_, stream_handle, host_dst,
&device_mem_src, size, c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src,
uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
stream_executor_->memcpy_htod(&device_, stream_handle, &device_mem_dst,
host_src, size, c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst,
const DeviceMemoryBase& gpu_src,
uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
stream_executor_->memcpy_dtod(&device_, stream_handle, &device_mem_dst,
&device_mem_src, size, c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
bool HostCallback(Stream* stream,
std::function<port::Status()> callback) override {
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
HostCallbackContext* ctx = new HostCallbackContext{callback};
return stream_executor_->host_callback(&device_, stream_handle,
&HostCallbackTrampoline, ctx);
}
port::Status AllocateEvent(Event* event) override {
DCHECK(event != nullptr);
return static_cast<CEvent*>(event->implementation())->Create();
}
port::Status DeallocateEvent(Event* event) override {
static_cast<CEvent*>(event->implementation())->Destroy();
return port::Status::OK();
}
port::Status RecordEvent(Stream* stream, Event* event) override {
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
return static_cast<CEvent*>(event->implementation())->Record(stream_handle);
}
port::Status WaitForEvent(Stream* stream, Event* event) override {
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_Event event_handle =
static_cast<CEvent*>(event->implementation())->Handle();
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->wait_for_event(&device_, stream_handle, event_handle,
c_status.get());
port::Status s = StatusFromTF_Status(c_status.get());
return s;
}
Event::Status PollForEventStatus(Event* event) override {
SP_Event event_handle =
static_cast<CEvent*>(event->implementation())->Handle();
SE_EventStatus event_status =
stream_executor_->get_event_status(&device_, event_handle);
return SEEventStatusToEventStatus(event_status);
}
bool AllocateStream(Stream* stream) override {
DCHECK(stream != nullptr);
port::Status status =
static_cast<CStream*>(stream->implementation())->Create();
// TODO(annarev): update AllocateStream to return status instead
// (similar to AllocateEvent).
return status.ok();
}
void DeallocateStream(Stream* stream) override {
static_cast<CStream*>(stream->implementation())->Destroy();
}
bool CreateStreamDependency(Stream* dependent, Stream* other) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream dependent_handle =
static_cast<CStream*>(dependent->implementation())->Handle();
SP_Stream other_handle =
static_cast<CStream*>(other->implementation())->Handle();
stream_executor_->create_stream_dependency(&device_, dependent_handle,
other_handle, c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
bool AllocateTimer(Timer* timer) override {
port::Status status =
static_cast<CTimer*>(timer->implementation())->Create();
// TODO(annarev): change return value of AllocateTimer
// to status (similar to AllocateEvent).
return status.ok();
}
void DeallocateTimer(Timer* timer) override {
static_cast<CTimer*>(timer->implementation())->Destroy();
}
bool StartTimer(Stream* stream, Timer* timer) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_Timer timer_handle =
static_cast<CTimer*>(timer->implementation())->Handle();
stream_executor_->start_timer(&device_, stream_handle, timer_handle,
c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
bool StopTimer(Stream* stream, Timer* timer) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_Timer timer_handle =
static_cast<CTimer*>(timer->implementation())->Handle();
stream_executor_->stop_timer(&device_, stream_handle, timer_handle,
c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
port::Status BlockHostForEvent(Stream* stream, Event* event) {
OwnedTFStatus c_status(TF_NewStatus());
SP_Event event_handle =
static_cast<CEvent*>(event->implementation())->Handle();
stream_executor_->block_host_for_event(&device_, event_handle,
c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status BlockHostUntilDone(Stream* stream) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Event event_handle;
stream_executor_->create_event(&device_, &event_handle, c_status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
stream_executor_->record_event(&device_, stream_handle, event_handle,
c_status.get());
port::Status s = StatusFromTF_Status(c_status.get());
if (!s.ok()) {
stream_executor_->destroy_event(&device_, event_handle);
return s;
}
stream_executor_->block_host_for_event(&device_, event_handle,
c_status.get());
stream_executor_->destroy_event(&device_, event_handle);
return StatusFromTF_Status(c_status.get());
}
port::Status GetStatus(Stream* stream) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
stream_executor_->get_stream_status(&device_, stream_handle,
c_status.get());
return StatusFromTF_Status(c_status.get());
}
int PlatformDeviceCount() override { return visible_device_count_; }
port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
return port::UnimplementedError(
"EnablePeerAccessTo is not supported by pluggable device.");
}
bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
return false;
}
bool DeviceMemoryUsage(int64* free, int64* total) const override {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
return stream_executor_->device_memory_usage(
&device_, reinterpret_cast<int64_t*>(free),
reinterpret_cast<int64_t*>(total));
}
// Creates a new DeviceDescription object.
// Ownership is transferred to the caller.
port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
const override {
// TODO(annarev): Figure out if we need to support more description fields.
internal::DeviceDescriptionBuilder builder;
builder.set_name(platform_name_);
return builder.Build();
}
// Each call creates a new instance of the platform-specific implementation of
// the corresponding interface type.
std::unique_ptr<internal::EventInterface> CreateEventImplementation()
override {
return std::unique_ptr<internal::EventInterface>(
new CEvent(&device_, stream_executor_));
}
std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
override {
LOG(FATAL)
<< "CreateKernelImplementation is not supported by pluggable device.";
}
std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
override {
return std::unique_ptr<internal::StreamInterface>(
new CStream(&device_, stream_executor_));
}
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
return std::unique_ptr<internal::TimerInterface>(
new CTimer(&device_, stream_executor_, timer_fns_));
}
private:
SP_Device device_;
void (*destroy_device_)(SP_Device* const device);
SP_StreamExecutor* stream_executor_;
SP_TimerFns* timer_fns_;
std::string platform_name_;
int visible_device_count_;
};
} // namespace
CPlatform::CPlatform(SP_Platform platform,
void (*destroy_platform)(SP_Platform*),
SP_StreamExecutor stream_executor, SP_TimerFns timer_fns)
: platform_(std::move(platform)),
destroy_platform_(destroy_platform),
stream_executor_(std::move(stream_executor)),
timer_fns_(std::move(timer_fns)),
name_(platform.name) {}
CPlatform::~CPlatform() {
executor_cache_.DestroyAllExecutors();
platform_.destroy_stream_executor(&stream_executor_);
platform_.destroy_timer_fns(&timer_fns_);
destroy_platform_(&platform_);
}
port::StatusOr<std::unique_ptr<DeviceDescription>>
CPlatform::DescriptionForDevice(int ordinal) const {
// TODO(annarev): see if we can get StreamExecutor instance
// and call GetDeviceDescription. executor_cache_.Get would need
// to be made const for it to work.
internal::DeviceDescriptionBuilder builder;
builder.set_name(name_);
return builder.Build();
}
port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDevice(int ordinal) {
stream_executor::StreamExecutorConfig config;
config.ordinal = ordinal;
return GetExecutor(config);
}
port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDeviceWithPluginConfig(
int ordinal, const PluginConfig& plugin_config) {
StreamExecutorConfig config;
config.ordinal = ordinal;
config.plugin_config = plugin_config;
return GetExecutor(config);
}
port::StatusOr<StreamExecutor*> CPlatform::GetExecutor(
const StreamExecutorConfig& config) {
return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
}
port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
const StreamExecutorConfig& config) {
// Fill device creation params
SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE};
SP_Device device{SP_DEVICE_STRUCT_SIZE};
device_params.device = &device;
device_params.ext = nullptr;
device_params.ordinal = config.ordinal;
OwnedTFStatus c_status(TF_NewStatus());
// Create Device
platform_.create_device(&device_params, c_status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSPDevice(device));
auto executor = absl::make_unique<CStreamExecutor>(
std::move(device), platform_.destroy_device, &stream_executor_,
&timer_fns_, name_, platform_.visible_device_count);
auto result = absl::make_unique<StreamExecutor>(this, std::move(executor),
config.ordinal);
return result;
}
port::Status RegisterDevicePlugin(const std::string& dso_path) {
// Step 1: Load plugin
tensorflow::Env* env = tensorflow::Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadDynamicLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol));
// Step 3: Call `TF_InitPlugin`
auto init_fn = reinterpret_cast<SEPluginInitFn>(dso_symbol);
return RegisterDevicePlugin(init_fn);
}
port::Status RegisterDevicePlugin(SEPluginInitFn init_fn) {
SE_PlatformRegistrationParams params{
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
params.major_version = SE_MAJOR;
params.minor_version = SE_MINOR;
params.revision_version = SE_REVISION;
params.platform = &platform;
OwnedTFStatus c_status(TF_NewStatus());
init_fn(&params, c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params));
TF_RETURN_IF_ERROR(ValidateSPPlatform(platform));
// Fill stream executor creation params
SE_CreateStreamExecutorParams se_params{
SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE};
SP_StreamExecutor se{SP_STREAMEXECUTOR_STRUCT_SIZE};
se_params.stream_executor = &se;
// Create StreamExecutor
platform.create_stream_executor(&se_params, c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se));
SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE};
platform.create_timer_fns(&timer_fns, c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
// Register new platform
std::string platform_name = std::string(platform.name);
std::unique_ptr<stream_executor::CPlatform> cplatform(
new stream_executor::CPlatform(std::move(platform),
params.destroy_platform, std::move(se),
std::move(timer_fns)));
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(cplatform)));
// TODO(annarev): Add pluggable device registration here.
return port::Status::OK();
}
} // namespace stream_executor

View File

@ -0,0 +1,395 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/tf_status.h"
// --------------------------------------------------------------------------
// C API for StreamExecutor. The API is under active development and eventually
// should allow registering a pluggable device with TensorFlow.
//
// Conventions:
// * Struct prefix indicates whether struct fields should be filled by the
// plugin or core implementation:
// * SE_ : set/filled by core unless explicitly marked otherwise.
// * SP_ : set/filled by plugin unless explicitly marked otherwise.
// * We use `struct_size` for version checking. It is exempt from the `SE/SP`
// rule above and should be set both by core and the plugin.
// * For example, `create_device` function receives `SP_Device*` as input
// with `struct_size` populated by core. The plugin is responsible for
// setting `struct_size` as well, along with all other fields.
// * Refer to "TensorFlow Versioning Strategy" section at
// https://github.com/tensorflow/community/pull/257/files.
// * Note that the API is still under active development and doesn't have
// versioning guarantees yet.
// * `void* ext` is a free-form field that can be populated by
// a plugin in `SP_*` structs or potential future extension points in `SE_`
// structs.
//
// Example usage:
//
// /* Sample TensorFlow code below, exact implementation might differ. */
// // Version checking uses `struct_size`. It is exempt from the `SE/SP` rule
// // above and should be set both by core and the plugin."
// SP_Device device { SP_DEVICE_STRUCT_SIZE };
// SE_CreateDeviceParams params { SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE } ;
// params.device = &device;
//
// /* Plugin code below */
// constexpr char DEVICE_NAME[] = "MyDevice";
// constexpr char DEVICE_TYPE[] = "GPU";
//
// void create_device(SE_CreateDeviceParams* params, TF_Status* status) {
// // Custom actions based on TensorFlow's view of SP_Device.
// OnTFDeviceView(params->device->struct_size);
// params->device = { SP_DEVICE_STRUCT_SIZE };
// params->device->device_handle = get_my_device_handle(device->ordinal);
// params->device->ordinal = params->ordinal;
// ...
// }
//
// void destroy_device(SP_Device* device) {
// delete_my_device_handle(device->device_handle);
// }
//
// void SE_InitPlugin(
// SE_PlatformRegistrationParams* params,
// TF_Status* status) {
// params->platform = { SP_PLATFORM_STRUCT_SIZE };
// // Values such as `name` and `type` must outlive SE_InitPlugin call.
// params->platform->name = DEVICE_NAME;
// params->platform->type = DEVICE_TYPE;
// params->platform->visible_device_count = 2;
// params->platform->create_device = create_device;
// params->platform->destroy_device = destroy_device;
// ...
// }
#define SE_MAJOR 0
#define SE_MINOR 0
#define SE_REVISION 1
#ifdef __cplusplus
extern "C" {
#endif
typedef struct SP_Stream_st* SP_Stream;
typedef struct SP_Event_st* SP_Event;
typedef struct SP_Timer_st* SP_Timer;
// Takes `callback_arg` passed to `host_callback` as the first argument.
typedef void (*SE_StatusCallbackFn)(void* const, TF_Status* const);
typedef struct SP_TimerFns {
size_t struct_size;
void* ext; // reserved for future use
uint64_t (*nanoseconds)(SP_Timer timer);
} SP_TimerFns;
#define SP_TIMER_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_TimerFns, nanoseconds)
typedef struct SP_AllocatorStats {
size_t struct_size;
int64_t num_allocs;
int64_t bytes_in_use;
int64_t peak_bytes_in_use;
int64_t largest_alloc_size;
int8_t has_bytes_limit;
int64_t bytes_limit;
int64_t bytes_reserved;
int64_t peak_bytes_reserved;
int8_t has_bytes_reservable_limit;
int64_t bytes_reservable_limit;
int64_t largest_free_block_bytes;
} SP_AllocatorStats;
#define SP_ALLOCATORSTATS_STRUCT_SIZE \
TF_OFFSET_OF_END(SP_AllocatorStats, largest_free_block_bytes)
// Potential states for an SP_Event. If `poll_for_status` returns anything aside
// from kPending or kComplete, an error has occurred; kUnknown is a bad state.
typedef enum SE_EventStatus {
SE_EVENT_UNKNOWN,
SE_EVENT_ERROR,
SE_EVENT_PENDING,
SE_EVENT_COMPLETE,
} SE_EventStatus;
// Memory allocation information.
// This matches DeviceMemoryBase defined here:
// https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57
typedef struct SP_DeviceMemoryBase {
size_t struct_size;
void* ext; // free-form data set by plugin
// Platform-dependent value representing allocated memory.
void* opaque;
uint64_t size; // Size in bytes of this allocation.
uint64_t payload; // Value for plugin's use
} SP_DeviceMemoryBase;
#define SP_DEVICE_MEMORY_BASE_STRUCT_SIZE \
TF_OFFSET_OF_END(SP_DeviceMemoryBase, size)
typedef struct SP_Device {
size_t struct_size;
void* ext; // free-form data set by plugin
int32_t ordinal; // device index
// Device vendor can store handle to their device representation
// here.
void* device_handle;
} SP_Device;
#define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, device_handle)
typedef struct SE_CreateDeviceParams {
size_t struct_size;
void* ext; // reserved for future use
int32_t ordinal; // device index
SP_Device* device; // Input/output, struct_size set by TF for plugin to read.
// Subsequently plugin fills the entire struct.
} SE_CreateDeviceParams;
#define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \
TF_OFFSET_OF_END(SE_CreateDeviceParams, device)
typedef struct SP_StreamExecutor {
size_t struct_size;
void* ext; // reserved for future use
/*** ALLOCATION CALLBACKS ***/
// Synchronously allocates `size` bytes on the underlying platform and returns
// `SP_DeviceMemoryBase` representing that allocation. In the case of failure,
// nullptr is returned.
// `memory_space` is reserved for a potential future usage and should be set
// to 0.
void (*allocate)(const SP_Device* device, uint64_t size, int64_t memory_space,
SP_DeviceMemoryBase* mem);
// Deallocate the device memory previously allocated via this interface.
// Deallocation of a nullptr-representative value is permitted.
void (*deallocate)(const SP_Device* device, SP_DeviceMemoryBase* memory);
// Allocates a region of host memory and registers it with the platform API.
// Memory allocated in this manner is required for use in asynchronous memcpy
// operations, such as `memcpy_dtoh`.
void* (*host_memory_allocate)(const SP_Device* device, uint64_t size);
// Deallocates a region of host memory allocated by `host_memory_allocate`.
void (*host_memory_deallocate)(const SP_Device* device, void* mem);
// Fills SP_AllocatorStats with allocator statistics, if it is available.
// If it is not available, return false.
TF_Bool (*get_allocator_stats)(const SP_Device* device,
SP_AllocatorStats* stats);
// Fills the underlying device memory usage information, if it is
// available. If it is not available (false is returned), free/total need not
// be initialized.
TF_Bool (*device_memory_usage)(const SP_Device* device, int64_t* free,
int64_t* total);
/*** STREAM CALLBACKS ***/
// Creates SP_Stream. This call should also allocate stream
// resources on the underlying platform and initializes its
// internals.
void (*create_stream)(const SP_Device* device, SP_Stream* stream,
TF_Status* status);
// Destroys SP_Stream and deallocates any underlying resources.
void (*destroy_stream)(const SP_Device* device, SP_Stream stream);
// Causes `dependent` to not begin execution until `other` has finished its
// last-enqueued work.
void (*create_stream_dependency)(const SP_Device* device, SP_Stream dependent,
SP_Stream other, TF_Status* status);
// Without blocking the device, retrieve the current stream status.
void (*get_stream_status)(const SP_Device* device, SP_Stream stream,
TF_Status* status);
/*** EVENT CALLBACKS ***/
// Create SP_Event. Performs platform-specific allocation and initialization
// of an event.
void (*create_event)(const SP_Device* device, SP_Event* event,
TF_Status* status);
// Destroy SE_Event and perform any platform-specific deallocation and
// cleanup of an event.
void (*destroy_event)(const SP_Device* device, SP_Event event);
// Requests the current status of the event from the underlying platform.
SE_EventStatus (*get_event_status)(const SP_Device* device, SP_Event event);
// Inserts the specified event at the end of the specified stream.
void (*record_event)(const SP_Device* device, SP_Stream stream,
SP_Event event, TF_Status* status);
// Wait for the specified event at the end of the specified stream.
void (*wait_for_event)(const SP_Device* const device, SP_Stream stream,
SP_Event event, TF_Status* const status);
/*** TIMER CALLBACKS ***/
// Creates SP_Timer. Allocates timer resources on the underlying platform
// and initializes its internals, setting `timer` output variable. Sets
// values in `timer_fns` struct.
void (*create_timer)(const SP_Device* device, SP_Timer* timer,
TF_Status* status);
// Destroy timer and deallocates timer resources on the underlying platform.
void (*destroy_timer)(const SP_Device* device, SP_Timer timer);
// Records a start event for an interval timer.
void (*start_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer,
TF_Status* status);
// Records a stop event for an interval timer.
void (*stop_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer,
TF_Status* status);
/*** MEMCPY CALLBACKS ***/
// Enqueues a memcpy operation onto stream, with a host destination location
// `host_dst` and a device memory source, with target size `size`.
void (*memcpy_dtoh)(const SP_Device* device, SP_Stream stream, void* host_dst,
const SP_DeviceMemoryBase* device_src, uint64_t size,
TF_Status* status);
// Enqueues a memcpy operation onto stream, with a device destination
// location and a host memory source, with target size `size`.
void (*memcpy_htod)(const SP_Device* device, SP_Stream stream,
SP_DeviceMemoryBase* device_dst, const void* host_src,
uint64_t size, TF_Status* status);
// Enqueues a memcpy operation onto stream, with a device destination
// location and a device memory source, with target size `size`.
void (*memcpy_dtod)(const SP_Device* device, SP_Stream stream,
SP_DeviceMemoryBase* device_dst,
const SP_DeviceMemoryBase* device_src, uint64_t size,
TF_Status* status);
// Blocks the caller while a data segment of the given size is
// copied from the device source to the host destination.
void (*sync_memcpy_dtoh)(const SP_Device* device, void* host_dst,
const SP_DeviceMemoryBase* device_src, uint64_t size,
TF_Status* status);
// Blocks the caller while a data segment of the given size is
// copied from the host source to the device destination.
void (*sync_memcpy_htod)(const SP_Device* device,
SP_DeviceMemoryBase* device_dst,
const void* host_src, uint64_t size,
TF_Status* status);
// Blocks the caller while a data segment of the given size is copied from the
// device source to the device destination.
void (*sync_memcpy_dtod)(const SP_Device* device,
SP_DeviceMemoryBase* device_dst,
const SP_DeviceMemoryBase* device_src, uint64_t size,
TF_Status* status);
// Causes the host code to synchronously wait for the event to complete.
void (*block_host_for_event)(const SP_Device* device, SP_Event event,
TF_Status* status);
// Synchronizes all activity occurring in the StreamExecutor's context (most
// likely a whole device).
void (*synchronize_all_activity)(const SP_Device* device, TF_Status* status);
// Enqueues on a stream a user-specified function to be run on the host.
// `callback_arg` should be passed as the first argument to `callback_fn`.
TF_Bool (*host_callback)(SP_Device* device, SP_Stream stream,
SE_StatusCallbackFn callback_fn, void* callback_arg);
} SP_StreamExecutor;
#define SP_STREAMEXECUTOR_STRUCT_SIZE \
TF_OFFSET_OF_END(SP_StreamExecutor, host_callback)
typedef struct SE_CreateStreamExecutorParams {
size_t struct_size;
void* ext; // reserved for future use
SP_StreamExecutor* stream_executor; // output, to be filled by plugin
} SE_CreateStreamExecutorParams;
#define SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE \
TF_OFFSET_OF_END(SE_CreateStreamExecutorParams, stream_executor)
typedef struct SP_Platform {
size_t struct_size;
void* ext; // free-form data set by plugin
// Platform name. Must be null-terminated.
const char* name;
// Device type name, for example GPU. Must be null-terminated.
const char* type;
// Number of visible devices
size_t visible_device_count;
// Callbacks for creating/destroying SP_Device.
void (*create_device)(SE_CreateDeviceParams* params, TF_Status* status);
// Clean up fields inside SP_Device that were allocated
// by the plugin. `device` itself should not be deleted here.
void (*destroy_device)(SP_Device* device);
// Callbacks for creating/destroying SP_StreamExecutor.
void (*create_stream_executor)(SE_CreateStreamExecutorParams* params,
TF_Status* status);
// Clean up fields inside SP_StreamExecutor that were allocated
// by the plugin. `stream_executor` itself should not be deleted here.
void (*destroy_stream_executor)(SP_StreamExecutor* stream_executor);
// Callbacks for creating/destroying SP_TimerFns.
void (*create_timer_fns)(SP_TimerFns* timer, TF_Status* status);
void (*destroy_timer_fns)(SP_TimerFns* timer_fns);
} SP_Platform;
#define SP_PLATFORM_STRUCT_SIZE TF_OFFSET_OF_END(SP_Platform, destroy_timer_fns)
typedef struct SE_PlatformRegistrationParams {
size_t struct_size;
void* ext; // reserved for future use
// StreamExecutor C API version.
int32_t major_version;
int32_t minor_version;
int32_t revision_version;
SP_Platform* platform; // output, set by plugin
// Clean up fields inside SP_Platform that were allocated
// by the plugin. `platform` itself should not be deleted here.
void (*destroy_platform)(SP_Platform* platform); // out, set by plugin
} SE_PlatformRegistrationParams;
#define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \
TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform)
void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_

View File

@ -0,0 +1,80 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Classes and utilities that work with StreamExecutor C API for internal use.
// This includes functions used for device registration and interfaces needed
// for testing.
#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform.h"
namespace stream_executor {
// Plugin initialization function that a device plugin
// must define.
typedef void (*SEPluginInitFn)(SE_PlatformRegistrationParams* const,
TF_Status* const);
// Loads dso and registers StreamExecutor-based pluggable device.
port::Status RegisterDevicePlugin(const std::string& dso_path);
// Allow registering a plugin using a function (used for testing).
port::Status RegisterDevicePlugin(SEPluginInitFn init_fn);
class CPlatform : public Platform {
public:
explicit CPlatform(SP_Platform platform,
void (*destroy_platform)(SP_Platform*),
SP_StreamExecutor stream_executor, SP_TimerFns timer_fns);
~CPlatform() override;
Id id() const override { return const_cast<int*>(&plugin_id_value_); }
const std::string& Name() const override { return name_; }
int VisibleDeviceCount() const override {
return platform_.visible_device_count;
}
port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
int ordinal) const override;
port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
int ordinal, const PluginConfig& plugin_config) override;
port::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) override;
port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
const StreamExecutorConfig& config) override;
// Trace listener is not supported
void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override {
LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device";
}
void UnregisterTraceListener(TraceListener* listener) override {}
void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); }
private:
SP_Platform platform_;
void (*destroy_platform_)(SP_Platform*);
SP_StreamExecutor stream_executor_;
SP_TimerFns timer_fns_;
const std::string name_;
int plugin_id_value_;
stream_executor::ExecutorCache executor_cache_;
};
} // namespace stream_executor
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_

View File

@ -0,0 +1,802 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0(the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/timer.h"
struct SP_Stream_st {
explicit SP_Stream_st(int id) : stream_id(id) {}
int stream_id;
};
struct SP_Event_st {
explicit SP_Event_st(int id) : event_id(id) {}
int event_id;
};
struct SP_Timer_st {
explicit SP_Timer_st(int id) : timer_id(id) {}
int timer_id;
};
namespace stream_executor {
namespace {
constexpr int DEVICE_COUNT = 2;
constexpr char DEVICE_NAME[] = "MyDevice";
constexpr char DEVICE_TYPE[] = "GPU";
/*** Create SP_StreamExecutor (with empty functions) ***/
void allocate(const SP_Device* const device, uint64_t size,
int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
void deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
}
TF_Bool get_allocator_stats(const SP_Device* const device,
SP_AllocatorStats* const stats) {
return true;
}
TF_Bool device_memory_usage(const SP_Device* const device, int64_t* const free,
int64_t* const total) {
return true;
}
void create_stream(const SP_Device* const device, SP_Stream* stream,
TF_Status* const status) {
stream = nullptr;
}
void destroy_stream(const SP_Device* const device, SP_Stream stream) {}
void create_stream_dependency(const SP_Device* const device,
SP_Stream dependent, SP_Stream other,
TF_Status* const status) {}
void get_stream_status(const SP_Device* const device, SP_Stream stream,
TF_Status* const status) {}
void create_event(const SP_Device* const device, SP_Event* event,
TF_Status* const status) {
event = nullptr;
}
void destroy_event(const SP_Device* const device, SP_Event event) {}
SE_EventStatus get_event_status(const SP_Device* const device, SP_Event event) {
return SE_EVENT_UNKNOWN;
}
void record_event(const SP_Device* const device, SP_Stream stream,
SP_Event event, TF_Status* const status) {}
void wait_for_event(const SP_Device* const device, SP_Stream stream,
SP_Event event, TF_Status* const status) {}
void create_timer(const SP_Device* const device, SP_Timer* timer,
TF_Status* const status) {}
void destroy_timer(const SP_Device* const device, SP_Timer timer) {}
void start_timer(const SP_Device* const device, SP_Stream stream,
SP_Timer timer, TF_Status* const status) {}
void stop_timer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
TF_Status* const status) {}
void memcpy_dtoh(const SP_Device* const device, SP_Stream stream,
void* host_dst, const SP_DeviceMemoryBase* const device_src,
uint64_t size, TF_Status* const status) {}
void memcpy_htod(const SP_Device* const device, SP_Stream stream,
SP_DeviceMemoryBase* const device_dst, const void* host_src,
uint64_t size, TF_Status* const status) {}
void sync_memcpy_dtoh(const SP_Device* const device, void* host_dst,
const SP_DeviceMemoryBase* const device_src,
uint64_t size, TF_Status* const status) {}
void sync_memcpy_htod(const SP_Device* const device,
SP_DeviceMemoryBase* const device_dst,
const void* host_src, uint64_t size,
TF_Status* const status) {}
void block_host_for_event(const SP_Device* const device, SP_Event event,
TF_Status* const status) {}
void synchronize_all_activity(const SP_Device* const device,
TF_Status* const status) {}
TF_Bool host_callback(SP_Device* const device, SP_Stream stream,
SE_StatusCallbackFn const callback_fn,
void* const callback_arg) {
return true;
}
void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
se->struct_size = SP_STREAMEXECUTOR_STRUCT_SIZE;
se->allocate = allocate;
se->deallocate = deallocate;
se->get_allocator_stats = get_allocator_stats;
se->device_memory_usage = device_memory_usage;
se->create_stream = create_stream;
se->destroy_stream = destroy_stream;
se->create_stream_dependency = create_stream_dependency;
se->get_stream_status = get_stream_status;
se->create_event = create_event;
se->destroy_event = destroy_event;
se->get_event_status = get_event_status;
se->record_event = record_event;
se->wait_for_event = wait_for_event;
se->create_timer = create_timer;
se->destroy_timer = destroy_timer;
se->start_timer = start_timer;
se->stop_timer = stop_timer;
se->memcpy_dtoh = memcpy_dtoh;
se->memcpy_htod = memcpy_htod;
se->sync_memcpy_dtoh = sync_memcpy_dtoh;
se->sync_memcpy_htod = sync_memcpy_htod;
se->block_host_for_event = block_host_for_event;
se->synchronize_all_activity = synchronize_all_activity;
se->host_callback = host_callback;
}
/*** Create SP_TimerFns ***/
uint64_t nanoseconds(SP_Timer timer) { return timer->timer_id; }
void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
timer_fns->nanoseconds = nanoseconds;
}
/*** Create SP_Platform ***/
void create_timer_fns(SP_TimerFns* timer_fns, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultTimerFns(timer_fns);
}
void destroy_timer_fns(SP_TimerFns* timer_fns) {}
void create_stream_executor(SE_CreateStreamExecutorParams* params,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultStreamExecutor(params->stream_executor);
}
void destroy_stream_executor(SP_StreamExecutor* se) {}
void create_device(SE_CreateDeviceParams* params, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
params->device->struct_size = SP_DEVICE_STRUCT_SIZE;
}
void destroy_device(SP_Device* device) {}
void PopulateDefaultPlatform(SP_Platform* platform) {
platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
platform->name = DEVICE_NAME;
platform->type = DEVICE_TYPE;
platform->visible_device_count = DEVICE_COUNT;
platform->create_device = create_device;
platform->destroy_device = destroy_device;
platform->create_stream_executor = create_stream_executor;
platform->destroy_stream_executor = destroy_stream_executor;
platform->create_timer_fns = create_timer_fns;
platform->destroy_timer_fns = destroy_timer_fns;
}
void destroy_platform(SP_Platform* const platform) {}
/*** Registration tests ***/
TEST(StreamExecutor, SuccessfulRegistration) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform);
params->destroy_platform = destroy_platform;
};
port::Status status = RegisterDevicePlugin(plugin_init);
TF_ASSERT_OK(status);
port::StatusOr<Platform*> maybe_platform =
MultiPlatformManager::PlatformWithName("MyDevice");
TF_ASSERT_OK(maybe_platform.status());
Platform* platform = maybe_platform.ConsumeValueOrDie();
ASSERT_EQ(platform->Name(), DEVICE_NAME);
ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT);
port::StatusOr<StreamExecutor*> maybe_executor =
platform->ExecutorForDevice(0);
TF_ASSERT_OK(maybe_executor.status());
StreamExecutor* executor = maybe_executor.ConsumeValueOrDie();
ASSERT_EQ(executor->GetDeviceDescription().name(), "MyDevice");
}
TEST(StreamExecutor, NameNotSet) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform);
params->platform->name = nullptr;
params->destroy_platform = destroy_platform;
};
port::Status status = RegisterDevicePlugin(plugin_init);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
}
TEST(StreamExecutor, CreateDeviceNotSet) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform);
params->platform->create_device = nullptr;
params->destroy_platform = destroy_platform;
};
port::Status status = RegisterDevicePlugin(plugin_init);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(status.error_message(),
"'create_device' field in SP_Platform must be set.");
}
/*** StreamExecutor behavior tests ***/
class StreamExecutorTest : public ::testing::Test {
protected:
StreamExecutorTest() {}
void SetUp() override {
PopulateDefaultPlatform(&platform_);
PopulateDefaultStreamExecutor(&se_);
PopulateDefaultTimerFns(&timer_fns_);
}
void TearDown() override {}
StreamExecutor* GetExecutor(int ordinal) {
if (!cplatform_) {
cplatform_ = absl::make_unique<CPlatform>(platform_, destroy_platform,
se_, timer_fns_);
}
port::StatusOr<StreamExecutor*> maybe_executor =
cplatform_->ExecutorForDevice(ordinal);
TF_CHECK_OK(maybe_executor.status());
return maybe_executor.ConsumeValueOrDie();
}
SP_Platform platform_;
SP_StreamExecutor se_;
SP_TimerFns timer_fns_;
std::unique_ptr<CPlatform> cplatform_;
};
TEST_F(StreamExecutorTest, Allocate) {
se_.allocate = [](const SP_Device* const device, uint64_t size,
int64_t memory_space, SP_DeviceMemoryBase* const mem) {
mem->struct_size = SP_DEVICE_MEMORY_BASE_STRUCT_SIZE;
mem->opaque = std::malloc(size);
mem->size = size;
};
se_.deallocate = [](const SP_Device* const device,
SP_DeviceMemoryBase* const mem) {
EXPECT_EQ(mem->size, 2 * sizeof(int));
std::free(mem->opaque);
mem->opaque = nullptr;
mem->size = 0;
};
StreamExecutor* executor = GetExecutor(0);
DeviceMemory<int> mem = executor->AllocateArray<int>(2);
ASSERT_NE(mem.opaque(), nullptr);
ASSERT_EQ(mem.size(), 2 * sizeof(int));
executor->Deallocate(&mem);
ASSERT_EQ(mem.opaque(), nullptr);
}
TEST_F(StreamExecutorTest, HostMemoryAllocate) {
static bool allocate_called = false;
static bool deallocate_called = false;
se_.host_memory_allocate = [](const SP_Device* const device, uint64_t size) {
allocate_called = true;
return std::malloc(size);
};
se_.host_memory_deallocate = [](const SP_Device* const device, void* mem) {
std::free(mem);
deallocate_called = true;
};
StreamExecutor* executor = GetExecutor(0);
ASSERT_FALSE(allocate_called);
void* mem = executor->HostMemoryAllocate(8);
ASSERT_NE(mem, nullptr);
ASSERT_TRUE(allocate_called);
ASSERT_FALSE(deallocate_called);
executor->HostMemoryDeallocate(mem);
ASSERT_TRUE(deallocate_called);
}
TEST_F(StreamExecutorTest, GetAllocatorStats) {
se_.get_allocator_stats = [](const SP_Device* const device,
SP_AllocatorStats* const stat) -> TF_Bool {
stat->struct_size = SP_ALLOCATORSTATS_STRUCT_SIZE;
stat->bytes_in_use = 123;
return true;
};
StreamExecutor* executor = GetExecutor(0);
absl::optional<AllocatorStats> optional_stats = executor->GetAllocatorStats();
ASSERT_TRUE(optional_stats.has_value());
AllocatorStats stats = optional_stats.value();
ASSERT_EQ(stats.bytes_in_use, 123);
}
TEST_F(StreamExecutorTest, DeviceMemoryUsage) {
se_.device_memory_usage = [](const SP_Device* const device,
int64_t* const free,
int64_t* const total) -> TF_Bool {
*free = 45;
*total = 7;
return true;
};
StreamExecutor* executor = GetExecutor(0);
int64 free = 0;
int64 total = 0;
executor->DeviceMemoryUsage(&free, &total);
ASSERT_EQ(free, 45);
ASSERT_EQ(total, 7);
}
TEST_F(StreamExecutorTest, CreateStream) {
static bool stream_created = false;
static bool stream_deleted = false;
se_.create_stream = [](const SP_Device* const device, SP_Stream* stream,
TF_Status* const status) -> void {
*stream = new SP_Stream_st(14);
stream_created = true;
};
se_.destroy_stream = [](const SP_Device* const device,
SP_Stream stream) -> void {
auto custom_stream = static_cast<SP_Stream_st*>(stream);
ASSERT_EQ(custom_stream->stream_id, 14);
delete custom_stream;
stream_deleted = true;
};
StreamExecutor* executor = GetExecutor(0);
ASSERT_FALSE(stream_created);
Stream* stream = new Stream(executor);
stream->Init();
ASSERT_TRUE(stream->ok());
ASSERT_TRUE(stream_created);
ASSERT_FALSE(stream_deleted);
delete stream;
ASSERT_TRUE(stream_deleted);
}
TEST_F(StreamExecutorTest, CreateStreamDependency) {
static bool create_stream_dependency_called = false;
se_.create_stream_dependency = [](const SP_Device* const device,
SP_Stream dependent, SP_Stream other,
TF_Status* const status) {
TF_SetStatus(status, TF_OK, "");
create_stream_dependency_called = true;
};
StreamExecutor* executor = GetExecutor(0);
Stream dependent(executor);
dependent.Init();
Stream other(executor);
other.Init();
ASSERT_FALSE(create_stream_dependency_called);
dependent.ThenWaitFor(&other);
ASSERT_TRUE(create_stream_dependency_called);
}
TEST_F(StreamExecutorTest, StreamStatus) {
static bool status_ok = true;
se_.get_stream_status = [](const SP_Device* const device, SP_Stream stream,
TF_Status* const status) -> void {
if (status_ok) {
TF_SetStatus(status, TF_OK, "");
} else {
TF_SetStatus(status, TF_INTERNAL, "Test error");
}
};
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
ASSERT_TRUE(stream.ok());
TF_ASSERT_OK(stream.RefreshStatus());
status_ok = false;
auto updated_status = stream.RefreshStatus();
ASSERT_FALSE(stream.ok());
ASSERT_EQ(updated_status.error_message(), "Test error");
}
TEST_F(StreamExecutorTest, CreateEvent) {
static bool event_created = false;
static bool event_deleted = false;
se_.create_event = [](const SP_Device* const device, SP_Event* event,
TF_Status* const status) -> void {
*event = new SP_Event_st(123);
event_created = true;
};
se_.destroy_event = [](const SP_Device* const device,
SP_Event event) -> void {
auto custom_event = static_cast<SP_Event_st*>(event);
ASSERT_EQ(custom_event->event_id, 123);
delete custom_event;
event_deleted = true;
};
StreamExecutor* executor = GetExecutor(0);
ASSERT_FALSE(event_created);
Event* event = new Event(executor);
event->Init();
ASSERT_TRUE(event_created);
ASSERT_FALSE(event_deleted);
delete event;
ASSERT_TRUE(event_deleted);
}
TEST_F(StreamExecutorTest, PollForEventStatus) {
static SE_EventStatus event_status = SE_EVENT_COMPLETE;
se_.create_event = [](const SP_Device* const device, SP_Event* event,
TF_Status* const status) -> void {
*event = new SP_Event_st(123);
};
se_.destroy_event = [](const SP_Device* const device,
SP_Event event) -> void { delete event; };
se_.get_event_status = [](const SP_Device* const device,
SP_Event event) -> SE_EventStatus {
EXPECT_EQ(event->event_id, 123);
return event_status;
};
StreamExecutor* executor = GetExecutor(0);
Event event(executor);
event.Init();
ASSERT_EQ(event.PollForStatus(), Event::Status::kComplete);
event_status = SE_EVENT_ERROR;
ASSERT_EQ(event.PollForStatus(), Event::Status::kError);
}
TEST_F(StreamExecutorTest, RecordAndWaitForEvent) {
static bool record_called = false;
static bool wait_called = false;
se_.create_stream = [](const SP_Device* const device, SP_Stream* stream,
TF_Status* const status) -> void {
*stream = new SP_Stream_st(1);
};
se_.destroy_stream = [](const SP_Device* const device,
SP_Stream stream) -> void { delete stream; };
se_.create_event = [](const SP_Device* const device, SP_Event* event,
TF_Status* const status) -> void {
*event = new SP_Event_st(2);
};
se_.destroy_event = [](const SP_Device* const device,
SP_Event event) -> void { delete event; };
se_.record_event = [](const SP_Device* const device, SP_Stream stream,
SP_Event event, TF_Status* const status) {
EXPECT_EQ(stream->stream_id, 1);
EXPECT_EQ(event->event_id, 2);
TF_SetStatus(status, TF_OK, "");
record_called = true;
};
se_.wait_for_event = [](const SP_Device* const device, SP_Stream stream,
SP_Event event, TF_Status* const status) {
EXPECT_EQ(stream->stream_id, 1);
EXPECT_EQ(event->event_id, 2);
TF_SetStatus(status, TF_OK, "");
wait_called = true;
};
StreamExecutor* executor = GetExecutor(0);
Event event(executor);
event.Init();
Stream stream(executor);
stream.Init();
ASSERT_FALSE(record_called);
stream.ThenRecordEvent(&event);
ASSERT_TRUE(record_called);
ASSERT_FALSE(wait_called);
stream.ThenWaitFor(&event);
ASSERT_TRUE(wait_called);
}
TEST_F(StreamExecutorTest, CreateTimer) {
static bool timer_created = false;
static bool timer_deleted = false;
se_.create_timer = [](const SP_Device* const device, SP_Timer* timer,
TF_Status* const status) -> void {
*timer = new SP_Timer_st(25);
timer_created = true;
};
se_.destroy_timer = [](const SP_Device* const device,
SP_Timer timer) -> void {
auto custom_timer = static_cast<SP_Timer_st*>(timer);
EXPECT_EQ(custom_timer->timer_id, 25);
delete custom_timer;
timer_deleted = true;
};
StreamExecutor* executor = GetExecutor(0);
ASSERT_FALSE(timer_created);
Stream stream(executor);
stream.Init();
Timer* timer = new Timer(executor);
stream.InitTimer(timer);
ASSERT_TRUE(stream.ok());
ASSERT_TRUE(timer_created);
ASSERT_FALSE(timer_deleted);
delete timer;
ASSERT_TRUE(timer_deleted);
}
TEST_F(StreamExecutorTest, StartTimer) {
static bool start_called = false;
static bool stop_called = false;
static TF_Code start_timer_status = TF_OK;
static TF_Code stop_timer_status = TF_OK;
se_.create_timer = [](const SP_Device* const device, SP_Timer* timer,
TF_Status* const status) -> void {
*timer = new SP_Timer_st(7);
};
se_.destroy_timer = [](const SP_Device* const device,
SP_Timer timer) -> void { delete timer; };
se_.start_timer = [](const SP_Device* const device, SP_Stream stream,
SP_Timer timer, TF_Status* const status) {
TF_SetStatus(status, start_timer_status, "");
EXPECT_EQ(timer->timer_id, 7);
start_called = true;
};
se_.stop_timer = [](const SP_Device* const device, SP_Stream stream,
SP_Timer timer, TF_Status* const status) {
TF_SetStatus(status, stop_timer_status, "");
EXPECT_EQ(timer->timer_id, 7);
stop_called = true;
};
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
Timer timer(executor);
stream.InitTimer(&timer);
// Check both start and stop succeed
ASSERT_FALSE(start_called);
stream.ThenStartTimer(&timer);
ASSERT_TRUE(start_called);
ASSERT_FALSE(stop_called);
stream.ThenStopTimer(&timer);
ASSERT_TRUE(stop_called);
// Check start timer fails
ASSERT_TRUE(stream.ok());
start_timer_status = TF_UNKNOWN;
stream.ThenStartTimer(&timer);
ASSERT_FALSE(stream.ok());
// Check stop timer fails
start_timer_status = TF_OK;
stop_timer_status = TF_UNKNOWN;
Stream stream2(executor);
stream2.Init();
Timer timer2(executor);
stream2.InitTimer(&timer2);
stream2.ThenStartTimer(&timer2);
ASSERT_TRUE(stream2.ok());
stream2.ThenStopTimer(&timer2);
ASSERT_FALSE(stream2.ok());
}
TEST_F(StreamExecutorTest, TimerFns) {
se_.create_timer = [](const SP_Device* const device, SP_Timer* timer,
TF_Status* const status) -> void {
*timer = new SP_Timer_st(25000);
};
se_.destroy_timer = [](const SP_Device* const device,
SP_Timer timer) -> void { delete timer; };
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
Timer timer(executor);
stream.InitTimer(&timer);
// Our test nanoseconds function just returns value
// passed to SP_Timer_st constructor.
ASSERT_EQ(timer.Nanoseconds(), 25000);
ASSERT_EQ(timer.Microseconds(), 25);
}
TEST_F(StreamExecutorTest, MemcpyToHost) {
se_.create_stream = [](const SP_Device* const device, SP_Stream* stream,
TF_Status* const status) -> void {
*stream = new SP_Stream_st(14);
};
se_.destroy_stream = [](const SP_Device* const device,
SP_Stream stream) -> void { delete stream; };
se_.memcpy_dtoh = [](const SP_Device* const device, SP_Stream stream,
void* host_dst,
const SP_DeviceMemoryBase* const device_src,
uint64_t size, TF_Status* const status) {
TF_SetStatus(status, TF_OK, "");
EXPECT_EQ(stream->stream_id, 14);
std::memcpy(host_dst, device_src->opaque, size);
};
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
size_t size = sizeof(int);
int src_data = 34;
int dst_data = 2;
DeviceMemoryBase device_src(&src_data, size);
Stream& stream_ref = stream.ThenMemcpy(&dst_data, device_src, size);
ASSERT_EQ(dst_data, 34);
ASSERT_EQ(stream_ref.implementation(), stream.implementation());
}
TEST_F(StreamExecutorTest, MemcpyFromHost) {
se_.memcpy_htod = [](const SP_Device* const device, SP_Stream stream,
SP_DeviceMemoryBase* const device_dst,
const void* host_src, uint64_t size,
TF_Status* const status) {
TF_SetStatus(status, TF_OK, "");
std::memcpy(device_dst->opaque, host_src, size);
};
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
size_t size = sizeof(int);
int src_data = 18;
int dst_data = 0;
DeviceMemoryBase device_dst(&dst_data, size);
stream.ThenMemcpy(&device_dst, &src_data, size);
ASSERT_EQ(dst_data, 18);
}
TEST_F(StreamExecutorTest, MemcpyDeviceToDevice) {
se_.memcpy_dtod = [](const SP_Device* const device, SP_Stream stream,
SP_DeviceMemoryBase* const device_dst,
const SP_DeviceMemoryBase* const device_src,
uint64_t size, TF_Status* const status) {
TF_SetStatus(status, TF_OK, "");
std::memcpy(device_dst->opaque, device_src->opaque, size);
};
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
size_t size = sizeof(int);
int src_data = 18;
int dst_data = 0;
DeviceMemoryBase device_dst(&dst_data, size);
DeviceMemoryBase device_src(&src_data, size);
stream.ThenMemcpy(&device_dst, device_src, size);
ASSERT_EQ(dst_data, 18);
}
TEST_F(StreamExecutorTest, SyncMemcpyToHost) {
se_.sync_memcpy_dtoh = [](const SP_Device* const device, void* host_dst,
const SP_DeviceMemoryBase* const device_src,
uint64_t size, TF_Status* const status) {
TF_SetStatus(status, TF_OK, "");
std::memcpy(host_dst, device_src->opaque, size);
};
StreamExecutor* executor = GetExecutor(0);
size_t size = sizeof(int);
int src_data = 34;
int dst_data = 2;
DeviceMemoryBase device_src(&src_data, size);
TF_ASSERT_OK(executor->SynchronousMemcpyD2H(device_src, size, &dst_data));
ASSERT_EQ(dst_data, 34);
}
TEST_F(StreamExecutorTest, SyncMemcpyFromHost) {
se_.sync_memcpy_htod =
[](const SP_Device* const device, SP_DeviceMemoryBase* const device_dst,
const void* host_src, uint64_t size, TF_Status* const status) {
TF_SetStatus(status, TF_OK, "");
std::memcpy(device_dst->opaque, host_src, size);
};
StreamExecutor* executor = GetExecutor(0);
size_t size = sizeof(int);
int src_data = 18;
int dst_data = 0;
DeviceMemoryBase device_dst(&dst_data, size);
TF_ASSERT_OK(executor->SynchronousMemcpyH2D(&src_data, size, &device_dst));
ASSERT_EQ(dst_data, 18);
}
TEST_F(StreamExecutorTest, SyncMemcpyDeviceToDevice) {
se_.sync_memcpy_dtod = [](const SP_Device* const device,
SP_DeviceMemoryBase* const device_dst,
const SP_DeviceMemoryBase* const device_src,
uint64_t size, TF_Status* const status) {
TF_SetStatus(status, TF_OK, "");
std::memcpy(device_dst->opaque, device_src->opaque, size);
};
StreamExecutor* executor = GetExecutor(0);
size_t size = sizeof(int);
int src_data = 18;
int dst_data = 0;
DeviceMemoryBase device_dst(&dst_data, size);
DeviceMemoryBase device_src(&src_data, size);
ASSERT_TRUE(executor->SynchronousMemcpy(&device_dst, device_src, size));
ASSERT_EQ(dst_data, 18);
}
TEST_F(StreamExecutorTest, BlockHostForEvent) {
static bool block_host_for_event_called = false;
se_.create_event = [](const SP_Device* const device, SP_Event* event,
TF_Status* const status) {
*event = new SP_Event_st(357);
};
se_.destroy_event = [](const SP_Device* const device, SP_Event event) {
delete event;
};
se_.block_host_for_event = [](const SP_Device* const device, SP_Event event,
TF_Status* const status) -> void {
ASSERT_EQ(event->event_id, 357);
TF_SetStatus(status, TF_OK, "");
block_host_for_event_called = true;
};
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
ASSERT_FALSE(block_host_for_event_called);
TF_ASSERT_OK(stream.BlockHostUntilDone());
ASSERT_TRUE(block_host_for_event_called);
}
TEST_F(StreamExecutorTest, SynchronizeAllActivity) {
static bool synchronize_all_called = false;
se_.synchronize_all_activity = [](const SP_Device* const device,
TF_Status* const status) {
TF_SetStatus(status, TF_OK, "");
synchronize_all_called = true;
};
StreamExecutor* executor = GetExecutor(0);
ASSERT_FALSE(synchronize_all_called);
ASSERT_TRUE(executor->SynchronizeAllActivity());
ASSERT_TRUE(synchronize_all_called);
}
TEST_F(StreamExecutorTest, HostCallbackOk) {
se_.host_callback = [](SP_Device* const device, SP_Stream stream,
SE_StatusCallbackFn const callback_fn,
void* const callback_arg) -> TF_Bool {
TF_Status* status = TF_NewStatus();
callback_fn(callback_arg, status);
bool ok = TF_GetCode(status) == TF_OK;
TF_DeleteStatus(status);
return ok;
};
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
std::function<port::Status()> callback = []() -> port::Status {
return port::Status::OK();
};
stream.ThenDoHostCallbackWithStatus(callback);
ASSERT_TRUE(stream.ok());
}
TEST_F(StreamExecutorTest, HostCallbackError) {
se_.host_callback = [](SP_Device* const device, SP_Stream stream,
SE_StatusCallbackFn const callback_fn,
void* const callback_arg) -> TF_Bool {
TF_Status* status = TF_NewStatus();
callback_fn(callback_arg, status);
bool ok = TF_GetCode(status) == TF_OK;
TF_DeleteStatus(status);
return ok;
};
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
std::function<port::Status()> callback = []() -> port::Status {
return port::UnimplementedError("Unimplemented");
};
stream.ThenDoHostCallbackWithStatus(callback);
ASSERT_FALSE(stream.ok());
}
} // namespace
} // namespace stream_executor

View File

@ -261,7 +261,6 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
size_t len, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
@ -279,4 +278,73 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
return nullptr;
}
return tf_tensor;
}
}
TF_Tensor* TF_ForwardInputOrAllocateOutput(
TF_OpKernelContext* context, int* candidate_input_indices,
int num_candidate_input_indices, int output_index, int64_t* output_dims,
int output_num_dims, int* forwarded_input, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
tensorflow::gtl::ArraySlice<int> input_indices_array(
candidate_input_indices, num_candidate_input_indices);
tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
tensorflow::Tensor* output_tensor_pointer;
tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
input_indices_array, output_index,
tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
forwarded_input);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
return tf_tensor_output;
}
TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
int64_t* dims, int num_dims,
TF_AllocatorAttributes* attributes,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
TF_SetStatus(status, TF_OK, "");
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
reinterpret_cast<tensorflow::int64*>(dims), num_dims);
if (attributes && !attributes->struct_size) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
"TF_AllocatorAttributes struct "
"size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
return nullptr;
}
tensorflow::AllocatorAttributes allocator_attr;
if (attributes && attributes->on_host) {
allocator_attr.set_on_host(true);
}
tensorflow::Status s;
tensorflow::Tensor tensor;
s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimarray), &tensor,
allocator_attr);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
TF_Tensor* tf_tensor;
tf_tensor = TF_TensorFromTensor(tensor, &s);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
return tf_tensor;
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
// Macro to control visibility of exported symbols in the shared library (.so,
// .dylib, .dll).
@ -199,6 +200,26 @@ TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context,
int64_t* dims, int num_dims,
size_t len, TF_Status* status);
// Tries to forward one of the inputs given in input_indices to
// output[output_index]. If none of the given inputs can be forwarded, calls
// allocate_output() to allocate a new output buffer. The index of the
// forwarded input will be assign to output argument forwarded_input (if it's
// not nullptr). If no inputs are forwarded, forwarded_input will be assigned
// -1.
TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput(
TF_OpKernelContext* context, int* candidate_input_indices,
int num_candidate_input_indices, int output_index, int64_t* output_dims,
int output_num_dims, int* forwarded_input, TF_Status* status);
// Allocates a temporary Tensor of the specified type and shape. The
// Tensor must not be used after kernel construction is
// complete.
//
// num_dims must equal the size of array dims
TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTemp(
TF_OpKernelContext* context, TF_DataType dtype, int64_t* dims, int num_dims,
TF_AllocatorAttributes* alloc_attrs, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -39,6 +39,33 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "histogram_summary_op",
prefix = "histogram_summary_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "merge_summary_op",
prefix = "merge_summary_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
tf_gen_op_libs(
op_lib_names = ["bitcast"],
deps = [
@ -59,6 +86,24 @@ tf_gen_op_libs(
],
)
tf_gen_op_libs(
op_lib_names = ["histogram_summary"],
deps = [
"//tensorflow/c:ops",
"//tensorflow/c:tf_status",
"//tensorflow/core:lib",
],
)
tf_gen_op_libs(
op_lib_names = ["merge_summary"],
deps = [
"//tensorflow/c:ops",
"//tensorflow/c:tf_status",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "bitcast_op_test",
srcs = ["bitcast_op_test.cc"],
@ -122,6 +167,8 @@ filegroup(
name = "android_all_op_kernels",
srcs = [
"bitcast_op.cc",
"histogram_summary_op.cc",
"merge_summary_op.cc",
"summary_op.cc",
"tensor_shape_utils.cc",
"tensor_shape_utils.h",
@ -133,6 +180,8 @@ filegroup(
name = "android_all_ops",
srcs = [
"ops/bitcast.cc",
"ops/histogram_summary.cc",
"ops/merge_summary.cc",
"ops/summary.cc",
],
)

View File

@ -0,0 +1,163 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
namespace {
// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status.
struct TFTensorDeleter {
void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); }
};
struct TFStatusDeleter {
void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
};
// Struct that wraps TF_Tensor and TF_Status to delete once out of scope.
using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>;
using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>;
// Used to pass the operation node name from kernel construction to
// kernel computation.
struct HistogramSummaryOp {
std::string op_node_name;
};
void* HistogramSummaryOp_Create(TF_OpKernelConstruction* ctx) {
HistogramSummaryOp* kernel = new HistogramSummaryOp;
TF_StringView string_view_name = TF_OpKernelConstruction_GetName(ctx);
kernel->op_node_name =
std::string(string_view_name.data, string_view_name.len);
return kernel;
}
void HistogramSummaryOp_Delete(void* kernel) {
delete static_cast<HistogramSummaryOp*>(kernel);
}
template <typename T>
void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
HistogramSummaryOp* k = static_cast<HistogramSummaryOp*>(kernel);
TF_Tensor* tags;
TF_Tensor* values;
Safe_TF_StatusPtr status(TF_NewStatus());
TF_GetInput(ctx, 0, &tags, status.get());
Safe_TF_TensorPtr safe_tags_ptr(tags);
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
TF_GetInput(ctx, 1, &values, status.get());
Safe_TF_TensorPtr safe_values_ptr(values);
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
if (TF_NumDims(safe_tags_ptr.get()) != 0) {
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, "tags must be scalar");
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
// Cast values to array to access tensor elements by index
auto values_array = static_cast<T*>(TF_TensorData(safe_values_ptr.get()));
tensorflow::histogram::Histogram histo;
for (int64_t i = 0; i < TF_TensorElementCount(safe_values_ptr.get()); ++i) {
const double double_val = static_cast<double>(values_array[i]);
if (Eigen::numext::isnan(double_val)) {
std::ostringstream err;
err << "Nan in summary histogram for: " << k->op_node_name;
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
return;
} else if (Eigen::numext::isinf(double_val)) {
std::ostringstream err;
err << "Infinity in Histogram for: " << k->op_node_name;
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
return;
}
histo.Add(double_val);
}
tensorflow::Summary s;
tensorflow::Summary::Value* v = s.add_value();
const tensorflow::tstring& tag =
*(static_cast<tensorflow::tstring*>(TF_TensorData(safe_tags_ptr.get())));
v->set_tag(tag.data(), tag.size());
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0),
/*dims=*/nullptr, /*num_dims=*/0,
/*len=*/sizeof(tensorflow::tstring), status.get()));
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
TF_TensorData(summary_tensor.get()));
CHECK(SerializeToTString(s, output_tstring));
}
template <typename T>
void RegisterHistogramSummaryOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder(
"HistogramSummary", tensorflow::DEVICE_CPU, &HistogramSummaryOp_Create,
&HistogramSummaryOp_Compute<T>, &HistogramSummaryOp_Delete);
TF_KernelBuilder_TypeConstraint(
builder, "T",
static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
TF_RegisterKernelBuilder("HistogramSummary", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering Histogram Summmary kernel";
}
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the Histogram Summary kernel.
TF_ATTRIBUTE_UNUSED static bool IsHistogramSummaryOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("HistogramSummary")) {
RegisterHistogramSummaryOpKernel<tensorflow::int64>();
RegisterHistogramSummaryOpKernel<tensorflow::uint64>();
RegisterHistogramSummaryOpKernel<tensorflow::int32>();
RegisterHistogramSummaryOpKernel<tensorflow::uint32>();
RegisterHistogramSummaryOpKernel<tensorflow::uint16>();
RegisterHistogramSummaryOpKernel<tensorflow::int16>();
RegisterHistogramSummaryOpKernel<tensorflow::int8>();
RegisterHistogramSummaryOpKernel<tensorflow::uint8>();
RegisterHistogramSummaryOpKernel<Eigen::half>();
RegisterHistogramSummaryOpKernel<tensorflow::bfloat16>();
RegisterHistogramSummaryOpKernel<float>();
RegisterHistogramSummaryOpKernel<double>();
}
return true;
}();
} // namespace

View File

@ -0,0 +1,123 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <sstream>
#include <unordered_set>
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
namespace {
// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status
struct TFTensorDeleter {
void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); }
};
struct TFStatusDeleter {
void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
};
// Struct that wraps TF_Tensor and TF_Status to delete once out of scope
using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>;
using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>;
// dummy functions used for kernel registration
void* MergeSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
void MergeSummaryOp_Delete(void* kernel) {}
void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
tensorflow::Summary s;
std::unordered_set<tensorflow::string> tags;
Safe_TF_StatusPtr status(TF_NewStatus());
for (int input_num = 0; input_num < TF_NumInputs(ctx); ++input_num) {
TF_Tensor* input;
TF_GetInput(ctx, input_num, &input, status.get());
Safe_TF_TensorPtr safe_input_ptr(input);
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
auto tags_array =
static_cast<tensorflow::tstring*>(TF_TensorData(safe_input_ptr.get()));
for (int i = 0; i < TF_TensorElementCount(safe_input_ptr.get()); ++i) {
const tensorflow::tstring& s_in = tags_array[i];
tensorflow::Summary summary_in;
if (!tensorflow::ParseProtoUnlimited(&summary_in, s_in)) {
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT,
"Could not parse one of the summary inputs");
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
for (int v = 0; v < summary_in.value_size(); ++v) {
// This tag is unused by the TensorSummary op, so no need to check for
// duplicates.
const tensorflow::string& tag = summary_in.value(v).tag();
if ((!tag.empty()) && !tags.insert(tag).second) {
std::ostringstream err;
err << "Duplicate tag " << tag << " found in summary inputs ";
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
*s.add_value() = summary_in.value(v);
}
}
}
Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0),
/*dims=*/nullptr, /*num_dims=*/0,
/*len=*/sizeof(tensorflow::tstring), status.get()));
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
TF_TensorData(summary_tensor.get()));
CHECK(SerializeToTString(s, output_tstring));
}
void RegisterMergeSummaryOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder(
"MergeSummary", tensorflow::DEVICE_CPU, &MergeSummaryOp_Create,
&MergeSummaryOp_Compute, &MergeSummaryOp_Delete);
TF_RegisterKernelBuilder("MergeSummary", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering Merge Summmary kernel";
}
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the Histogram Summary kernel.
TF_ATTRIBUTE_UNUSED static bool IsMergeSummaryOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("MergeSummary")) {
RegisterMergeSummaryOpKernel();
}
return true;
}();
} // namespace

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
static void histogram_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
TF_DeleteShapeHandle(result);
}
void Register_HistogramSummaryOp() {
TF_Status* status = TF_NewStatus();
TF_OpDefinitionBuilder* op_builder =
TF_NewOpDefinitionBuilder("HistogramSummary");
TF_OpDefinitionBuilderAddInput(op_builder, "tag: string");
TF_OpDefinitionBuilderAddInput(op_builder, "values: T");
TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
TF_OpDefinitionBuilderAddAttr(op_builder, "T: realnumbertype = DT_FLOAT");
TF_OpDefinitionBuilderSetShapeInferenceFunction(
op_builder, &histogram_summary_shape_inference_fn);
TF_RegisterOpDefinition(op_builder, status);
CHECK_EQ(TF_GetCode(status), TF_OK)
<< "HistogramSummary op registration failed: " << TF_Message(status);
TF_DeleteStatus(status);
}
TF_ATTRIBUTE_UNUSED static bool HistogramSummaryOpRegistered = []() {
if (SHOULD_REGISTER_OP("HistogramSummary")) {
Register_HistogramSummaryOp();
}
return true;
}();

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
static void merge_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
TF_DeleteShapeHandle(result);
}
void Register_MergeSummaryOp() {
TF_Status* status = TF_NewStatus();
TF_OpDefinitionBuilder* op_builder =
TF_NewOpDefinitionBuilder("MergeSummary");
TF_OpDefinitionBuilderAddInput(op_builder, "inputs: N * string");
TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
TF_OpDefinitionBuilderAddAttr(op_builder, "N: int >= 1");
TF_OpDefinitionBuilderSetShapeInferenceFunction(
op_builder, &merge_summary_shape_inference_fn);
TF_RegisterOpDefinition(op_builder, status);
CHECK_EQ(TF_GetCode(status), TF_OK)
<< "MergeSummary op registration failed: " << TF_Message(status);
TF_DeleteStatus(status);
}
TF_ATTRIBUTE_UNUSED static bool MergeSummaryOpRegistered = []() {
if (SHOULD_REGISTER_OP("MergeSummary")) {
Register_MergeSummaryOp();
}
return true;
}();

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"

View File

@ -368,6 +368,16 @@ class DeviceKernelOpTest : public OpsTestBase {
#endif
};
// Validates that the tensor has shape and type corresponding to
// dims and dtype.
void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
TF_DataType dtype);
// Copies data of length tensor_size_bytes from values to tensor.
template <typename T>
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
TF_OpKernelContext* ctx);
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
@ -379,22 +389,11 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
TF_Tensor* output = TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*len=*/tensor_size_bytes, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(1, TF_NumDims(output));
EXPECT_EQ(1, TF_Dim(output, 0));
validate_tensor(output, &dim, 1, TF_FLOAT);
// Set output to 3
float* data = reinterpret_cast<float*>(TF_TensorData(output));
float value = 3.0f;
#if GOOGLE_CUDA
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, &value,
tensor_size_bytes);
#else
*data = value;
#endif
float values[1] = {3.0f};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
@ -417,12 +416,8 @@ TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) {
TF_Tensor* output = TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*len=*/0, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(1, TF_NumDims(output));
EXPECT_EQ(0, TF_Dim(output, 0));
validate_tensor(output, &dim, 1, TF_FLOAT);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
@ -442,27 +437,16 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
TF_Status* s = TF_NewStatus();
// Allocate 2x3 output
int64_t dim[2] = {2, 3};
size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT) * 6;
TF_Tensor* output = TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim,
/*num_dims=*/2, /*len=*/tensor_size_bytes, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(2, TF_NumDims(output));
EXPECT_EQ(2, TF_Dim(output, 0));
EXPECT_EQ(3, TF_Dim(output, 1));
validate_tensor(output, dim, 2, TF_FLOAT);
// Set output to [1 2 3 4 5 6]
void* data = TF_TensorData(output);
float value[6] = {1, 2, 3, 4, 5, 6};
#if GOOGLE_CUDA
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, value,
tensor_size_bytes);
#else
memcpy(data, value, tensor_size_bytes);
#endif
float values[6] = {1, 2, 3, 4, 5, 6};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
@ -474,4 +458,200 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
output->DebugString(100));
}
REGISTER_OP("AllocateTempOp1").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
// Allocate scalar TF_Tensor
TF_Status* s = TF_NewStatus();
int64_t dim = 1;
TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA
alloc_attrs.on_host = 0;
#else
alloc_attrs.on_host = 1;
#endif
TF_Tensor* output = TF_AllocateTemp(
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
EXPECT_EQ(TF_OK, TF_GetCode(s));
validate_tensor(output, &dim, 1, TF_FLOAT);
// Set TF_Tensor value to 3
float values[1] = {3.0f};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_SetOutput(ctx, 0, output, s);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
SetupOp("AllocateTempOp1", "AllocateTemp1", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
output->DebugString(100));
}
REGISTER_OP("AllocateTempOp0").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus();
// Allocate empty TF_Tensor
int64_t dim = 0;
TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA
alloc_attrs.on_host = 0;
#else
alloc_attrs.on_host = 1;
#endif
TF_Tensor* output = TF_AllocateTemp(
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
validate_tensor(output, &dim, 1, TF_FLOAT);
TF_SetOutput(ctx, 0, output, s);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
SetupOp("AllocateTempOp0", "AllocateTemp0", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
EXPECT_EQ("Tensor<type: float shape: [0] values: >",
output->DebugString(100));
}
REGISTER_OP("AllocateTempOp2x3").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus();
size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
// Allocate 2x3 TF_Tensor
int64_t dim[2] = {2, 3};
TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA
alloc_attrs.on_host = 0;
#else
alloc_attrs.on_host = 1;
#endif
TF_Tensor* output = TF_AllocateTemp(
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/dim,
/*num_dims=*/2, /*allocator_attributes*/ &alloc_attrs, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
validate_tensor(output, dim, 2, TF_FLOAT);
// Set TF_Tensor values to [1 2 3 4 5 6]
float values[6] = {1, 2, 3, 4, 5, 6};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_SetOutput(ctx, 0, output, s);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
SetupOp("AllocateTempOp2x3", "AllocateTempOp2x3", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
output->DebugString(100));
}
TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
const char* node_name = "TestForwardInputOrAllocateOutputKernel";
const char* op_name = "BazOp";
const char* device_name = "FakeDeviceName";
REGISTER_OP(op_name)
.Input("input1: float")
.Input("input2: float")
.Output("output1: float")
.Attr("SomeDataTypeAttr: type");
// A kernel whose Compute function that forwards a scalar input to output
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus();
int candidate_input_indices[1] = {0};
int forwarded_input;
int64_t output_dims[1] = {};
TF_Tensor* output = TF_ForwardInputOrAllocateOutput(
/*context=*/ctx, candidate_input_indices,
/*num_candidate_input_indices=*/1,
/*output_index=*/0, output_dims, /*output_num_dims=*/0,
&forwarded_input, /*status=*/s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(forwarded_input, 0);
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(0, TF_NumDims(output));
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
my_compute_func, nullptr);
{
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
}
{
OpKernelContext::Params p;
DummyDevice dummy_device(nullptr);
p.device = &dummy_device;
AllocatorAttributes alloc_attrs;
p.output_attr_array = &alloc_attrs;
Tensor t(123.0f);
gtl::InlinedVector<TensorValue, 4> inputs;
// GetFakeKernel requires a NodeDef with two inputs
inputs.emplace_back(&t);
inputs.emplace_back();
p.inputs = &inputs;
Status status;
std::unique_ptr<OpKernel> kernel =
GetFakeKernel(device_name, op_name, node_name, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
p.op_kernel = kernel.get();
OpKernelContext ctx(&p);
kernel->Compute(&ctx);
ASSERT_EQ(123, ctx.mutable_output(0)->scalar<float>()());
}
}
void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
TF_DataType dtype) {
EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor));
EXPECT_EQ(num_dims, TF_NumDims(tensor));
for (int i = 0; i < num_dims; ++i) {
EXPECT_EQ(dims[i], TF_Dim(tensor, i));
}
}
template <typename T>
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
TF_OpKernelContext* ctx) {
T* data = reinterpret_cast<T*>(TF_TensorData(tensor));
#if GOOGLE_CUDA
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values,
tensor_size_bytes);
#else
memcpy(data, values, tensor_size_bytes);
#endif
}
} // namespace tensorflow

View File

@ -28,6 +28,7 @@ void TF_Log(TF_LogLevel level, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
auto message = BuildMessage(fmt, args);
va_end(args);
switch (level) {
case TF_INFO:
LOG(INFO) << message;
@ -48,6 +49,7 @@ void TF_VLog(int level, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
auto message = BuildMessage(fmt, args);
va_end(args);
VLOG(level) << message;
}
@ -55,5 +57,6 @@ void TF_DVLog(int level, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
auto message = BuildMessage(fmt, args);
va_end(args);
DVLOG(level) << message;
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <stdbool.h>
#include <stdint.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -45,6 +46,16 @@ limitations under the License.
extern "C" {
#endif
// Allocator Attributes used for tensor allocation.
typedef struct TF_AllocatorAttributes {
size_t struct_size;
// Set boolean to 1 for CPU allocation, else 0.
TF_Bool on_host;
} TF_AllocatorAttributes;
#define TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE \
TF_OFFSET_OF_END(TF_AllocatorAttributes, on_host)
// --------------------------------------------------------------------------
// TF_Tensor holds a multi-dimensional array of elements of a single data type.
// For all types other than TF_STRING, the data buffer stores elements

View File

@ -47,6 +47,7 @@ cc_library(
# TODO(b/111634734): :lib and :protos_all contain dependencies that
# cannot be built on mobile platforms. Instead, include the appropriate
# tf_lib depending on the build platform.
"@com_google_absl//absl/memory:memory",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
]),
@ -171,6 +172,7 @@ tf_cc_test(
deps = [
":constants",
":loader",
":reader",
":signature_constants",
":tag_constants",
"//tensorflow/core:lib",

View File

@ -51,8 +51,32 @@ cc_library(
deps = [
":concrete_function",
":concrete_function_list",
":signature_def_function",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:status",
],
)
cc_library(
name = "signature_def_function",
hdrs = [
"signature_def_function.h",
],
deps = [
":signature_def_function_metadata",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/experimental/saved_model/public:signature_def_function",
"//tensorflow/cc/experimental/base/public:status",
],
)
cc_library(
name = "signature_def_function_metadata",
hdrs = [
"signature_def_function_metadata.h",
],
deps = [
"//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata",
],
)

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
#include "tensorflow/cc/saved_model/experimental/public/signature_def_function.h"
namespace tensorflow {
namespace experimental {
@ -80,8 +81,8 @@ class SavedModelAPI {
// If status is not OK, returns nullptr. Otherwise, returns a
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
// is bound to SavedModelAPI it was loaded from.
ConcreteFunction* GetSignatureDefFunction(const std::string& function_path,
Status* status);
SignatureDefFunction* GetSignatureDefFunction(
const std::string& function_path, Status* status);
// Lists all Conrete Functions available from the SavedModel.
std::vector<ConcreteFunction*> ListFunctions();
@ -140,14 +141,14 @@ inline ConcreteFunction* SavedModelAPI::GetConcreteFunction(
return ConcreteFunction::wrap(function);
}
inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction(
inline SignatureDefFunction* SavedModelAPI::GetSignatureDefFunction(
const std::string& function_path, Status* status) {
TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction(
TF_SignatureDefFunction* function = TF_GetSavedModelSignatureDefFunction(
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
if (!status->ok()) {
return nullptr;
}
return ConcreteFunction::wrap(function);
return SignatureDefFunction::wrap(function);
}
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {

View File

@ -0,0 +1,89 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// SignatureDefFunctions are functions that correspond to either:
// "signatures" saved from a TF2 SavedModel APIs:
// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/python/saved_model/save.py#L830-L854
// Or the "SignatureDefMap" saved from TF1 SavedModel APIs:
// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/python/saved_model/load_v1_in_v2_test.py#L170-L174
// In both cases, a SignatureDef is serialized as a SignatureDef protobuf:
// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/core/protobuf/meta_graph.proto#L260-L330
// and represents a computation defined by a TF subgraph.
// These Signatures were primarily designed to be interoperable with the legacy
// TF 1 Session-based C++ SavedModelBundle loading APIs:
// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/cc/saved_model/loader.h#L96-L108
// SignatureDefFunctions have different semantics from regular TF2
// ConcreteFunctions, and are mainly intended provide a serving-friendly
// transition point from the TF1 Session API.
// First, SignatureDefFunctions have different calling conventions.
// SignatureDefFunctions' inputs and outputs are constrained to **flattened
// lists of TensorHandles only**. They do not support more exotic input/output
// types (like optionals, generators, etc). Additionally, this flattening means
// they will not preserve the exact interface of the original tf.function they
// were traced from, as things like composite tensors decay into their
// internal dense tensor representation.
// Second, all inputs and outputs are "named", and these names are load bearing
// (eg: they are part of the interface of tensorflow_serving):
// https://github.com/tensorflow/serving/blob/e0d247b2e4050713194b8fad0be24a0636df7209/tensorflow_serving/apis/predict.proto#L21
// https://github.com/tensorflow/serving/blob/e0d247b2e4050713194b8fad0be24a0636df7209/tensorflow_serving/apis/predict.proto#L39
// The name of each input/output is stored in the corresponding tf::Argument in
// SignatureDefFunctionMetadata::arguments(). Users must ensure the order of
// TensorHandles passed to the function matches with the order of named
// arguments. Similarly the name of the outputs is stored in
// SignatureDefFunctionMetadata::returns().
class SignatureDefFunction final {
public:
// Returns FunctionMetadata associated with this ConcreteFunction.
const SignatureDefFunctionMetadata* GetFunctionMetadata();
private:
friend class SavedModelAPI;
friend class ConcreteFunctionList;
// TODO(bmzhao): Consider adding a macro for wrapping/unwrapping
// when moving out of experimental.
static SignatureDefFunction* wrap(TF_SignatureDefFunction* p) {
return reinterpret_cast<SignatureDefFunction*>(p);
}
static TF_SignatureDefFunction* unwrap(SignatureDefFunction* p) {
return reinterpret_cast<TF_SignatureDefFunction*>(p);
}
};
inline const SignatureDefFunctionMetadata*
SignatureDefFunction::GetFunctionMetadata() {
return SignatureDefFunctionMetadata::wrap(
TF_SignatureDefFunctionGetMetadata(unwrap(this)));
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
#include <memory>
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// SignatureDefFunctionMetadata stores additional information on each input
// and output's names, dtypes, and shape.
class SignatureDefFunctionMetadata final {
// TODO(bmzhao): Add getters here as necessary.
private:
friend class SignatureDefFunction;
static SignatureDefFunctionMetadata* wrap(
TF_SignatureDefFunctionMetadata* p) {
return reinterpret_cast<SignatureDefFunctionMetadata*>(p);
}
static TF_SignatureDefFunctionMetadata* unwrap(
SignatureDefFunctionMetadata* p) {
return reinterpret_cast<TF_SignatureDefFunctionMetadata*>(p);
}
};
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
@ -95,16 +96,6 @@ static Status ValidateSavedTensors(const GraphDef& graph_def) {
return Status::OK();
}
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
const SessionOptions& session_options,
std::unique_ptr<Session>* session) {
Session* session_p = nullptr;
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
session->reset(session_p);
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def()));
return (*session)->Create(meta_graph_def.graph_def());
}
Tensor CreateStringTensor(const string& value) {
Tensor tensor(DT_STRING, TensorShape({}));
tensor.scalar<tstring>()() = value;
@ -228,22 +219,18 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
nullptr /* outputs */, &run_metadata, session);
}
Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
LOG(INFO) << "Reading SavedModel debug info (if present) from: "
<< export_dir;
} // namespace
const string debug_info_pb_path =
io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
GraphDebugInfo debug_info;
TF_RETURN_IF_ERROR(
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
*debug_info_proto =
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
}
return Status::OK();
SavedModelBundleInterface::~SavedModelBundleInterface() {}
Status LoadMetagraphIntoSession(const SessionOptions& session_options,
const MetaGraphDef& meta_graph,
std::unique_ptr<Session>* session) {
Session* session_p = nullptr;
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
session->reset(session_p);
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph.graph_def()));
return (*session)->Create(meta_graph.graph_def());
}
Status LoadSavedModelInternal(const SessionOptions& session_options,
@ -251,46 +238,17 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundle* const bundle) {
const uint64 read_start_microseconds = Env::Default()->NowMicros();
TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
&bundle->meta_graph_def));
TF_RETURN_IF_ERROR(
ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info));
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
bundle->meta_graph_def, session_options, &bundle->session));
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
TF_RETURN_IF_ERROR(
RunRestore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(),
asset_file_defs, bundle->session.get()));
// Record walltime spent in restoring graph from disk, but postpone metric
// increments until graph init finishes.
const uint64 restore_graph_walltime =
GetLatencyMicroseconds(read_start_microseconds);
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
string init_op_name;
TF_RETURN_IF_ERROR(
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
asset_file_defs, bundle->session.get(),
init_op_name));
load_latency_by_stage->GetCell(export_dir, "restore_graph")
->Add(restore_graph_walltime);
// Record wall time spent in init op.
load_latency_by_stage->GetCell(export_dir, "init_graph")
->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
TF_RETURN_IF_ERROR(LoadMetagraphIntoSession(
session_options, bundle->meta_graph_def, &bundle->session));
TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def,
export_dir, &bundle->session));
return Status::OK();
}
} // namespace
SavedModelBundleInterface::~SavedModelBundleInterface() {}
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
@ -424,6 +382,35 @@ class LiteSessionWrapper : public Session {
};
} // namespace
Status RestoreSession(const RunOptions& run_options,
const MetaGraphDef& meta_graph, const string& export_dir,
std::unique_ptr<Session>* session) {
const uint64 read_start_microseconds = Env::Default()->NowMicros();
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
meta_graph.saver_def().restore_op_name(),
meta_graph.saver_def().filename_tensor_name(),
asset_file_defs, session->get()));
// Record walltime spent in restoring graph from disk, but postpone metric
// increments until graph init finishes.
const uint64 restore_graph_walltime =
GetLatencyMicroseconds(read_start_microseconds);
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
string init_op_name;
TF_RETURN_IF_ERROR(
internal::GetInitOp(export_dir, meta_graph, &init_op_name));
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, meta_graph,
asset_file_defs, session->get(), init_op_name));
load_latency_by_stage->GetCell(export_dir, "restore_graph")
->Add(restore_graph_walltime);
// Record wall time spent in init op.
load_latency_by_stage->GetCell(export_dir, "init_graph")
->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
return Status::OK();
}
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,

View File

@ -96,6 +96,21 @@ class SavedModelBundleLite : public SavedModelBundleInterface {
protobuf::Map<string, SignatureDef> signatures_;
};
// Restore variable and resources in the SavedModel export dir for the
// indicated metagraph.
// The recommended way to load a saved model is to call LoadSavedModel,
// which provides an already initialized Metagraph, Session, and DebugInfo.
Status RestoreSession(const RunOptions& run_options,
const MetaGraphDef& meta_graph, const string& export_dir,
std::unique_ptr<Session>* session);
// Initialize a session which wraps this metagraph.
// The recommended way to load a saved model is to call LoadSavedModel,
// which provides an already initialized Metagraph, Session, and DebugInfo.
Status LoadMetagraphIntoSession(const SessionOptions& session_options,
const MetaGraphDef& meta_graph,
std::unique_ptr<Session>* session);
/// Loads a SavedModel from the specified export directory. The MetaGraphDef
/// to be loaded is identified by the supplied tags, corresponding exactly to
/// the set of tags used at SavedModel build time. Stores a SavedModel bundle in

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <unordered_set>
#include "absl/memory/memory.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -86,4 +87,22 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
return Status::OK();
}
Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
LOG(INFO) << "Reading SavedModel debug info (if present) from: "
<< export_dir;
const string debug_info_pb_path =
io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
GraphDebugInfo debug_info;
TF_RETURN_IF_ERROR(
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
*debug_info_proto =
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
@ -34,6 +35,11 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
const std::unordered_set<string>& tags,
MetaGraphDef* const meta_graph_def);
// Store debug info from the SavedModel export dir.
Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto);
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_

View File

@ -106,5 +106,11 @@ TEST_F(ReaderTest, InvalidExportPath) {
EXPECT_FALSE(st.ok());
}
TEST_F(ReaderTest, ReadSavedModelDebugInfoIfPresent) {
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
std::unique_ptr<GraphDebugInfo> debug_info_proto;
TF_ASSERT_OK(ReadSavedModelDebugInfoIfPresent(export_dir, &debug_info_proto));
}
} // namespace
} // namespace tensorflow

View File

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/example/example.pb.h"
@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
namespace {
@ -131,6 +132,43 @@ TEST_F(LoaderTest, TagMatch) {
CheckSavedModelBundle(export_dir, bundle);
}
TEST_F(LoaderTest, ReadMetaGraphFromSavedModel) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
MetaGraphDef actual_metagraph;
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&actual_metagraph));
EXPECT_EQ(actual_metagraph.DebugString(),
bundle.meta_graph_def.DebugString());
}
TEST_F(LoaderTest, RestoreSession) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
SavedModelBundle actual_bundle;
const std::unordered_set<std::string> tags = {kSavedModelTagServe};
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, tags,
&actual_bundle.meta_graph_def));
TF_ASSERT_OK(LoadMetagraphIntoSession(
session_options, actual_bundle.meta_graph_def, &actual_bundle.session));
TF_ASSERT_OK(RestoreSession(run_options, actual_bundle.meta_graph_def,
export_dir, &actual_bundle.session));
CheckSavedModelBundle(export_dir, actual_bundle);
}
TEST_F(LoaderTest, NoTagMatch) {
SavedModelBundle bundle;
RunOptions run_options;

Some files were not shown because too many files have changed in this diff Show More