Merge branch 'master' into list_keys
This commit is contained in:
commit
ff1ef12025
19
.bazelrc
19
.bazelrc
@ -461,12 +461,12 @@ build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
|
||||
|
||||
# Map default to CUDA 10.1.
|
||||
# Map default to CUDA 11 for PY35 and greater.
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda11.0_nvcc_py3.5
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.0_nvcc_py3.6
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.0_nvcc_py3.7
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.0_nvcc_py3.8
|
||||
|
||||
# Deprecated configs that people might still use.
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
|
||||
@ -583,9 +583,9 @@ build:release_cpu_macos --config=avx_linux
|
||||
build:release_gpu_common --config=release_common
|
||||
build:release_gpu_common --config=cuda
|
||||
build:release_gpu_common --config=tensorrt
|
||||
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
|
||||
build:release_gpu_common --action_env=TF_CUDA_VERSION="10"
|
||||
build:release_gpu_common --action_env=TF_CUDNN_VERSION="7"
|
||||
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0"
|
||||
build:release_gpu_common --action_env=TF_CUDA_VERSION="11"
|
||||
build:release_gpu_common --action_env=TF_CUDNN_VERSION="8"
|
||||
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
|
||||
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70"
|
||||
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
|
||||
@ -595,8 +595,7 @@ build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
|
||||
|
||||
build:release_gpu_linux --config=release_gpu_common
|
||||
build:release_gpu_linux --config=avx_linux
|
||||
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
|
||||
|
||||
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain
|
||||
build:release_windows_common --config=release_common
|
||||
build:release_windows_common --define=no_tensorflow_py_deps=true
|
||||
build:release_windows_common --announce_rc
|
||||
|
20
RELEASE.md
20
RELEASE.md
@ -81,6 +81,12 @@
|
||||
server and set `dispatcher_fault_tolerance=True`. The dispatcher will
|
||||
store its state to `work_dir`, so that on restart it can continue from its
|
||||
previous state after restart.
|
||||
* Added tf.data service support for sharing dataset graphs via shared
|
||||
filesystem instead of over RPC. This reduces load on the dispatcher,
|
||||
improving performance of distributing datasets. For this to work, the
|
||||
dispatcher's `work_dir` must be accessible from workers. If the worker
|
||||
fails to read from the `work_dir`, it falls back to using RPC for dataset
|
||||
graph transfer.
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
* We have implemented an optimization which reorders data-discarding
|
||||
@ -88,6 +94,7 @@
|
||||
dataset when it is safe to do so. The optimization can be disabled via
|
||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||
option.
|
||||
* `tf.data.Options` were previously immutable and can now be overriden.
|
||||
* `tf.image`:
|
||||
* Added deterministic `tf.image.stateless_random_*` functions for each
|
||||
`tf.image.random_*` function. Added a new op
|
||||
@ -106,7 +113,8 @@
|
||||
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
|
||||
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
||||
as an alternative to accepting a `callable` loss.
|
||||
* Added `beta` parameter to FTRL optimizer to match paper.
|
||||
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
||||
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
|
||||
* Added `mobilenet_v3` to keras application model.
|
||||
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
||||
customization of how gradients are aggregated across devices, as well as
|
||||
@ -155,6 +163,14 @@
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.train.Checkpoint`:
|
||||
* Now accepts a `root` argument in the initialization, which generates a
|
||||
checkpoint with a root object. This allows users to create a `Checkpoint`
|
||||
object that is compatible with Keras `model.save_weights()` and
|
||||
`model.load_weights`. The checkpoint is also compatible with the
|
||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||
will automatically find the checkpoint in the SavedModel.
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
@ -251,6 +267,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
* Mutable tables now restore checkpointed values when loaded from SavedModel.
|
||||
* GPU
|
||||
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
|
||||
* Remove environmental variable `TF_USE_CUDNN`.
|
||||
* Others
|
||||
* Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`.
|
||||
* Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations.
|
||||
@ -1582,6 +1599,7 @@ Yuan (Terry) Tang, Yuchen Ying, Yves-Noel Weweler, zhangyujing, zjjott, zyeric,
|
||||
color palette of the frame. This has been fixed now
|
||||
* image.resize now considers proper pixel centers and has new kernels
|
||||
(incl. anti-aliasing).
|
||||
* Added an isotonic regression solver (tf.nn.isotonic_regression).
|
||||
* Performance
|
||||
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically
|
||||
dispatches the best kernel implementation based on CPU vector
|
||||
|
@ -58,9 +58,9 @@ filegroup(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
cc_library(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
textual_hdrs = [
|
||||
"c_api_internal.h",
|
||||
"c_api_macros.h",
|
||||
"conversion_macros.h",
|
||||
@ -220,6 +220,7 @@ cc_library(
|
||||
name = "logging",
|
||||
srcs = ["logging.cc"],
|
||||
hdrs = ["logging.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_macros",
|
||||
"//tensorflow/core/platform:logging",
|
||||
|
@ -240,6 +240,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:array_grad",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
@ -255,6 +256,72 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mnist_gradients_testutil",
|
||||
srcs = [
|
||||
"mnist_gradients_testutil.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"mnist_gradients_testutil.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":tape",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "mnist_gradients_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"mnist_gradients_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":mnist_gradients_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/gradients:nn_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_tensor_handle",
|
||||
hdrs = ["abstract_tensor_handle.h"],
|
||||
|
@ -30,18 +30,26 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
// TODO(b/162618595): Enable this test once we remove the check of remote
|
||||
// outputs in ProcessFunctionLibraryRuntime.
|
||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
|
@ -169,6 +169,13 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
|
||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
if (remote_func_outputs) {
|
||||
const string backing_device =
|
||||
TFE_TensorHandleBackingDeviceName(retvals[0], status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(backing_device, task2_name);
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
@ -102,6 +102,32 @@ TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
|
||||
int64_t dims[], int num_dims) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
|
||||
int64_t dims[], int num_dims) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
|
||||
constexpr int64_t dims[] = {100, 100};
|
||||
constexpr int num_elements = dims[0] * dims[1];
|
||||
|
@ -40,6 +40,14 @@ TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
|
||||
float data[], int64_t dims[],
|
||||
int num_dims);
|
||||
|
||||
// Get a Matrix TensorHandle with given float values and dimensions
|
||||
TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
|
||||
int64_t dims[], int num_dims);
|
||||
|
||||
// Get a Matrix TensorHandle with given int values and dimensions
|
||||
TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
|
||||
int64_t dims[], int num_dims);
|
||||
|
||||
// Return a tensor handle containing a 100x100 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);
|
||||
|
||||
|
@ -85,7 +85,11 @@ class GraphOperation : public TracingOperation {
|
||||
return errors::FailedPrecondition(
|
||||
"GraphOperation::Reset must be called before calling SetOpName.");
|
||||
}
|
||||
op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name));
|
||||
// TODO(b/145674566): We use Graph::NewName to get a unique name here but
|
||||
// this may not be consistent with python's naming policy.
|
||||
mutex_lock l(g_->mu);
|
||||
op_.reset(new TF_OperationDescription(g_, op_type_.c_str(),
|
||||
g_->graph.NewName(op_name).c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
const string& Name() const override { return op_type_; }
|
||||
|
@ -557,7 +557,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add1", s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
@ -579,7 +579,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add2", s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
@ -23,25 +24,97 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
Status GradientRegistry::Register(const string& op_name,
|
||||
GradientFunctionFactory factory) {
|
||||
namespace {
|
||||
Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
|
||||
AbstractTensorHandle** result) {
|
||||
AbstractOperationPtr op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("ZerosLike", ToId(t)).c_str()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(op->AddInput(t));
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
TF_RETURN_IF_ERROR(
|
||||
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
|
||||
*result = outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class IncomingGradientsImpl : public IncomingGradients {
|
||||
public:
|
||||
explicit IncomingGradientsImpl(
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs, Context* ctx,
|
||||
DefaultGradientFunction* default_gradients)
|
||||
: grad_inputs_(grad_inputs),
|
||||
ctx_(ctx),
|
||||
default_gradients_(default_gradients) {}
|
||||
AbstractTensorHandle* operator[](int i) const override {
|
||||
return default_gradients_->get(ctx_, grad_inputs_, i);
|
||||
}
|
||||
size_t size() const override { return grad_inputs_.size(); }
|
||||
|
||||
private:
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs_;
|
||||
Context* ctx_;
|
||||
DefaultGradientFunction* default_gradients_;
|
||||
};
|
||||
|
||||
AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op)
|
||||
: outputs_(op.outputs) {
|
||||
for (auto output : outputs_) {
|
||||
output->Ref();
|
||||
}
|
||||
}
|
||||
AbstractTensorHandle* AllZerosDefaultGradients::get(
|
||||
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
|
||||
if (grad_inputs[i]) {
|
||||
return grad_inputs[i];
|
||||
}
|
||||
if (cached_default_grads_[i]) {
|
||||
return cached_default_grads_[i].get();
|
||||
}
|
||||
AbstractTensorHandle* result = nullptr;
|
||||
Status s = ZerosLike(ctx->ctx, outputs_[i], &result);
|
||||
if (!s.ok()) {
|
||||
if (result) {
|
||||
result->Unref();
|
||||
}
|
||||
VLOG(1) << "Failed to create ZerosLike for index " << i;
|
||||
return nullptr;
|
||||
}
|
||||
cached_default_grads_[i].reset(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
PassThroughDefaultGradients::PassThroughDefaultGradients(
|
||||
const ForwardOperation& op) {}
|
||||
AbstractTensorHandle* PassThroughDefaultGradients::get(
|
||||
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
|
||||
return grad_inputs[i];
|
||||
}
|
||||
|
||||
Status GradientRegistry::Register(
|
||||
const string& op_name, BackwardFunctionFactory backward_function_factory) {
|
||||
auto iter = registry_.find(op_name);
|
||||
if (iter != registry_.end()) {
|
||||
const string error_msg = "Gradient already exists for op: " + op_name + ".";
|
||||
return errors::AlreadyExists(error_msg);
|
||||
}
|
||||
registry_.insert({op_name, factory});
|
||||
registry_.insert({op_name, backward_function_factory});
|
||||
return Status::OK();
|
||||
}
|
||||
Status GradientRegistry::Lookup(
|
||||
const ForwardOperation& op,
|
||||
std::unique_ptr<GradientFunction>* grad_fn) const {
|
||||
std::unique_ptr<BackwardFunction>* backward_function) const {
|
||||
auto iter = registry_.find(op.op_name);
|
||||
if (iter == registry_.end()) {
|
||||
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
|
||||
return errors::NotFound(error_msg);
|
||||
}
|
||||
grad_fn->reset(iter->second(op));
|
||||
backward_function->reset(iter->second(op));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -92,33 +165,8 @@ AbstractTensorHandle* TapeTensor::OnesLike() const {
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
AbstractTensorHandle* TapeTensor::ZerosLike() const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
|
||||
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("ZerosLike", ToId(handle_)).c_str());
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
s = op->AddInput(handle_);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
int num_outputs = 1;
|
||||
// TODO(srbs): Figure out who is in charge of releasing this.
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
|
||||
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
|
||||
|
||||
// Returns the number of elements in the gradient tensor.
|
||||
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
|
||||
@ -159,13 +207,16 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients(
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
Status TapeVSpace::CallBackwardFunction(
|
||||
GradientFunction* backward_function,
|
||||
BackwardFunction* backward_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const {
|
||||
if (backward_function == nullptr) return Status::OK();
|
||||
Context ctx = {ctx_};
|
||||
return backward_function->Compute(&ctx, output_gradients, result);
|
||||
IncomingGradientsImpl incoming_gradients(
|
||||
output_gradients, &ctx, backward_function->GetDefaultGradientFunction());
|
||||
return backward_function->GetGradientFunction()->Compute(
|
||||
&ctx, incoming_gradients, result);
|
||||
}
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
@ -373,15 +424,15 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
|
||||
}
|
||||
tape->RecordOperation(
|
||||
op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
[registry, forward_op_]() -> GradientFunction* {
|
||||
std::unique_ptr<GradientFunction> grad_fn;
|
||||
Status s = registry.Lookup(*forward_op_, &grad_fn);
|
||||
[registry, forward_op_]() -> BackwardFunction* {
|
||||
std::unique_ptr<BackwardFunction> backward_fn;
|
||||
Status s = registry.Lookup(*forward_op_, &backward_fn);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return grad_fn.release();
|
||||
return backward_fn.release();
|
||||
},
|
||||
[](GradientFunction* ptr) {
|
||||
[](BackwardFunction* ptr) {
|
||||
if (ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
@ -55,18 +55,25 @@ struct Context {
|
||||
public:
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
|
||||
class IncomingGradients {
|
||||
public:
|
||||
virtual AbstractTensorHandle* operator[](int i) const = 0;
|
||||
virtual size_t size() const = 0;
|
||||
virtual ~IncomingGradients() {}
|
||||
};
|
||||
|
||||
class GradientFunction {
|
||||
public:
|
||||
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
|
||||
// `grad_inputs`.
|
||||
virtual Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
|
||||
virtual ~GradientFunction() {}
|
||||
};
|
||||
|
||||
// Metadata from the forward operation that is made available to the
|
||||
// gradient registerer to instantiate a GradientFunction.
|
||||
// gradient registerer to instantiate a BackwardFunction.
|
||||
struct ForwardOperation {
|
||||
public:
|
||||
string op_name;
|
||||
@ -76,18 +83,86 @@ struct ForwardOperation {
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
|
||||
using GradientFunctionFactory =
|
||||
std::function<GradientFunction*(const ForwardOperation& op)>;
|
||||
|
||||
// Map from op name to a `GradientFunctionFactory`.
|
||||
class GradientRegistry {
|
||||
// Interface for building default zeros gradients for op outputs which are
|
||||
// missing incoming gradients. Custom implementations of this can be used to
|
||||
// control which of the forward op's output tensors/their metadata needs to
|
||||
// be kept around in memory to build the default zeros grad.
|
||||
//
|
||||
// Some common helper implementations are provided below.
|
||||
class DefaultGradientFunction {
|
||||
public:
|
||||
Status Register(const string& op, GradientFunctionFactory factory);
|
||||
Status Lookup(const ForwardOperation& op,
|
||||
std::unique_ptr<GradientFunction>* grad_fn) const;
|
||||
virtual AbstractTensorHandle* get(
|
||||
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
int i) = 0;
|
||||
virtual ~DefaultGradientFunction() {}
|
||||
};
|
||||
|
||||
// Returns zeros for any `nullptr` in `grad_inputs`.
|
||||
//
|
||||
// This may require keeping track of all of forward op's output
|
||||
// tensors and hence may incur a higher memory footprint. Use sparingly.
|
||||
//
|
||||
// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor
|
||||
// handle.
|
||||
//
|
||||
// The destructor of this class `Unref`'s any cached tensor handles so users of
|
||||
// those tensor handles should `Ref` them in order to keep them alive if needed.
|
||||
class AllZerosDefaultGradients : public DefaultGradientFunction {
|
||||
public:
|
||||
explicit AllZerosDefaultGradients(const ForwardOperation& op);
|
||||
AbstractTensorHandle* get(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
int i) override;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
|
||||
// TODO(srbs): We do not always need to keep the tensors around. In immediate
|
||||
// execution mode we just need to store the shape and dtype. During tracing
|
||||
// we may need to keep the tensor around if the shape is not full defined.
|
||||
std::vector<AbstractTensorHandle*> outputs_;
|
||||
std::vector<AbstractTensorHandlePtr> cached_default_grads_;
|
||||
};
|
||||
|
||||
// Passes through `grad_inputs` as-is. The `GradientFunction`
|
||||
// will be expected to deal with nullptr in `grad_inputs` if any.
|
||||
class PassThroughDefaultGradients : public DefaultGradientFunction {
|
||||
public:
|
||||
explicit PassThroughDefaultGradients(const ForwardOperation& op);
|
||||
AbstractTensorHandle* get(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
int i) override;
|
||||
};
|
||||
|
||||
// A `BackwardFunction` wraps a `GradientFunction` and a
|
||||
// `DefaultGradientFunction`. Both are owned by this class' instance.
|
||||
class BackwardFunction {
|
||||
public:
|
||||
BackwardFunction(GradientFunction* gradient_function,
|
||||
DefaultGradientFunction* default_gradients)
|
||||
: gradient_function_(gradient_function),
|
||||
default_gradients_(default_gradients) {}
|
||||
GradientFunction* GetGradientFunction() { return gradient_function_.get(); }
|
||||
DefaultGradientFunction* GetDefaultGradientFunction() {
|
||||
return default_gradients_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<GradientFunction> gradient_function_;
|
||||
std::unique_ptr<DefaultGradientFunction> default_gradients_;
|
||||
};
|
||||
|
||||
using BackwardFunctionFactory =
|
||||
std::function<BackwardFunction*(const ForwardOperation& op)>;
|
||||
|
||||
// Map from op name to a `BackwardFunctionFactory`.
|
||||
class GradientRegistry {
|
||||
public:
|
||||
Status Register(const string& op,
|
||||
BackwardFunctionFactory backward_function_factory);
|
||||
Status Lookup(const ForwardOperation& op,
|
||||
std::unique_ptr<BackwardFunction>* backward_function) const;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, BackwardFunctionFactory> registry_;
|
||||
};
|
||||
|
||||
// Returns a unique id for the tensor which is used by the tape to build
|
||||
@ -106,9 +181,16 @@ int64 ToId(AbstractTensorHandle* t);
|
||||
// allow us to trace the data dependencies between operations and hence compute
|
||||
// gradients.
|
||||
//
|
||||
// This also implements `ZerosLike` and `OnesLike` to create the default
|
||||
// This also implements `OnesLike` to create the default
|
||||
// incoming gradients for tensors which do not already have an incoming
|
||||
// gradient.
|
||||
//
|
||||
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
||||
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
||||
// for each op.
|
||||
// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy.
|
||||
// Figure out a way to avoid this.
|
||||
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
|
||||
class TapeTensor {
|
||||
public:
|
||||
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
|
||||
@ -123,7 +205,7 @@ class TapeTensor {
|
||||
|
||||
private:
|
||||
AbstractTensorHandle* handle_;
|
||||
// The context where OnesLike and ZerosLike ops are to be created.
|
||||
// The context where OnesLike ops are to be created.
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
@ -132,7 +214,7 @@ class TapeTensor {
|
||||
// gradient and for performing gradient aggregation.
|
||||
// See `tensorflow::eager::VSpace` for more details.
|
||||
class TapeVSpace
|
||||
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
|
||||
: public eager::VSpace<AbstractTensorHandle, BackwardFunction, TapeTensor> {
|
||||
public:
|
||||
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
~TapeVSpace() override {}
|
||||
@ -147,7 +229,7 @@ class TapeVSpace
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
Status CallBackwardFunction(
|
||||
GradientFunction* backward_function,
|
||||
BackwardFunction* backward_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const override;
|
||||
@ -168,8 +250,14 @@ class TapeVSpace
|
||||
};
|
||||
|
||||
// A tracing/immediate-execution agnostic tape.
|
||||
//
|
||||
// Gradient functions defined for this library support handling null incoming
|
||||
// gradients. `Tape::ComputeGradient` should be called with
|
||||
// `build_default_zeros_grads=false`. Calling with
|
||||
// `build_default_zeros_grads=true` (the default) is equivalent but just results
|
||||
// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway.
|
||||
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
|
||||
GradientFunction, TapeTensor>;
|
||||
BackwardFunction, TapeTensor>;
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
@ -50,6 +51,7 @@ class CppGradients
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -94,6 +96,26 @@ Status Exp(AbstractContext* ctx, Tape* tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `IdentityN(inputs)` and records it on the tape.
|
||||
Status IdentityN(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(identity_n_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
|
||||
->SetOpName("my_identity_n"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
|
||||
int num_retvals = outputs.size();
|
||||
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
|
||||
tape, registry);
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
@ -116,7 +138,8 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Unref();
|
||||
}
|
||||
@ -146,7 +169,8 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto exp_output : exp_outputs) {
|
||||
exp_output->Unref();
|
||||
}
|
||||
@ -155,6 +179,41 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// ignored, y = IdentityN(inputs[0], inputs[1])
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
// This should return [nullptr, 1].
|
||||
Status IdentityNGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
tape->Watch(ToId(inputs[1]));
|
||||
|
||||
vector<AbstractTensorHandle*> identity_n_outputs(2);
|
||||
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs,
|
||||
absl::MakeSpan(identity_n_outputs), registry));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto identity_n_output : identity_n_outputs) {
|
||||
identity_n_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -389,13 +448,72 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
// TODO(b/160888630): Enable this test with mlir after AddInputList is
|
||||
// supported. It is needed for AddN op which is used for gradient aggregation.
|
||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x1)
|
||||
// tape.watch(x2)
|
||||
// unused, y = IdentityN([x1, x2])
|
||||
// outputs = tape.gradient(y, [x1, x2])
|
||||
// Expected: [nullptr, 1]
|
||||
//
|
||||
// This test is interesting because the current implementation of GradientTape
|
||||
// would return [0, 1] whereas we use build_default_zeros_grads=false here
|
||||
// so we get back [nullptr, 1].
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x1;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x1.reset(x_raw);
|
||||
}
|
||||
AbstractTensorHandlePtr x2;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x2.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
EXPECT_EQ(outputs[0], nullptr);
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
|
||||
// supported. It is needed for IdentityN.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(true, false),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
|
@ -57,15 +57,10 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
|
||||
// Create a tensor instance from the given data buffer and description.
|
||||
// `memory_releaser` will be called on destruction, and it's responsible for
|
||||
// cleaning up the underlying buffer. `convert_string` indicates whether it
|
||||
// has to handle tstring conversion. Expected to be removed once tstring
|
||||
// migration is done.
|
||||
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
|
||||
const int64_t* dims,
|
||||
int num_dims, void* data,
|
||||
size_t len, bool convert_string,
|
||||
MemoryReleaser memory_releaser,
|
||||
void* memory_releaser_arg) = 0;
|
||||
// cleaning up the underlying buffer.
|
||||
virtual AbstractTensorInterface* CreateTensor(
|
||||
DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
|
||||
MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
|
||||
|
781
tensorflow/c/eager/mnist_gradients_test.cc
Normal file
781
tensorflow/c/eager/mnist_gradients_test.cc
Normal file
@ -0,0 +1,781 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
}
|
||||
};
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
|
||||
TF_RETURN_IF_ERROR(
|
||||
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||
SparseSoftmaxCrossEntropyLossRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ========================= Test Util Functions ==============================
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a Matrix TensorHandle with given float values and dimensions
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a Matrix TensorHandle with given int values and dimensions
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
// =========================== Start Tests ================================
|
||||
|
||||
TEST_P(CppGradients, TestMatMulGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t A_dims[] = {2, 2};
|
||||
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
|
||||
int64_t B_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
|
||||
AbstractTensorHandlePtr A =
|
||||
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
|
||||
AbstractTensorHandlePtr B =
|
||||
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
*
|
||||
* tape.watch(A)
|
||||
* tape.watch(B)
|
||||
* Y = AB
|
||||
* outputs = tape.gradient(Y, [A, B])
|
||||
*/
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dA_tensor;
|
||||
s = GetValue(outputs[0], &dA_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dA_tensor),
|
||||
TF_TensorByteSize(dA_tensor));
|
||||
|
||||
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
|
||||
}
|
||||
|
||||
TF_Tensor* dB_tensor;
|
||||
s = GetValue(outputs[1], &dB_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
memcpy(&result_data[0], TF_TensorData(dB_tensor),
|
||||
TF_TensorByteSize(dB_tensor));
|
||||
|
||||
float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f};
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(dA_tensor);
|
||||
TF_DeleteTensor(dB_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNISTForward) {
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims);
|
||||
|
||||
// W1 = first weights
|
||||
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
|
||||
AbstractTensorHandlePtr W1 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
|
||||
|
||||
// W2 = second weights
|
||||
float W2_vals[] = {.1f, .2f, .3f, -.5f};
|
||||
AbstractTensorHandlePtr W2 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
|
||||
|
||||
// y = labels
|
||||
int y_vals[] = {1, 1};
|
||||
int64_t dims_y[] = {2};
|
||||
num_dims = sizeof(dims_y) / sizeof(dims_y[0]);
|
||||
AbstractTensorHandlePtr y =
|
||||
GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
|
||||
// Run the Forward Pass
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
Status s =
|
||||
RunModel(MNISTForwardModel, ctx.get(),
|
||||
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Verify the Results
|
||||
TF_Tensor* scores_tensor;
|
||||
s = GetValue(outputs[0], &scores_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(scores_tensor),
|
||||
TF_TensorByteSize(scores_tensor));
|
||||
|
||||
float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
|
||||
}
|
||||
|
||||
TF_Tensor* loss_vals_tensor;
|
||||
s = GetValue(outputs[1], &loss_vals_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
|
||||
TF_TensorByteSize(loss_vals_tensor));
|
||||
float expected_losses[2] = {9.6f, 27.2f};
|
||||
for (int j = 0; j < 2; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(scores_tensor);
|
||||
TF_DeleteTensor(loss_vals_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNISTForward2) {
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
int64_t X_dims[] = {3, 2};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
// W1 = first weights
|
||||
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
|
||||
int64_t dims[] = {2, 2};
|
||||
AbstractTensorHandlePtr W1 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
|
||||
|
||||
// W2 = second weights
|
||||
float W2_vals[] = {.1f, .2f, .3f, -.5f};
|
||||
AbstractTensorHandlePtr W2 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
|
||||
|
||||
// y = labels
|
||||
int y_vals[] = {1, 1, 1};
|
||||
int64_t y_dims[] = {3};
|
||||
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
|
||||
AbstractTensorHandlePtr y =
|
||||
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
|
||||
// Run the Forward Pass
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
Status s =
|
||||
RunModel(MNISTForwardModel, ctx.get(),
|
||||
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Verify the Results
|
||||
TF_Tensor* scores_tensor;
|
||||
s = GetValue(outputs[0], &scores_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[6] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(scores_tensor),
|
||||
TF_TensorByteSize(scores_tensor));
|
||||
|
||||
float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 6; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
|
||||
}
|
||||
|
||||
TF_Tensor* loss_vals_tensor;
|
||||
s = GetValue(outputs[1], &loss_vals_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
|
||||
TF_TensorByteSize(loss_vals_tensor));
|
||||
float expected_losses[3] = {9.6f, 27.2f, 44.8f};
|
||||
for (int j = 0; j < 3; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(scores_tensor);
|
||||
TF_DeleteTensor(loss_vals_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMatMulTranspose) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
int64_t X_dims[] = {2, 3};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
// W1 = first weights
|
||||
float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t dims[] = {2, 2};
|
||||
AbstractTensorHandlePtr W1 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
|
||||
// Run the MatMul Op
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
|
||||
Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Verify the Results
|
||||
TF_Tensor* scores_tensor;
|
||||
s = GetValue(outputs[0], &scores_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[6] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(scores_tensor),
|
||||
TF_TensorByteSize(scores_tensor));
|
||||
|
||||
float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 6; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestReluGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
|
||||
int64_t X_dims[] = {3, 3};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
*
|
||||
* tape.watch(X)
|
||||
* Y = Relu(X)
|
||||
* outputs = tape.gradient(Y, [X])
|
||||
*/
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(ReluGradModel, ctx.get(), {X.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dX_tensor;
|
||||
s = GetValue(outputs[0], &dX_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[9] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dX_tensor),
|
||||
TF_TensorByteSize(dX_tensor));
|
||||
|
||||
float expected_dX[9] = {1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 9; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(dX_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSoftmaxLossGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = scores
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
|
||||
int64_t X_dims[] = {3, 3};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
// y = labels
|
||||
int y_vals[] = {1, 0, 1};
|
||||
int64_t y_dims[] = {3};
|
||||
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
|
||||
AbstractTensorHandlePtr y =
|
||||
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
*
|
||||
* tape.watch(X)
|
||||
* tape.watch(labels)
|
||||
* loss = SoftmaxLoss(X, labels)
|
||||
* outputs = tape.gradient(loss, [X, labels])
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(SoftmaxLossGradModel, ctx.get(), {X.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dX_tensor;
|
||||
s = GetValue(outputs[0], &dX_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[9] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dX_tensor),
|
||||
TF_TensorByteSize(dX_tensor));
|
||||
|
||||
float expected_dX[9] = {0.090f, -0.7553f, 0.6652f, -0.9099f, 0.2447f,
|
||||
0.6652f, 0.8437f, -0.8858f, 0.0420f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 9; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
|
||||
}
|
||||
|
||||
// Only Unref() first output as 2nd is nullptr grad for labels
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(dX_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNISTGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t X_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
// W1 = first weights
|
||||
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
|
||||
int64_t dims[] = {2, 2};
|
||||
AbstractTensorHandlePtr W1 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
|
||||
|
||||
// W2 = second weights
|
||||
float W2_vals[] = {.1f, .2f, .3f, -.5f};
|
||||
AbstractTensorHandlePtr W2 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
|
||||
|
||||
// y = labels
|
||||
int y_vals[] = {1, 1};
|
||||
int64_t y_dims[] = {2};
|
||||
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
|
||||
AbstractTensorHandlePtr y =
|
||||
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
|
||||
|
||||
// Register Grads
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
*
|
||||
*
|
||||
* tape.watch(W1)
|
||||
* tape.watch(W2)
|
||||
* mm = X*W1
|
||||
* hidden = Relu(mm)
|
||||
* scores = W2*hidden
|
||||
* loss = SoftmaxLoss(scores, y)
|
||||
* outputs = tape.gradient(loss, [A, B])
|
||||
*
|
||||
*/
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(3);
|
||||
s = RunModel(MNISTGradModel, ctx.get(),
|
||||
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float tolerance = 1e-3;
|
||||
TF_Tensor* dW1_tensor;
|
||||
s = GetValue(outputs[0], &dW1_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dW1_tensor),
|
||||
TF_TensorByteSize(dW1_tensor));
|
||||
|
||||
float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f};
|
||||
; // dLoss
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance);
|
||||
}
|
||||
|
||||
TF_Tensor* dW2_tensor;
|
||||
s = GetValue(outputs[1], &dW2_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
memcpy(&result_data[0], TF_TensorData(dW2_tensor),
|
||||
TF_TensorByteSize(dW2_tensor));
|
||||
|
||||
float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f}; // dLoss
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
outputs[2]->Unref();
|
||||
TF_DeleteTensor(dW1_tensor);
|
||||
TF_DeleteTensor(dW2_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestScalarMul) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr eta;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
eta.reset(x_raw);
|
||||
}
|
||||
|
||||
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t A_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
|
||||
AbstractTensorHandlePtr A =
|
||||
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dA_tensor;
|
||||
s = GetValue(outputs[0], &dA_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dA_tensor),
|
||||
TF_TensorByteSize(dA_tensor));
|
||||
|
||||
float tolerance = 1e-3;
|
||||
float eta_val = 1.5f;
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(dA_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNIST_Training) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t X_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
// TODO(amturati): use random initializer for weights instead of
|
||||
// constant values.
|
||||
|
||||
// W1 = first weights
|
||||
float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f};
|
||||
int64_t dims[] = {2, 2};
|
||||
AbstractTensorHandlePtr W1 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
|
||||
|
||||
// W2 = second weights
|
||||
float W2_vals[] = {.1f, .2f, .3f, -.5f};
|
||||
AbstractTensorHandlePtr W2 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
|
||||
|
||||
// y = labels
|
||||
int y_vals[] = {1, 1};
|
||||
int64_t y_dims[] = {2};
|
||||
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
|
||||
AbstractTensorHandlePtr y =
|
||||
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
|
||||
|
||||
// Register Grads
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Prepare for training
|
||||
std::vector<AbstractTensorHandle*> weights;
|
||||
weights.push_back(W1.get());
|
||||
weights.push_back(W2.get());
|
||||
|
||||
// Set learning rate to be 1e-1
|
||||
AbstractTensorHandle* learning_rate = nullptr;
|
||||
s = TestScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Train
|
||||
int num_iters = 10;
|
||||
std::vector<AbstractTensorHandle*> mnist_outputs(3);
|
||||
std::vector<AbstractTensorHandle*> grads(2);
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
// Run Forward Pass
|
||||
s = RunModel(MNISTGradModel, ctx.get(),
|
||||
{X.get(), weights[0], weights[1], y.get()},
|
||||
absl::MakeSpan(mnist_outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Fill grads
|
||||
grads[0] = mnist_outputs[0];
|
||||
grads[1] = mnist_outputs[1];
|
||||
|
||||
// Gradient Update
|
||||
s = UpdateWeights(ctx.get(), grads, weights, learning_rate);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
grads[0]->Unref(); // release W1_grad
|
||||
grads[1]->Unref(); // release W2_grad
|
||||
mnist_outputs[2]->Unref(); // release loss
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
594
tensorflow/c/eager/mnist_gradients_testutil.cc
Normal file
594
tensorflow/c/eager/mnist_gradients_testutil.cc
Normal file
@ -0,0 +1,594 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
using std::vector;
|
||||
using tracing::TracingOperation;
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(matmul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
|
||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
||||
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
|
||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
||||
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
|
||||
|
||||
int num_retvals = 1;
|
||||
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(mul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
|
||||
|
||||
int num_retvals = 1;
|
||||
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(relu_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
|
||||
// tape.
|
||||
Status SparseSoftmaxCrossEntropyLoss(
|
||||
AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* scores = inputs[0];
|
||||
AbstractTensorHandle* labels = inputs[1];
|
||||
|
||||
AbstractOperationPtr sm_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(sm_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
|
||||
|
||||
int num_retvals = 2; // returns loss values and backprop
|
||||
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
//===================== Test Models to run =========================
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
||||
registry)); // Compute x+y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] * inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status MatMulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
vector<AbstractTensorHandle*> mm_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute x*y.
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto mm_output : mm_outputs) {
|
||||
mm_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Model to run 2-layer net
|
||||
Status MNISTForwardModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
/**
|
||||
* We will trace a 2-layer fully connected network for an MNIST model:
|
||||
*
|
||||
* def mnist_forward(X, W1, W2, y_labels):
|
||||
* mm_out_1 = tf.matmul(X,W1)
|
||||
* hidden_layer = tf.nn.relu(mm_out_1)
|
||||
* scores = tf.matmul(hidden_layer,W2)
|
||||
* softmax =
|
||||
* tf.nn.sparse_softmax_cross_entropy_with_logits(scores,y_labels) return
|
||||
* scores, softmax
|
||||
*
|
||||
* Use this convention for inputs:
|
||||
*
|
||||
* inputs = [X, W1, W2, y_labels]
|
||||
*
|
||||
*/
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
AbstractTensorHandle* W2 = inputs[2];
|
||||
AbstractTensorHandle* y_labels = inputs[3];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(W1)); // Watch W1.
|
||||
tape->Watch(ToId(W2)); // Watch W2.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
|
||||
absl::MakeSpan(temp_outputs), "relu",
|
||||
registry)); // Compute Relu(X*W1)
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
|
||||
absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
||||
registry)); // Compute W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
|
||||
|
||||
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
||||
|
||||
outputs[0] = scores;
|
||||
outputs[1] = loss_vals;
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMulTransposeModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(X));
|
||||
tape->Watch(ToId(W1));
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/true,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReluGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch X
|
||||
vector<AbstractTensorHandle*> relu_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
|
||||
"relu0", registry)); // Relu(X)
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(relu_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
|
||||
for (auto relu_output : relu_outputs) {
|
||||
relu_output->Unref();
|
||||
}
|
||||
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SoftmaxLossGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch scores.
|
||||
tape->Watch(ToId(inputs[1])); // Watch labels.
|
||||
vector<AbstractTensorHandle*> sm_outputs(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sm_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MNISTGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
AbstractTensorHandle* W2 = inputs[2];
|
||||
AbstractTensorHandle* y_labels = inputs[3];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/true);
|
||||
tape->Watch(ToId(X)); // Watch X.
|
||||
tape->Watch(ToId(W1)); // Watch W1.
|
||||
tape->Watch(ToId(W2)); // Watch W1.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
|
||||
AbstractTensorHandle* mm = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
|
||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||
"relu0", registry));
|
||||
|
||||
AbstractTensorHandle* hidden = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
|
||||
absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
||||
registry)); // W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmaxloss", registry)); // W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* loss = temp_outputs[0];
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tape->ComputeGradient(vspace, /*target_tensor_ids=*/{ToId(loss)},
|
||||
/*source_tensor_ids=*/{ToId(W1), ToId(W2)},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
|
||||
// Only release 2nd temp output as first holds loss values.
|
||||
temp_outputs[1]->Unref();
|
||||
|
||||
outputs[0] = out_grads[0]; // dW1
|
||||
outputs[1] = out_grads[1]; // dW2
|
||||
outputs[2] = loss;
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ScalarMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* eta = inputs[0];
|
||||
AbstractTensorHandle* A = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
|
||||
"scalarMul0", registry)); // Compute eta*A
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ============================= End Models ================================
|
||||
|
||||
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
|
||||
vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate) {
|
||||
/* Update weights one by one using gradient update rule:
|
||||
*
|
||||
* w -= lr*grad[w]
|
||||
*
|
||||
* NOTE: assuming learning rate is positive
|
||||
*/
|
||||
|
||||
Status s;
|
||||
int num_grads = grads.size();
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
std::string update_str;
|
||||
|
||||
// Negate learning rate for gradient descent
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"neg_lr")); // Compute -lr
|
||||
learning_rate = temp_outputs[0];
|
||||
|
||||
for (int i = 0; i < num_grads; i++) {
|
||||
// Compute dW = -lr * grad(w[i])
|
||||
update_str = "update_mul_" + std::to_string(i);
|
||||
s = ops::Mul(ctx, {learning_rate, grads[i]}, absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str());
|
||||
|
||||
AbstractTensorHandle* dW = temp_outputs[0];
|
||||
|
||||
// Compute temp = weights[i] + dW
|
||||
update_str = "update_add_" + std::to_string(i);
|
||||
s = ops::Add(ctx, {weights[i], dW}, absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str());
|
||||
|
||||
// Update the weights
|
||||
weights[i] = temp_outputs[0];
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
// Returning null tensors from a tf.function is not supported, so we keep
|
||||
// track of indices in the model's outputs are nullptr in this set.
|
||||
// The FunctionDef only outputs the non-null tensors. We later pad the
|
||||
// function op outputs to have nullptrs at the `null_indices`.
|
||||
absl::flat_hash_set<int> null_indices;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
vector<AbstractTensorHandle*> model_outputs;
|
||||
model_outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(model_outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = 0;
|
||||
output_list.outputs.reserve(outputs.size());
|
||||
for (int i = 0; i < model_outputs.size(); i++) {
|
||||
if (model_outputs[i]) {
|
||||
output_list.outputs.emplace_back(model_outputs[i]);
|
||||
output_list.expected_num_outputs += 1;
|
||||
} else {
|
||||
null_indices.insert(i);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size() - null_indices.size();
|
||||
vector<AbstractTensorHandle*> fn_outputs(retvals);
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
|
||||
&retvals));
|
||||
int skipped_indices = 0;
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
if (!null_indices.contains(i)) {
|
||||
outputs[i] = fn_outputs[i - skipped_indices];
|
||||
} else {
|
||||
skipped_indices += 1;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs, registry);
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
146
tensorflow/c/eager/mnist_gradients_testutil.h
Normal file
146
tensorflow/c/eager/mnist_gradients_testutil.h
Normal file
@ -0,0 +1,146 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using namespace tensorflow::gradients;
|
||||
using namespace tensorflow::gradients::internal;
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` and records it on the tape.
|
||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
|
||||
// tape.
|
||||
Status SparseSoftmaxCrossEntropyLoss(
|
||||
AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// ====================== End Tape Ops ============================
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] * inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status MatMulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes 2-layer Neural Network with Softmax Loss.
|
||||
Status MNISTForwardModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes MatMul with first matrix tranposed.
|
||||
Status MatMulTransposeModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Test Model to verify ReluGrad functionality
|
||||
Status ReluGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Test Model to verify SoftmaxGrad functionality
|
||||
Status SoftmaxLossGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Test Model to verify Multi-grad functionality for MNIST
|
||||
Status MNISTGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Test Model to verify scalar-tensor multiplication Op
|
||||
Status ScalarMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Updates the weights for a neural network given incoming grads and learning
|
||||
// rate
|
||||
Status UpdateWeights(AbstractContext* ctx,
|
||||
std::vector<AbstractTensorHandle*>& grads,
|
||||
std::vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate);
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name);
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params);
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
@ -76,10 +76,26 @@ cc_library(
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "parallel_device_lib_test",
|
||||
srcs = ["parallel_device_lib_test.cc"],
|
||||
deps = [
|
||||
":parallel_device_lib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_testlib",
|
||||
testonly = 1,
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
@ -118,6 +119,9 @@ class DeviceThread {
|
||||
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
// Outputs
|
||||
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
// TF_Status is an incomplete type and so can't be stack allocated. To avoid
|
||||
// unnecessary allocations each Execute call, we keep one heap-allocated
|
||||
// version for the thread.
|
||||
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
|
||||
|
||||
const std::string device_;
|
||||
@ -188,6 +192,9 @@ std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
|
||||
if (TF_GetCode(status_.get()) != TF_OK) {
|
||||
TF_SetStatus(status, TF_GetCode(status_.get()),
|
||||
TF_Message(status_.get()));
|
||||
// Reset the member `status_` so future op executions (after recovery from
|
||||
// the bad `status`) start with an OK status.
|
||||
TF_SetStatus(status_.get(), TF_OK, "");
|
||||
}
|
||||
execution_state_ = ExecutionState::kIdle;
|
||||
result = std::move(op_outputs_);
|
||||
@ -255,18 +262,27 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
TFE_Context* context, TF_Status* status) const {
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::Vector(
|
||||
TFE_Context* context, TF_Status* status,
|
||||
absl::Span<const int32_t> values) const {
|
||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
|
||||
if (values.size() != num_underlying_devices()) {
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
"Number of values did not match number of underlying devices.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (int device_index = 0; device_index < num_underlying_devices();
|
||||
++device_index) {
|
||||
int32_t* device_id = new int32_t;
|
||||
*device_id = device_index;
|
||||
int32_t* device_value = new int32_t;
|
||||
*device_value = values[device_index];
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value,
|
||||
sizeof(int32_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int32_t*>(data);
|
||||
@ -295,6 +311,16 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
TFE_Context* context, TF_Status* status) const {
|
||||
std::vector<int32_t> ids;
|
||||
ids.reserve(num_underlying_devices());
|
||||
for (int i = 0; i < num_underlying_devices(); ++i) {
|
||||
ids.push_back(i);
|
||||
}
|
||||
return Vector(context, status, ids);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::Execute(TFE_Context* context,
|
||||
const std::vector<ParallelTensor*>& inputs,
|
||||
@ -319,21 +345,36 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
std::move(device_inputs), attributes,
|
||||
expected_max_outputs);
|
||||
}
|
||||
StatusPtr first_bad_status(nullptr);
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
DeviceThread* device_thread = device_threads_[device_index].get();
|
||||
per_device_output_tensors.push_back(device_thread->Join(status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
// We will run every Join even if there are bad statuses in case the user
|
||||
// wants to recover and continue running ops on the parallel device (which
|
||||
// would otherwise deadlock).
|
||||
if (TF_GetCode(status) != TF_OK && first_bad_status == nullptr) {
|
||||
first_bad_status.reset(TF_NewStatus());
|
||||
TF_SetStatus(first_bad_status.get(), TF_GetCode(status),
|
||||
TF_Message(status));
|
||||
}
|
||||
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = per_device_output_tensors.rbegin()->size();
|
||||
} else {
|
||||
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
if (first_bad_status == nullptr &&
|
||||
per_device_output_tensors.rbegin()->size() != first_op_output_count) {
|
||||
first_bad_status.reset(TF_NewStatus());
|
||||
TF_SetStatus(first_bad_status.get(), TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (first_bad_status != nullptr) {
|
||||
TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
|
||||
TF_Message(first_bad_status.get()));
|
||||
return result;
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
@ -61,6 +62,11 @@ class ParallelDevice {
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// Construct a parallel tensor consisting of the scalar values from `values`.
|
||||
std::unique_ptr<ParallelTensor> Vector(
|
||||
TFE_Context* context, TF_Status* status,
|
||||
absl::Span<const int32_t> values) const;
|
||||
|
||||
// A parallel tensor with scalar integers numbering component devices.
|
||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||
TF_Status* status) const;
|
||||
|
@ -0,0 +1,84 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
|
||||
TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::vector<std::string> devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
ParallelDevice parallel_device(std::move(devices));
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
auto outputs =
|
||||
parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
|
||||
"VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
|
||||
/*expected_max_outputs=*/1, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
|
||||
std::vector<ParallelTensor*> handle_inputs;
|
||||
handle_inputs.reserve(handles.size());
|
||||
for (auto& handle : handles) {
|
||||
handle_inputs.push_back(handle.get());
|
||||
}
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> read_op(
|
||||
TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT);
|
||||
parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp",
|
||||
TFE_OpGetAttrs(read_op.get()),
|
||||
/*expected_max_outputs=*/1, status.get());
|
||||
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
|
||||
TF_SetStatus(status.get(), TF_OK, "");
|
||||
|
||||
// Check that ops still run successfully on the device.
|
||||
parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
|
||||
"VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
|
||||
/*expected_max_outputs=*/1, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
@ -146,13 +146,16 @@ class GradientTape {
|
||||
// once) and produces the gradient of the target tensors with respect to the
|
||||
// source tensors. The output gradients are used if not empty and not
|
||||
// null. The result is populated with one tensor per target element.
|
||||
// When running backward functions, builds zeros-like tensors for
|
||||
// incoming grads which are nullptrs, unless `build_default_zeros_grads`
|
||||
// is set to false.
|
||||
Status ComputeGradient(
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
|
||||
const gtl::ArraySlice<int64> target_tensor_ids,
|
||||
const gtl::ArraySlice<int64> source_tensor_ids,
|
||||
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result);
|
||||
std::vector<Gradient*>* result, bool build_default_zeros_grads = true);
|
||||
|
||||
bool IsPersistent() const { return persistent_; }
|
||||
|
||||
@ -655,8 +658,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
const gtl::ArraySlice<int64> target_tensor_ids,
|
||||
const gtl::ArraySlice<int64> source_tensor_ids,
|
||||
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result) {
|
||||
gtl::ArraySlice<Gradient*> output_gradients, std::vector<Gradient*>* result,
|
||||
bool build_default_zeros_grads) {
|
||||
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
|
||||
source_tensor_ids.end());
|
||||
BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
|
||||
@ -717,14 +720,14 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
const int64 id = trace.output_tensor_info[i].GetID();
|
||||
auto grad_it = gradients.find(id);
|
||||
if (grad_it == gradients.end()) {
|
||||
auto func_name_it =
|
||||
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
|
||||
if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
|
||||
func_name_it->second.find(i) != func_name_it->second.end()) {
|
||||
out_gradients.push_back(nullptr);
|
||||
} else {
|
||||
out_gradients.push_back(nullptr);
|
||||
zero_indices.push_back(i);
|
||||
out_gradients.push_back(nullptr);
|
||||
if (build_default_zeros_grads) {
|
||||
auto func_name_it =
|
||||
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
|
||||
if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() ||
|
||||
func_name_it->second.find(i) == func_name_it->second.end()) {
|
||||
zero_indices.push_back(i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
any_gradient_nonzero = true;
|
||||
@ -745,6 +748,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
}
|
||||
}
|
||||
std::vector<Gradient*> in_gradients;
|
||||
DCHECK(build_default_zeros_grads || zero_indices.empty());
|
||||
if (any_gradient_nonzero) {
|
||||
for (const auto i : zero_indices) {
|
||||
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
|
||||
|
@ -26,6 +26,8 @@ cc_library(
|
||||
}),
|
||||
deps = [
|
||||
":aws_crypto",
|
||||
":aws_logging",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@aws",
|
||||
@ -45,6 +47,18 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "aws_logging",
|
||||
srcs = ["aws_logging.cc"],
|
||||
hdrs = ["aws_logging.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:logging",
|
||||
"@aws",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "s3_filesystem_test",
|
||||
srcs = [
|
||||
|
159
tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc
Normal file
159
tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h"
|
||||
|
||||
#include <aws/core/Aws.h>
|
||||
#include <aws/core/utils/logging/AWSLogging.h>
|
||||
#include <aws/core/utils/logging/LogSystemInterface.h>
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/c/logging.h"
|
||||
|
||||
static constexpr char kAWSLoggingTag[] = "AWSLogging";
|
||||
|
||||
static const std::map<const std::string, const Aws::Utils::Logging::LogLevel>
|
||||
log_levels_string_to_aws = {
|
||||
{"off", Aws::Utils::Logging::LogLevel::Off},
|
||||
{"fatal", Aws::Utils::Logging::LogLevel::Fatal},
|
||||
{"error", Aws::Utils::Logging::LogLevel::Error},
|
||||
{"warn", Aws::Utils::Logging::LogLevel::Warn},
|
||||
{"info", Aws::Utils::Logging::LogLevel::Info},
|
||||
{"debug", Aws::Utils::Logging::LogLevel::Debug},
|
||||
{"trace", Aws::Utils::Logging::LogLevel::Trace}};
|
||||
|
||||
static const std::map<const int, const Aws::Utils::Logging::LogLevel>
|
||||
log_levels_tf_to_aws = {{0, Aws::Utils::Logging::LogLevel::Info},
|
||||
{1, Aws::Utils::Logging::LogLevel::Warn},
|
||||
{2, Aws::Utils::Logging::LogLevel::Error},
|
||||
{3, Aws::Utils::Logging::LogLevel::Fatal}};
|
||||
|
||||
namespace tf_s3_filesystem {
|
||||
|
||||
AWSLogSystem::AWSLogSystem(Aws::Utils::Logging::LogLevel log_level)
|
||||
: log_level_(log_level) {}
|
||||
|
||||
void AWSLogSystem::LogMessage(Aws::Utils::Logging::LogLevel log_level,
|
||||
const std::string& message) {
|
||||
if (message == "Initializing Curl library") return;
|
||||
switch (log_level) {
|
||||
case Aws::Utils::Logging::LogLevel::Info:
|
||||
TF_Log(TF_INFO, message.c_str());
|
||||
break;
|
||||
case Aws::Utils::Logging::LogLevel::Warn:
|
||||
TF_Log(TF_WARNING, message.c_str());
|
||||
break;
|
||||
case Aws::Utils::Logging::LogLevel::Error:
|
||||
TF_Log(TF_ERROR, message.c_str());
|
||||
break;
|
||||
case Aws::Utils::Logging::LogLevel::Fatal:
|
||||
TF_Log(TF_FATAL, message.c_str());
|
||||
break;
|
||||
default:
|
||||
// this will match for DEBUG, TRACE
|
||||
TF_Log(TF_INFO, message.c_str());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void AWSLogSystem::Log(Aws::Utils::Logging::LogLevel log_level, const char* tag,
|
||||
const char* format, ...) {
|
||||
char buffer[256];
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
vsnprintf(buffer, 256, format, args);
|
||||
va_end(args);
|
||||
LogMessage(log_level, buffer);
|
||||
}
|
||||
|
||||
void AWSLogSystem::LogStream(Aws::Utils::Logging::LogLevel log_level,
|
||||
const char* tag,
|
||||
const Aws::OStringStream& message_stream) {
|
||||
LogMessage(log_level, message_stream.rdbuf()->str().c_str());
|
||||
}
|
||||
|
||||
void AWSLogSystem::Flush() { return; }
|
||||
|
||||
static Aws::Utils::Logging::LogLevel TfLogLevelToAwsLogLevel(int level) {
|
||||
// Converts TF Log Levels INFO, WARNING, ERROR and FATAL to the AWS enum
|
||||
// values for the levels
|
||||
if (log_levels_tf_to_aws.find(level) != log_levels_tf_to_aws.end()) {
|
||||
return log_levels_tf_to_aws.at(level);
|
||||
} else {
|
||||
// default to fatal
|
||||
return Aws::Utils::Logging::LogLevel::Fatal;
|
||||
}
|
||||
}
|
||||
|
||||
static Aws::Utils::Logging::LogLevel ParseAwsLogLevelFromEnv() {
|
||||
// defaults to FATAL log level for the AWS SDK
|
||||
// this is because many normal tensorflow operations are logged as errors in
|
||||
// the AWS SDK such as checking if a file exists can log an error in AWS SDK
|
||||
// if the file does not actually exist. Another such case is when reading a
|
||||
// file till the end, TensorFlow expects to see an InvalidRange exception at
|
||||
// the end, but this would be an error in the AWS SDK. This confuses users,
|
||||
// hence the default setting.
|
||||
Aws::Utils::Logging::LogLevel log_level =
|
||||
Aws::Utils::Logging::LogLevel::Fatal;
|
||||
|
||||
const char* aws_env_var_val = getenv("AWS_LOG_LEVEL");
|
||||
if (aws_env_var_val != nullptr) {
|
||||
std::string maybe_integer_str(aws_env_var_val, strlen(aws_env_var_val));
|
||||
std::istringstream ss(maybe_integer_str);
|
||||
int level;
|
||||
ss >> level;
|
||||
if (ss.fail()) {
|
||||
// wasn't a number
|
||||
// expecting a string
|
||||
std::string level_str = maybe_integer_str;
|
||||
if (log_levels_string_to_aws.find(level_str) !=
|
||||
log_levels_string_to_aws.end()) {
|
||||
log_level = log_levels_string_to_aws.at(level_str);
|
||||
}
|
||||
} else {
|
||||
// backwards compatibility
|
||||
// valid number, but this number follows the standard TensorFlow log
|
||||
// levels need to convert this to AWS SDK logging level number
|
||||
log_level = TfLogLevelToAwsLogLevel(level);
|
||||
}
|
||||
}
|
||||
return log_level;
|
||||
}
|
||||
|
||||
static bool initialized = false;
|
||||
ABSL_CONST_INIT static absl::Mutex s3_logging_mutex(absl::kConstInit);
|
||||
void AWSLogSystem::InitializeAWSLogging() {
|
||||
absl::MutexLock l(&s3_logging_mutex);
|
||||
if (!initialized) {
|
||||
Aws::Utils::Logging::InitializeAWSLogging(Aws::MakeShared<AWSLogSystem>(
|
||||
kAWSLoggingTag, ParseAwsLogLevelFromEnv()));
|
||||
initialized = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void AWSLogSystem::ShutdownAWSLogging() {
|
||||
absl::MutexLock l(&s3_logging_mutex);
|
||||
if (initialized) {
|
||||
Aws::Utils::Logging::ShutdownAWSLogging();
|
||||
initialized = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tf_s3_filesystem
|
@ -0,0 +1,64 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_
|
||||
|
||||
#include <aws/core/utils/logging/LogLevel.h>
|
||||
#include <aws/core/utils/logging/LogSystemInterface.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <string>
|
||||
|
||||
namespace tf_s3_filesystem {
|
||||
|
||||
class AWSLogSystem : public Aws::Utils::Logging::LogSystemInterface {
|
||||
public:
|
||||
static void InitializeAWSLogging();
|
||||
static void ShutdownAWSLogging();
|
||||
|
||||
explicit AWSLogSystem(Aws::Utils::Logging::LogLevel log_level);
|
||||
virtual ~AWSLogSystem() = default;
|
||||
|
||||
// Gets the currently configured log level.
|
||||
Aws::Utils::Logging::LogLevel GetLogLevel(void) const override {
|
||||
return log_level_;
|
||||
}
|
||||
|
||||
// Set a new log level. This has the immediate effect of changing the log.
|
||||
void SetLogLevel(Aws::Utils::Logging::LogLevel log_level) {
|
||||
log_level_.store(log_level);
|
||||
}
|
||||
|
||||
// Does a printf style output to ProcessFormattedStatement. Don't use this,
|
||||
// it's unsafe. See LogStream.
|
||||
void Log(Aws::Utils::Logging::LogLevel log_level, const char* tag,
|
||||
const char* format, ...) override;
|
||||
|
||||
// Writes the stream to ProcessFormattedStatement.
|
||||
void LogStream(Aws::Utils::Logging::LogLevel log_level, const char* tag,
|
||||
const Aws::OStringStream& messageStream) override;
|
||||
|
||||
// Flushes the buffered messages if the logger supports buffering
|
||||
void Flush() override;
|
||||
|
||||
private:
|
||||
void LogMessage(Aws::Utils::Logging::LogLevel log_level,
|
||||
const std::string& message);
|
||||
std::atomic<Aws::Utils::Logging::LogLevel> log_level_;
|
||||
};
|
||||
|
||||
} // namespace tf_s3_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_
|
@ -38,6 +38,8 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h"
|
||||
#include "tensorflow/c/logging.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for S3 environments.
|
||||
@ -186,6 +188,8 @@ static void GetS3Client(tf_s3_filesystem::S3File* s3_file) {
|
||||
absl::MutexLock l(&s3_file->initialization_lock);
|
||||
|
||||
if (s3_file->s3_client.get() == nullptr) {
|
||||
tf_s3_filesystem::AWSLogSystem::InitializeAWSLogging();
|
||||
|
||||
Aws::SDKOptions options;
|
||||
options.cryptoOptions.sha256Factory_create_fn = []() {
|
||||
return Aws::MakeShared<tf_s3_filesystem::AWSSHA256Factory>(
|
||||
@ -250,6 +254,7 @@ static void ShutdownClient(Aws::S3::S3Client* s3_client) {
|
||||
delete s3_client;
|
||||
Aws::SDKOptions options;
|
||||
Aws::ShutdownAPI(options);
|
||||
tf_s3_filesystem::AWSLogSystem::ShutdownAWSLogging();
|
||||
}
|
||||
}
|
||||
|
||||
@ -281,6 +286,7 @@ void Cleanup(TF_RandomAccessFile* file) {
|
||||
|
||||
static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
TF_VLog(3, "ReadFile using S3Client\n");
|
||||
Aws::S3::Model::GetObjectRequest get_object_request;
|
||||
get_object_request.WithBucket(s3_file->bucket).WithKey(s3_file->object);
|
||||
Aws::String bytes =
|
||||
@ -306,12 +312,14 @@ static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
|
||||
|
||||
static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
TF_VLog(3, "Using TransferManager\n");
|
||||
auto create_download_stream = [&]() {
|
||||
return Aws::New<TFS3UnderlyingStream>(
|
||||
"S3ReadStream",
|
||||
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
|
||||
"S3ReadStream", reinterpret_cast<unsigned char*>(buffer), n));
|
||||
};
|
||||
TF_VLog(3, "Created stream to read with transferManager\n");
|
||||
auto handle = s3_file->transfer_manager->DownloadFile(
|
||||
s3_file->bucket, s3_file->object, offset, n, create_download_stream);
|
||||
handle->WaitUntilFinished();
|
||||
@ -322,6 +330,10 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
|
||||
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
|
||||
retries++ < kDownloadRetries) {
|
||||
// Only failed parts will be downloaded again.
|
||||
TF_VLog(
|
||||
1,
|
||||
"Retrying read of s3://%s/%s after failure. Current retry count: %u\n",
|
||||
s3_file->bucket.c_str(), s3_file->object.c_str(), retries);
|
||||
s3_file->transfer_manager->RetryDownload(handle);
|
||||
handle->WaitUntilFinished();
|
||||
}
|
||||
@ -341,6 +353,8 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
TF_VLog(1, "ReadFilefromS3 s3://%s/%s from %u for n: %u\n",
|
||||
s3_file->bucket.c_str(), s3_file->object.c_str(), offset, n);
|
||||
if (s3_file->use_multi_part_download)
|
||||
return ReadS3TransferManager(s3_file, offset, n, buffer, status);
|
||||
else
|
||||
@ -416,6 +430,8 @@ void Sync(const TF_WritableFile* file, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return;
|
||||
}
|
||||
TF_VLog(1, "WriteFileToS3: s3://%s/%s\n", s3_file->bucket.c_str(),
|
||||
s3_file->object.c_str());
|
||||
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
|
||||
auto handle = s3_file->transfer_manager->UploadFile(
|
||||
s3_file->outfile, s3_file->bucket, s3_file->object,
|
||||
@ -426,6 +442,10 @@ void Sync(const TF_WritableFile* file, TF_Status* status) {
|
||||
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
|
||||
retries++ < kUploadRetries) {
|
||||
// if multipart upload was used, only the failed parts will be re-sent
|
||||
TF_VLog(1,
|
||||
"Retrying upload of s3://%s/%s after failure. Current retry count: "
|
||||
"%u\n",
|
||||
s3_file->bucket.c_str(), s3_file->object.c_str(), retries);
|
||||
s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle);
|
||||
handle->WaitUntilFinished();
|
||||
}
|
||||
@ -613,6 +633,7 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_FileStatistics* stats, TF_Status* status) {
|
||||
TF_VLog(1, "Stat on path: %s\n", path);
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -737,6 +758,8 @@ static void SimpleCopyFile(const Aws::String& source,
|
||||
const Aws::String& bucket_dst,
|
||||
const Aws::String& object_dst, S3File* s3_file,
|
||||
TF_Status* status) {
|
||||
TF_VLog(1, "SimpleCopyFile from %s to %s/%s\n", bucket_dst.c_str(),
|
||||
object_dst.c_str());
|
||||
Aws::S3::Model::CopyObjectRequest copy_object_request;
|
||||
copy_object_request.WithCopySource(source)
|
||||
.WithBucket(bucket_dst)
|
||||
@ -801,6 +824,8 @@ static void MultiPartCopy(const Aws::String& source,
|
||||
const Aws::String& object_dst, const size_t num_parts,
|
||||
const uint64_t file_size, S3File* s3_file,
|
||||
TF_Status* status) {
|
||||
TF_VLog(1, "MultiPartCopy from %s to %s/%s\n", bucket_dst.c_str(),
|
||||
object_dst.c_str());
|
||||
Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request;
|
||||
create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst);
|
||||
|
||||
@ -827,6 +852,8 @@ static void MultiPartCopy(const Aws::String& source,
|
||||
auto chunk_size =
|
||||
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
|
||||
|
||||
TF_VLog(1, "Copying from %s in %u parts of size %u each\n", source.c_str(),
|
||||
num_parts, chunk_size);
|
||||
size_t retries = 0;
|
||||
while (retries++ < 3) {
|
||||
// Queue up parts.
|
||||
@ -891,6 +918,9 @@ static void MultiPartCopy(const Aws::String& source,
|
||||
status);
|
||||
} else {
|
||||
// Retry.
|
||||
TF_Log(TF_ERROR,
|
||||
"Retrying failed copy of part %u due to an error with S3\n",
|
||||
part_number);
|
||||
num_finished_parts--;
|
||||
}
|
||||
}
|
||||
@ -967,6 +997,7 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
|
||||
|
||||
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
TF_VLog(1, "DeleteFile: %s\n", path);
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -985,6 +1016,7 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
TF_VLog(1, "CreateDir: %s\n", path);
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -1026,6 +1058,7 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
TF_VLog(1, "DeleteDir: %s\n", path);
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -1060,6 +1093,7 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
const char* dst, TF_Status* status) {
|
||||
TF_VLog(1, "RenameFile from: %s to %s\n", src, dst);
|
||||
Aws::String bucket_src, object_src;
|
||||
ParseS3Path(src, false, &bucket_src, &object_src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -1120,6 +1154,7 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
|
||||
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
char*** entries, TF_Status* status) {
|
||||
TF_VLog(1, "GetChildren for path: %s\n", path);
|
||||
Aws::String bucket, prefix;
|
||||
ParseS3Path(path, true, &bucket, &prefix, status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
|
@ -3,6 +3,24 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "array_grad",
|
||||
srcs = ["array_grad.cc"],
|
||||
hdrs = [
|
||||
"array_grad.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "math_grad",
|
||||
srcs = ["math_grad.cc"],
|
||||
@ -19,6 +37,28 @@ cc_library(
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nn_grad",
|
||||
srcs = ["nn_grad.cc"],
|
||||
hdrs = [
|
||||
"nn_grad.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
48
tensorflow/c/experimental/gradients/array_grad.cc
Normal file
48
tensorflow/c/experimental/gradients/array_grad.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace {
|
||||
using std::vector;
|
||||
class IdentityNGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(grad_inputs.size(), nullptr);
|
||||
for (int i = 0; i < grad_inputs.size(); i++) {
|
||||
auto grad_input = grad_inputs[i];
|
||||
// TODO(srbs): Should we add a copy contructor to AbstractTensorHandle
|
||||
// that takes care of this similar to `Tensor`?
|
||||
if (grad_input) {
|
||||
grad_input->Ref();
|
||||
}
|
||||
(*grad_outputs)[i] = grad_input;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
~IdentityNGradientFunction() override {}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new IdentityNGradientFunction;
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
26
tensorflow/c/experimental/gradients/array_grad.h
Normal file
26
tensorflow/c/experimental/gradients/array_grad.h
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
|
||||
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
|
@ -15,13 +15,17 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::Identity;
|
||||
using tensorflow::ops::MatMul;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::ZerosLike;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
@ -29,20 +33,23 @@ namespace {
|
||||
|
||||
class AddGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
// TODO(b/145674566): Handle name unification in tracing code.
|
||||
// TODO(b/161805092): Support broadcasting.
|
||||
|
||||
std::string name = "Identity_A";
|
||||
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs),
|
||||
"Identity0"));
|
||||
name.c_str()));
|
||||
(*grad_outputs)[0] = identity_outputs[0];
|
||||
|
||||
name = "Identity_B";
|
||||
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs),
|
||||
"Identity1"));
|
||||
name.c_str()));
|
||||
(*grad_outputs)[1] = identity_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
@ -54,16 +61,18 @@ class ExpGradientFunction : public GradientFunction {
|
||||
explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
|
||||
exp->Ref();
|
||||
}
|
||||
Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
vector<AbstractTensorHandle*> conj_outputs(1);
|
||||
TF_RETURN_IF_ERROR(
|
||||
Conj(ctx->ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), "ExpConj"));
|
||||
std::string name = "Conj_Exp_Grad";
|
||||
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {exp_.get()},
|
||||
absl::MakeSpan(conj_outputs), name.c_str()));
|
||||
AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
|
||||
grad_outputs->resize(1);
|
||||
|
||||
name = "Mul_Exp_Grad";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), "ExpGradMul"));
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~ExpGradientFunction() override {}
|
||||
@ -72,14 +81,142 @@ class ExpGradientFunction : public GradientFunction {
|
||||
AbstractTensorHandlePtr exp_;
|
||||
};
|
||||
|
||||
class MatMulGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||
AttrBuilder f_attrs)
|
||||
: forward_inputs(f_inputs), forward_attrs(f_attrs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a matmul op A*B, the gradients are:
|
||||
*
|
||||
* dA = U * B.T
|
||||
* dB = A.T * U
|
||||
*
|
||||
* where A.T means `transpose(A)`
|
||||
*/
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
grad_outputs->resize(2);
|
||||
|
||||
// Get transpose attrs
|
||||
bool t_a;
|
||||
forward_attrs.Get("transpose_a", &t_a);
|
||||
|
||||
bool t_b;
|
||||
forward_attrs.Get("transpose_b", &t_b);
|
||||
|
||||
// Conj each input
|
||||
vector<AbstractTensorHandle*> conj_outputs(1);
|
||||
std::string name = "Conj_A_MatMul_Grad";
|
||||
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[0]},
|
||||
absl::MakeSpan(conj_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* A = conj_outputs[0];
|
||||
|
||||
name = "Conj_B_MatMul_Grad";
|
||||
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[1]},
|
||||
absl::MakeSpan(conj_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* B = conj_outputs[0];
|
||||
|
||||
// Calc Grad
|
||||
vector<AbstractTensorHandle*> matmul_A_outputs(1);
|
||||
vector<AbstractTensorHandle*> matmul_B_outputs(1);
|
||||
std::string name_grad_A = "MatMul_Grad_A";
|
||||
std::string name_grad_B = "MatMul_Grad_B";
|
||||
if (!t_a && !t_b) {
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
|
||||
absl::MakeSpan(matmul_A_outputs),
|
||||
name_grad_A.c_str(),
|
||||
/*transpose_a = */ false,
|
||||
/*transpose_b = */ true));
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
|
||||
absl::MakeSpan(matmul_B_outputs),
|
||||
name_grad_B.c_str(),
|
||||
/*transpose_a = */ true,
|
||||
/*transpose_b = */ false));
|
||||
} else if (!t_a && t_b) {
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
|
||||
absl::MakeSpan(matmul_A_outputs),
|
||||
name_grad_A.c_str(),
|
||||
/*transpose_a = */ false,
|
||||
/*transpose_b = */ false));
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
|
||||
absl::MakeSpan(matmul_B_outputs),
|
||||
name_grad_B.c_str(),
|
||||
/*transpose_a = */ true,
|
||||
/*transpose_b = */ false));
|
||||
|
||||
} else if (t_a && !t_b) {
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
|
||||
absl::MakeSpan(matmul_A_outputs),
|
||||
name_grad_A.c_str(),
|
||||
/*transpose_a = */ false,
|
||||
/*transpose_b = */ true));
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
|
||||
absl::MakeSpan(matmul_B_outputs),
|
||||
name_grad_B.c_str(),
|
||||
/*transpose_a = */ false,
|
||||
/*transpose_b = */ false));
|
||||
} else { // t_a && t_b
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
|
||||
absl::MakeSpan(matmul_A_outputs),
|
||||
name_grad_A.c_str(),
|
||||
/*transpose_a = */ true,
|
||||
/*transpose_b = */ true));
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
|
||||
absl::MakeSpan(matmul_B_outputs),
|
||||
name_grad_B.c_str(),
|
||||
/*transpose_a = */ true,
|
||||
/*transpose_b = */ true));
|
||||
}
|
||||
|
||||
// Gradient for A
|
||||
(*grad_outputs)[0] = matmul_A_outputs[0];
|
||||
|
||||
// Gradient for B
|
||||
(*grad_outputs)[1] = matmul_B_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~MatMulGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
AttrBuilder forward_attrs;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
return new AddGradientFunction;
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new AddGradientFunction;
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
GradientFunction* ExpRegisterer(const ForwardOperation& op) {
|
||||
return new ExpGradientFunction(op.outputs[0]);
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new ExpGradientFunction(op.outputs[0]);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new MatMulGradientFunction(op.inputs, op.attrs);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
|
@ -19,9 +19,10 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
111
tensorflow/c/experimental/gradients/nn_grad.cc
Normal file
111
tensorflow/c/experimental/gradients/nn_grad.cc
Normal file
@ -0,0 +1,111 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::Identity;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::ReluGrad;
|
||||
using tensorflow::ops::SparseSoftmaxCrossEntropyLoss;
|
||||
using tensorflow::ops::ZerosLike;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace {
|
||||
|
||||
class ReluGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit ReluGradientFunction(vector<AbstractTensorHandle*> f_outputs)
|
||||
: forward_outputs(f_outputs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
AbstractTensorHandle* activations = forward_outputs[0];
|
||||
grad_outputs->resize(1);
|
||||
vector<AbstractTensorHandle*> relugrad_outputs(1);
|
||||
|
||||
// Calculate Grad
|
||||
std::string name = "relu_grad";
|
||||
|
||||
TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, activations},
|
||||
absl::MakeSpan(relugrad_outputs),
|
||||
name.c_str()));
|
||||
(*grad_outputs)[0] = relugrad_outputs[0];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
~ReluGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_outputs;
|
||||
};
|
||||
|
||||
class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit SparseSoftmaxCrossEntropyLossGradientFunction(
|
||||
vector<AbstractTensorHandle*> f_outputs)
|
||||
: forward_outputs(f_outputs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
|
||||
// Grad for Softmax Input
|
||||
std::string name = "Mul_Softmax_Grad";
|
||||
vector<AbstractTensorHandle*> mul_outputs(1);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Mul(ctx->ctx, {grad_inputs[0], forward_outputs[1]},
|
||||
absl::MakeSpan(mul_outputs),
|
||||
name.c_str())); // upstream_grad * local softmax grad
|
||||
(*grad_outputs)[0] = mul_outputs[0];
|
||||
|
||||
// Grad for labels is null
|
||||
(*grad_outputs)[1] = nullptr;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
~SparseSoftmaxCrossEntropyLossGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_outputs;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* ReluRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new ReluGradientFunction(op.outputs);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer(
|
||||
const ForwardOperation& op) {
|
||||
auto gradient_function =
|
||||
new SparseSoftmaxCrossEntropyLossGradientFunction(op.outputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
28
tensorflow/c/experimental/gradients/nn_grad.h
Normal file
28
tensorflow/c/experimental/gradients/nn_grad.h
Normal file
@ -0,0 +1,28 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
|
||||
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
BackwardFunction* ReluRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer(
|
||||
const ForwardOperation& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
|
@ -15,7 +15,6 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
@ -36,12 +35,30 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":array_ops",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nn_ops",
|
||||
srcs = [
|
||||
"nn_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"nn_ops.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
// Creates an Identity op.
|
||||
|
||||
Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
@ -35,5 +35,19 @@ Status Identity(AbstractContext* ctx,
|
||||
return identity_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status ZerosLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr z_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
|
||||
if (isa<tensorflow::tracing::TracingOperation>(z_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(z_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
return z_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -22,9 +22,15 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status ZerosLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -51,5 +51,60 @@ Status Conj(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1]));
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(add_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMul(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a = false, bool transpose_b = false) {
|
||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(matmul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(matmul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1]));
|
||||
|
||||
TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_a", transpose_a));
|
||||
TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_b", transpose_b));
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(matmul_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr neg_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr));
|
||||
if (isa<TracingOperation>(neg_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(neg_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
return neg_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -25,6 +25,15 @@ Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
Status Conj(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
Status MatMul(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b);
|
||||
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
67
tensorflow/c/experimental/ops/nn_ops.cc
Normal file
67
tensorflow/c/experimental/ops/nn_ops.cc
Normal file
@ -0,0 +1,67 @@
|
||||
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
// Softmax Loss given scores and labels, used by the SoftMaxLossGradient
|
||||
Status SparseSoftmaxCrossEntropyLoss(
|
||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sm_loss_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits",
|
||||
/*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(sm_loss_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(sm_loss_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0])); // input scores
|
||||
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1])); // labels
|
||||
|
||||
// Outputs will contain: [loss_vals, gradients].
|
||||
int num_retvals = 2;
|
||||
TF_RETURN_IF_ERROR(sm_loss_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes Relu gradient given input features
|
||||
Status ReluGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr relugrad_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(relugrad_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(relugrad_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0])); // upstream grads
|
||||
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1])); // relu inputs
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(relugrad_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
37
tensorflow/c/experimental/ops/nn_ops.h
Normal file
37
tensorflow/c/experimental/ops/nn_ops.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
Status SparseSoftmaxCrossEntropyLoss(
|
||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status ReluGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_
|
@ -44,7 +44,9 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":signature_def_function",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -70,6 +72,26 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function",
|
||||
hdrs = [
|
||||
"signature_def_function.h",
|
||||
],
|
||||
deps = [
|
||||
":signature_def_function_metadata",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function_metadata",
|
||||
hdrs = [
|
||||
"signature_def_function_metadata.h",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_utils",
|
||||
testonly = True,
|
||||
@ -115,6 +137,7 @@ cc_library(
|
||||
":concrete_function",
|
||||
":saved_model_api",
|
||||
":saved_model_utils",
|
||||
":signature_def_function",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
@ -206,13 +229,13 @@ tf_cc_test(
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -26,10 +26,14 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Note that ConcreteFunctions's lifetimes are effectively bound
|
||||
// to the SavedModel they are loaded from, since they retain pointers
|
||||
// to the TensorHandles owned by the SavedModel, and the FunctionDef
|
||||
// of the SavedModel.
|
||||
// ConcreteFunctions correspond to an instance of a tf.function with a known set
|
||||
// of inputs (either through get_concrete_function) or an input_signature.
|
||||
// ConcreteFunction attempts to preserve the user-facing semantics of the
|
||||
// tf.function python API and can take a limited set of types as arguments
|
||||
// (to be modeled in tensorflow::Value), not just Tensors.
|
||||
// SavedModelAPI's ConcreteFunctions' lifetimes are bound to the SavedModel they
|
||||
// are loaded from, since they retain pointers to the TensorHandles owned by the
|
||||
// SavedModel, and the FunctionDef of the SavedModel.
|
||||
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
|
||||
// TFRT integration with TF Serving. Do not add more virtual implementations of
|
||||
// this class. Eventually we want to remove this virtual base class indirection
|
||||
|
@ -37,10 +37,11 @@ static const char kNoSharingResourceID[] =
|
||||
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
const char* raw_device_name,
|
||||
ImmediateTensorHandlePtr* handle) {
|
||||
ImmediateOpPtr varhandle_op(ctx->CreateOperation());
|
||||
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", raw_device_name));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
|
||||
|
||||
// Note that if shape is unknown rank, shape.dim_sizes() will be empty, and
|
||||
|
@ -31,6 +31,7 @@ namespace internal {
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
const char* raw_device_name,
|
||||
ImmediateTensorHandlePtr* handle);
|
||||
|
||||
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
|
||||
|
@ -55,7 +55,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
context(), DT_FLOAT, {}, nullptr, &handle));
|
||||
// The created TensorHandle should be a DT_Resource
|
||||
EXPECT_EQ(handle->DataType(), DT_RESOURCE);
|
||||
}
|
||||
@ -65,7 +65,7 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
context(), DT_FLOAT, {}, nullptr, &handle));
|
||||
|
||||
// Destroy the variable
|
||||
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
|
||||
@ -76,7 +76,7 @@ TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
ImmediateTensorHandlePtr variable;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &variable));
|
||||
context(), DT_FLOAT, {}, nullptr, &variable));
|
||||
|
||||
// Create a Scalar float TensorHandle with value 42, and assign it to
|
||||
// the variable.
|
||||
|
@ -65,10 +65,11 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
|
||||
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name,
|
||||
const char* raw_device_name,
|
||||
std::unique_ptr<Variable>* output) {
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||
ctx, dtype, shape, &handle));
|
||||
ctx, dtype, shape, raw_device_name, &handle));
|
||||
|
||||
output->reset(
|
||||
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||
|
@ -37,6 +37,7 @@ class Variable : public TensorHandleConvertible {
|
||||
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name,
|
||||
const char* raw_device_name,
|
||||
std::unique_ptr<Variable>* output);
|
||||
|
||||
// The dtype of the underlying variable.
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -39,11 +40,11 @@ class SavedModelAPI {
|
||||
virtual Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) = 0;
|
||||
|
||||
// Retrieve a function from a SavedModel, using the key of the
|
||||
// Retrieve a SignatureDefFunction from a SavedModel, using the key of the
|
||||
// SignatureDef map:
|
||||
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||
ConcreteFunction** function) = 0;
|
||||
SignatureDefFunction** function) = 0;
|
||||
|
||||
virtual std::vector<ConcreteFunction*> ListFunctions() = 0;
|
||||
|
||||
|
@ -122,9 +122,9 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
tensorflow::TensorShape shape(variable.shape());
|
||||
tensorflow::DataType dtype = variable.dtype();
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
Variable::CreateUninitialized(ctx, dtype, shape, name, output));
|
||||
|
||||
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
|
||||
ctx, dtype, shape, name,
|
||||
variable.device().empty() ? nullptr : variable.device().c_str(), output));
|
||||
return Status();
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -38,9 +39,15 @@ namespace {
|
||||
class SavedVariableLoadingTest : public ::testing::TestWithParam<
|
||||
std::tuple<DataType, std::vector<int64>>> {
|
||||
public:
|
||||
SavedVariableLoadingTest()
|
||||
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||
SavedVariableLoadingTest() {
|
||||
SessionOptions options;
|
||||
options.config.mutable_device_count()->insert({"CPU", 3});
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
ctx_ = testing::CreateTestingEagerContext(device_mgr_.get());
|
||||
}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
@ -67,6 +74,39 @@ TEST_P(SavedVariableLoadingTest, LoadSavedVariableSuccessful) {
|
||||
EXPECT_EQ(var->shape(), shape);
|
||||
}
|
||||
|
||||
// Verify that a device specified in the SavedVariable is kept.
|
||||
TEST_P(SavedVariableLoadingTest, LoadSavedVariableWithDevice) {
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
TensorShape shape(std::get<1>(test_params));
|
||||
|
||||
SavedVariable saved_variable;
|
||||
saved_variable.set_dtype(dtype);
|
||||
saved_variable.set_device("/job:localhost/replica:0/task:0/device:CPU:1"),
|
||||
shape.AsProto(saved_variable.mutable_shape());
|
||||
|
||||
std::unique_ptr<Variable> var;
|
||||
TF_ASSERT_OK(internal::LoadSavedVariable(context(), saved_variable, &var));
|
||||
EXPECT_EQ(down_cast<TensorHandle*>(var->handle())->resource_device()->name(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
}
|
||||
|
||||
// Verify load failure if a non-existing device is specified.
|
||||
TEST_P(SavedVariableLoadingTest, LoadSavedVariableWithInvalidDevice) {
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
TensorShape shape(std::get<1>(test_params));
|
||||
|
||||
SavedVariable saved_variable;
|
||||
saved_variable.set_dtype(dtype);
|
||||
saved_variable.set_device("/job:localhost/replica:0/task:0/device:CPU:99"),
|
||||
shape.AsProto(saved_variable.mutable_shape());
|
||||
|
||||
std::unique_ptr<Variable> var;
|
||||
ASSERT_NE(Status::OK(),
|
||||
internal::LoadSavedVariable(context(), saved_variable, &var));
|
||||
}
|
||||
|
||||
// Assigning and reading values should yield
|
||||
// consistent results.
|
||||
TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
|
||||
@ -79,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
|
||||
Status status;
|
||||
std::unique_ptr<Variable> var;
|
||||
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
||||
absl::nullopt, &var));
|
||||
absl::nullopt, nullptr, &var));
|
||||
|
||||
// Create a TensorHandle
|
||||
ImmediateTensorHandlePtr expected_handle =
|
||||
|
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// See tensorflow/cc/experimental/saved_model/public/signature_def_function.h
|
||||
// for SignatureDefFunction's intended user-facing semantics.
|
||||
// This class is the "implementation" C++ part of the C++/C/C++ sandwich for
|
||||
// a SignatureDefFunction.
|
||||
// Note(bmzhao): Implementation-wise, SignatureDefFunctions are always saved as
|
||||
// a "BareConcreteFunction", w/o a FunctionSpec, rather than a SavedFunction:
|
||||
// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/core/protobuf/saved_object_graph.proto#L60
|
||||
// Additionally they are guaranteed to be children of the .signatures attribute
|
||||
// of the root object, where the child object "name" is the signature_def key:
|
||||
// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/python/saved_model/signature_serialization.py#L181-L230
|
||||
// One of the critical requirements of SignatureDef functions is that their
|
||||
// inputs and outputs are "named". For example, a `.signatures` function:
|
||||
// a. Requires users to pass: kwargs of all inputs:
|
||||
// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/python/saved_model/signature_serialization.py#L119-L126
|
||||
// b. Returns a dictionary of named outputs.
|
||||
// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/python/saved_model/signature_serialization.py#L153-L161
|
||||
// Since SignatureDefFunctions do not have FunctionSpecs, but guarantee the
|
||||
// dictionary of inputs/outputs, we can parse these dictionaries' keys to obtain
|
||||
// the input/output names of the SignatureDef:
|
||||
// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/core/protobuf/meta_graph.proto#L318-L321
|
||||
class SignatureDefFunction {
|
||||
public:
|
||||
virtual ~SignatureDefFunction() = default;
|
||||
|
||||
// Creates a "Call" Op used to execute the function.
|
||||
virtual Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ImmediateOpPtr* out) const = 0;
|
||||
|
||||
virtual const SignatureDefFunctionMetadata& GetFunctionMetadata() const = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_
|
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class SignatureDefFunctionMetadata {
|
||||
// TODO(bmzhao): Fill in with fields as necessary
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/platform/bfloat16.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
@ -305,7 +306,7 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path,
|
||||
}
|
||||
|
||||
Status TFSavedModelAPI::GetSignatureDefFunction(
|
||||
const std::string& signature_def_key, ConcreteFunction** function) {
|
||||
const std::string& signature_def_key, SignatureDefFunction** function) {
|
||||
// TODO(bmzhao): Add support for retrieving a signaturedef function.
|
||||
return errors::Unimplemented(
|
||||
"Retrieving SignatureDef functions is unimplemented currently");
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
@ -55,7 +56,7 @@ class TFSavedModelAPI : public SavedModelAPI {
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||
ConcreteFunction** function) override;
|
||||
SignatureDefFunction** function) override;
|
||||
|
||||
static Status Load(
|
||||
const std::string& directory,
|
||||
|
@ -142,6 +142,8 @@ cc_library(
|
||||
":concrete_function_list_type",
|
||||
":concrete_function_type",
|
||||
":saved_model_api_type",
|
||||
":signature_def_function",
|
||||
":signature_def_function_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
@ -165,6 +167,77 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function",
|
||||
srcs = [
|
||||
"signature_def_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:signature_def_function.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":signature_def_function_metadata",
|
||||
":signature_def_function_metadata_type",
|
||||
":signature_def_function_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function_type",
|
||||
hdrs = [
|
||||
"signature_def_function_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function_metadata",
|
||||
srcs = [
|
||||
"signature_def_function_metadata.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":signature_def_function_metadata_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function_metadata_type",
|
||||
hdrs = [
|
||||
"signature_def_function_metadata_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
size = "small",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -106,9 +107,11 @@ TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
TF_CAPI_EXPORT extern TF_SignatureDefFunction*
|
||||
TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
|
||||
const char* signature_def_key,
|
||||
TF_Status* status) {
|
||||
tensorflow::SignatureDefFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
|
||||
&result);
|
||||
|
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
TF_SignatureDefFunctionMetadata* TF_SignatureDefFunctionGetMetadata(
|
||||
TF_SignatureDefFunction* func) {
|
||||
return tensorflow::wrap(const_cast<tensorflow::SignatureDefFunctionMetadata*>(
|
||||
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
||||
}
|
||||
|
||||
TFE_Op* TF_SignatureDefFunctionMakeCallOp(TF_SignatureDefFunction* func,
|
||||
TFE_TensorHandle** inputs,
|
||||
int num_inputs, TF_Status* status) {
|
||||
tensorflow::ImmediateOpPtr call_op;
|
||||
absl::Span<tensorflow::AbstractTensorHandle* const> input_span(
|
||||
reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(inputs)),
|
||||
static_cast<size_t>(num_inputs));
|
||||
status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::wrap(call_op.release());
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,20 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
|
||||
|
||||
// TODO(bmzhao): Add getter functions here as necessary.
|
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
|
||||
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefFunctionMetadata,
|
||||
TF_SignatureDefFunctionMetadata)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_
|
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
|
||||
typedef struct TF_SignatureDefFunction TF_SignatureDefFunction;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefFunction,
|
||||
TF_SignatureDefFunction)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_
|
@ -24,6 +24,8 @@ exports_files(
|
||||
"concrete_function_list.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
"signature_def_function.h",
|
||||
"signature_def_function_metadata.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||
)
|
||||
@ -39,6 +41,8 @@ cc_library(
|
||||
":concrete_function_list",
|
||||
":function_metadata",
|
||||
":saved_model_api",
|
||||
":signature_def_function",
|
||||
":signature_def_function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
@ -61,3 +65,13 @@ alias(
|
||||
name = "saved_model_api",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "signature_def_function",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "signature_def_function_metadata",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata",
|
||||
)
|
||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
|
@ -40,6 +40,13 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
|
||||
// The caller is responsible for deleting the returned TFE_Op. If op
|
||||
// construction fails, `status` will be non-OK and the returned pointer will be
|
||||
// null.
|
||||
// TODO(bmzhao): Remove this function in a subsequent change; Design + implement
|
||||
// a Function Execution interface for ConcreteFunction that accepts a tagged
|
||||
// union of types (tensorflow::Value). This effectively requires moving much of
|
||||
// the implementation of function.py/def_function.py to C++, and exposing a
|
||||
// high-level API here. A strawman for what this interface could look like:
|
||||
// TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value*
|
||||
// inputs, int num_inputs, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
|
||||
TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status);
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
@ -91,10 +92,13 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// TF_ConcreteFunction instance. Once `model` is deleted, all
|
||||
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status);
|
||||
// TF_SignatureDefFunction instance. Once `model` is deleted, all
|
||||
// `TF_SignatureDefFunctions` retrieved from it are invalid, and have been
|
||||
// deleted.
|
||||
TF_CAPI_EXPORT extern TF_SignatureDefFunction*
|
||||
TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
|
||||
const char* signature_def_key,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns a list of all ConcreteFunctions stored in this SavedModel.
|
||||
// The lifetime of the returned list is bound to `model`.
|
||||
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that corresponds to a SignatureDefFunction loaded from a
|
||||
// SavedModel.
|
||||
typedef struct TF_SignatureDefFunction TF_SignatureDefFunction;
|
||||
|
||||
// Returns FunctionMetadata associated with `func`. Metadata's lifetime is
|
||||
// bound to `func`, which is bound to the TF_SavedModel it was loaded from.
|
||||
TF_CAPI_EXPORT extern TF_SignatureDefFunctionMetadata*
|
||||
TF_SignatureDefFunctionGetMetadata(TF_SignatureDefFunction* func);
|
||||
|
||||
// Returns a TFE_Op suitable for executing this function. Caller must provide
|
||||
// all function inputs in `inputs`, and must not add any additional inputs on
|
||||
// the returned op. (i.e. don't call TFE_OpAddInput or TFE_OpAddInputList).
|
||||
// The caller is responsible for deleting the returned TFE_Op. If op
|
||||
// construction fails, `status` will be non-OK and the returned pointer will be
|
||||
// null.
|
||||
TF_CAPI_EXPORT extern TFE_Op* TF_SignatureDefFunctionMakeCallOp(
|
||||
TF_SignatureDefFunction* func, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_
|
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that corresponds to a SignatureDefFunction loaded from a
|
||||
// SavedModel.
|
||||
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
60
tensorflow/c/experimental/stream_executor/BUILD
Normal file
60
tensorflow/c/experimental/stream_executor/BUILD
Normal file
@ -0,0 +1,60 @@
|
||||
# Description:
|
||||
# StreamExecutor C API.
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_executor",
|
||||
srcs = ["stream_executor.cc"],
|
||||
hdrs = ["stream_executor.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":stream_executor_internal",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor:executor_cache",
|
||||
"//tensorflow/stream_executor:multi_platform_manager",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
"//tensorflow/stream_executor:stream_executor_internal",
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl",
|
||||
"//tensorflow/stream_executor:timer",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_executor_internal",
|
||||
hdrs = [
|
||||
"stream_executor.h",
|
||||
"stream_executor_internal.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/stream_executor:executor_cache",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "stream_executor_test",
|
||||
srcs = ["stream_executor_test.cc"],
|
||||
deps = [
|
||||
":stream_executor",
|
||||
":stream_executor_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/protobuf:error_codes_proto_impl_cc",
|
||||
"//tensorflow/stream_executor:multi_platform_manager",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl",
|
||||
],
|
||||
)
|
809
tensorflow/c/experimental/stream_executor/stream_executor.cc
Normal file
809
tensorflow/c/experimental/stream_executor/stream_executor.cc
Normal file
@ -0,0 +1,809 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// This file extends/implements core stream executor base classes in terms of
|
||||
// the C API defined in stream_executor.h. A class "CSomething" represents a
|
||||
// "Something" that can be manipulated via calls in the C interface and a C
|
||||
// struct called "SP_Something".
|
||||
//
|
||||
// This file also contains stream_executor::Platform registration for pluggable
|
||||
// device.
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/stream_executor/executor_cache.h"
|
||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
#include "tensorflow/stream_executor/timer.h"
|
||||
|
||||
using tensorflow::StatusFromTF_Status;
|
||||
|
||||
namespace stream_executor {
|
||||
namespace {
|
||||
|
||||
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
|
||||
do { \
|
||||
if (STRUCT_OBJ.struct_size == 0) { \
|
||||
return port::FailedPreconditionError( \
|
||||
"struct_size field in " #STRUCT_NAME \
|
||||
" must be set to " #SIZE_VALUE_NAME "."); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \
|
||||
do { \
|
||||
if (STRUCT_OBJ.NAME == 0) { \
|
||||
return port::FailedPreconditionError( \
|
||||
"'" #NAME "' field in " #STRUCT_NAME " must be set."); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
port::Status ValidateSPPlatform(const SP_Platform& platform) {
|
||||
VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, name);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, type);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, visible_device_count);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, create_device);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, destroy_device);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, create_stream_executor);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, destroy_stream_executor);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, create_timer_fns);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, destroy_timer_fns);
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) {
|
||||
VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE);
|
||||
VALIDATE_MEMBER(SP_TimerFns, timer_fns, nanoseconds);
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) {
|
||||
VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats, SP_ALLOCATORSTATS_STRUCT_SIZE);
|
||||
// All other fields could theoretically be zero/null.
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) {
|
||||
VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem,
|
||||
SP_DEVICE_MEMORY_BASE_STRUCT_SIZE);
|
||||
// All other fields could theoretically be zero/null.
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status ValidateSPDevice(const SP_Device& device) {
|
||||
VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE);
|
||||
// All other fields could theoretically be zero/null.
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se) {
|
||||
VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, allocate);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, deallocate);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, get_allocator_stats);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, device_memory_usage);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_stream);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream_dependency);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, get_stream_status);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, create_event);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_event);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, get_event_status);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, record_event);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, wait_for_event);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, create_timer);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_timer);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, start_timer);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, stop_timer);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_dtoh);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_htod);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_dtoh);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_htod);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, block_host_for_event);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, synchronize_all_activity);
|
||||
VALIDATE_MEMBER(SP_StreamExecutor, se, host_callback);
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status ValidateSEPlatformRegistrationParams(
|
||||
const SE_PlatformRegistrationParams& params) {
|
||||
VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params,
|
||||
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE);
|
||||
VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform);
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
#undef VALIDATE_MEMBER
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
// Converts SE_EventStatus to Event::Status.
|
||||
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
switch (s) {
|
||||
case SE_EVENT_ERROR:
|
||||
return Event::Status::kError;
|
||||
case SE_EVENT_PENDING:
|
||||
return Event::Status::kPending;
|
||||
case SE_EVENT_COMPLETE:
|
||||
return Event::Status::kComplete;
|
||||
default:
|
||||
return Event::Status::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
// Converts DeviceMemoryBase to a C struct.
|
||||
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
|
||||
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
|
||||
// `opaque` field inside SP_DeviceMemoryBase is not const.
|
||||
// Therefore, we need to cast away the constness before setting it.
|
||||
device_memory_base.opaque = const_cast<void*>(mem->opaque());
|
||||
device_memory_base.size = mem->size();
|
||||
device_memory_base.payload = mem->payload();
|
||||
// TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here.
|
||||
return device_memory_base;
|
||||
}
|
||||
|
||||
DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
|
||||
DeviceMemoryBase base(mem.opaque, mem.size);
|
||||
base.SetPayload(mem.payload);
|
||||
// TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here.
|
||||
return base;
|
||||
}
|
||||
|
||||
// Wrapper that allows passing std::function across C API.
|
||||
struct HostCallbackContext {
|
||||
std::function<port::Status()> callback;
|
||||
};
|
||||
|
||||
// This wrapper allows calling `HostCallbackContext::callback` across C API.
|
||||
// This function matches `SE_StatusCallbackFn` signature and will be passed as
|
||||
// `callback_fn` to `host_callback` in `SP_StreamExecutor`.
|
||||
void HostCallbackTrampoline(void* ctx, TF_Status* status) {
|
||||
HostCallbackContext* host_ctx = static_cast<HostCallbackContext*>(ctx);
|
||||
port::Status s = host_ctx->callback();
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
delete host_ctx;
|
||||
}
|
||||
|
||||
class CStreamExecutor : public internal::StreamExecutorInterface {
|
||||
public:
|
||||
explicit CStreamExecutor(SP_Device device,
|
||||
void (*destroy_device)(SP_Device* const device),
|
||||
SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns, const std::string& name,
|
||||
int visible_device_count)
|
||||
: device_(std::move(device)),
|
||||
destroy_device_(destroy_device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_fns_(timer_fns),
|
||||
platform_name_(name),
|
||||
visible_device_count_(visible_device_count) {}
|
||||
|
||||
~CStreamExecutor() override { destroy_device_(&device_); }
|
||||
|
||||
port::Status Init(int device_ordinal, DeviceOptions device_options) override {
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override {
|
||||
SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
|
||||
stream_executor_->allocate(&device_, size, memory_space, &mem);
|
||||
port::Status status = ValidateSPDeviceMemoryBase(mem);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status.error_message();
|
||||
}
|
||||
return DeviceMemoryBaseFromC(mem);
|
||||
}
|
||||
DeviceMemoryBase Allocate(uint64 size) {
|
||||
return Allocate(size, /*memory_space=*/0);
|
||||
}
|
||||
void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
|
||||
uint64 size) override {
|
||||
LOG(FATAL) << "GetSubBuffer is not supported by pluggable device.";
|
||||
}
|
||||
|
||||
void Deallocate(DeviceMemoryBase* mem) override {
|
||||
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem);
|
||||
stream_executor_->deallocate(&device_, &device_memory_base);
|
||||
}
|
||||
|
||||
void* HostMemoryAllocate(uint64 size) override {
|
||||
return stream_executor_->host_memory_allocate(&device_, size);
|
||||
}
|
||||
|
||||
void HostMemoryDeallocate(void* mem) override {
|
||||
stream_executor_->host_memory_deallocate(&device_, mem);
|
||||
}
|
||||
|
||||
bool HostMemoryRegister(void* mem, uint64 size) override { return false; }
|
||||
bool HostMemoryUnregister(void* mem) override { return false; }
|
||||
|
||||
absl::optional<AllocatorStats> GetAllocatorStats() override {
|
||||
SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE};
|
||||
TF_Bool has_stats =
|
||||
stream_executor_->get_allocator_stats(&device_, &c_stats);
|
||||
if (!has_stats) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
port::Status status = ValidateSPAllocatorStats(c_stats);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status.error_message();
|
||||
return absl::nullopt;
|
||||
}
|
||||
// TODO(annarev): validate SP_AllocatorStats.
|
||||
::stream_executor::AllocatorStats stats;
|
||||
stats.num_allocs = c_stats.num_allocs;
|
||||
stats.bytes_in_use = c_stats.bytes_in_use;
|
||||
stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
|
||||
stats.largest_alloc_size = c_stats.largest_alloc_size;
|
||||
if (c_stats.has_bytes_limit) {
|
||||
stats.bytes_limit = c_stats.bytes_limit;
|
||||
}
|
||||
stats.bytes_reserved = c_stats.bytes_reserved;
|
||||
stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
|
||||
if (c_stats.has_bytes_reservable_limit) {
|
||||
stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
|
||||
}
|
||||
stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
|
||||
return stats;
|
||||
}
|
||||
bool SynchronizeAllActivity() override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->synchronize_all_activity(&device_, c_status.get());
|
||||
if (TF_GetCode(c_status.get()) != TF_OK) {
|
||||
LOG(ERROR) << TF_Message(c_status.get());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
port::Status SynchronousMemZero(DeviceMemoryBase* location,
|
||||
uint64 size) override {
|
||||
// TODO(annarev): figure out if we should support memzero/memset
|
||||
// functionality by allocating on host and then copying to device.
|
||||
return port::UnimplementedError(
|
||||
"SynchronousMemZero is not supported by pluggable device.");
|
||||
}
|
||||
port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
|
||||
uint64 size) override {
|
||||
return port::UnimplementedError(
|
||||
"SynchronousMemSet is not supported by pluggable device.");
|
||||
}
|
||||
port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst,
|
||||
const void* host_src, uint64 size) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst);
|
||||
stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src,
|
||||
size, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
port::Status SynchronousMemcpy(void* host_dst,
|
||||
const DeviceMemoryBase& gpu_src,
|
||||
uint64 size) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src);
|
||||
stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base,
|
||||
size, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst,
|
||||
const DeviceMemoryBase& gpu_src,
|
||||
uint64 size) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
|
||||
SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
|
||||
stream_executor_->sync_memcpy_dtod(&device_, &device_mem_dst,
|
||||
&device_mem_src, size, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
|
||||
uint64 size) override {
|
||||
return port::UnimplementedError(
|
||||
"MemZero is not supported by pluggable device.");
|
||||
}
|
||||
port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern,
|
||||
uint64 size) override {
|
||||
return port::UnimplementedError(
|
||||
"Memset is not supported by pluggable device.");
|
||||
}
|
||||
port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
|
||||
uint32 pattern, uint64 size) override {
|
||||
return port::UnimplementedError(
|
||||
"Memset32 is not supported by pluggable device.");
|
||||
}
|
||||
bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src,
|
||||
uint64 size) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_Stream stream_handle =
|
||||
static_cast<CStream*>(stream->implementation())->Handle();
|
||||
SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
|
||||
stream_executor_->memcpy_dtoh(&device_, stream_handle, host_dst,
|
||||
&device_mem_src, size, c_status.get());
|
||||
if (TF_GetCode(c_status.get()) != TF_OK) {
|
||||
LOG(ERROR) << TF_Message(c_status.get());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src,
|
||||
uint64 size) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_Stream stream_handle =
|
||||
static_cast<CStream*>(stream->implementation())->Handle();
|
||||
SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
|
||||
stream_executor_->memcpy_htod(&device_, stream_handle, &device_mem_dst,
|
||||
host_src, size, c_status.get());
|
||||
if (TF_GetCode(c_status.get()) != TF_OK) {
|
||||
LOG(ERROR) << TF_Message(c_status.get());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst,
|
||||
const DeviceMemoryBase& gpu_src,
|
||||
uint64 size) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_Stream stream_handle =
|
||||
static_cast<CStream*>(stream->implementation())->Handle();
|
||||
SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
|
||||
SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
|
||||
stream_executor_->memcpy_dtod(&device_, stream_handle, &device_mem_dst,
|
||||
&device_mem_src, size, c_status.get());
|
||||
if (TF_GetCode(c_status.get()) != TF_OK) {
|
||||
LOG(ERROR) << TF_Message(c_status.get());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool HostCallback(Stream* stream,
|
||||
std::function<port::Status()> callback) override {
|
||||
SP_Stream stream_handle =
|
||||
static_cast<CStream*>(stream->implementation())->Handle();
|
||||
HostCallbackContext* ctx = new HostCallbackContext{callback};
|
||||
return stream_executor_->host_callback(&device_, stream_handle,
|
||||
&HostCallbackTrampoline, ctx);
|
||||
}
|
||||
port::Status AllocateEvent(Event* event) override {
|
||||
DCHECK(event != nullptr);
|
||||
return static_cast<CEvent*>(event->implementation())->Create();
|
||||
}
|
||||
port::Status DeallocateEvent(Event* event) override {
|
||||
static_cast<CEvent*>(event->implementation())->Destroy();
|
||||
return port::Status::OK();
|
||||
}
|
||||
port::Status RecordEvent(Stream* stream, Event* event) override {
|
||||
SP_Stream stream_handle =
|
||||
static_cast<CStream*>(stream->implementation())->Handle();
|
||||
return static_cast<CEvent*>(event->implementation())->Record(stream_handle);
|
||||
}
|
||||
port::Status WaitForEvent(Stream* stream, Event* event) override {
|
||||
SP_Stream stream_handle =
|
||||
static_cast<CStream*>(stream->implementation())->Handle();
|
||||
SP_Event event_handle =
|
||||
static_cast<CEvent*>(event->implementation())->Handle();
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->wait_for_event(&device_, stream_handle, event_handle,
|
||||
c_status.get());
|
||||
port::Status s = StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
Event::Status PollForEventStatus(Event* event) override {
|
||||
SP_Event event_handle =
|
||||
static_cast<CEvent*>(event->implementation())->Handle();
|
||||
SE_EventStatus event_status =
|
||||
stream_executor_->get_event_status(&device_, event_handle);
|
||||
return SEEventStatusToEventStatus(event_status);
|
||||
}
|
||||
bool AllocateStream(Stream* stream) override {
|
||||
DCHECK(stream != nullptr);
|
||||
port::Status status =
|
||||
static_cast<CStream*>(stream->implementation())->Create();
|
||||
// TODO(annarev): update AllocateStream to return status instead
|
||||
// (similar to AllocateEvent).
|
||||
return status.ok();
|
||||
}
|
||||
void DeallocateStream(Stream* stream) override {
|
||||
static_cast<CStream*>(stream->implementation())->Destroy();
|
||||
}
|
||||
bool CreateStreamDependency(Stream* dependent, Stream* other) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_Stream dependent_handle =
|
||||
static_cast<CStream*>(dependent->implementation())->Handle();
|
||||
SP_Stream other_handle =
|
||||
static_cast<CStream*>(other->implementation())->Handle();
|
||||
stream_executor_->create_stream_dependency(&device_, dependent_handle,
|
||||
other_handle, c_status.get());
|
||||
if (TF_GetCode(c_status.get()) != TF_OK) {
|
||||
LOG(ERROR) << TF_Message(c_status.get());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool AllocateTimer(Timer* timer) override {
|
||||
port::Status status =
|
||||
static_cast<CTimer*>(timer->implementation())->Create();
|
||||
// TODO(annarev): change return value of AllocateTimer
|
||||
// to status (similar to AllocateEvent).
|
||||
return status.ok();
|
||||
}
|
||||
void DeallocateTimer(Timer* timer) override {
|
||||
static_cast<CTimer*>(timer->implementation())->Destroy();
|
||||
}
|
||||
bool StartTimer(Stream* stream, Timer* timer) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_Stream stream_handle =
|
||||
static_cast<CStream*>(stream->implementation())->Handle();
|
||||
SP_Timer timer_handle =
|
||||
static_cast<CTimer*>(timer->implementation())->Handle();
|
||||
stream_executor_->start_timer(&device_, stream_handle, timer_handle,
|
||||
c_status.get());
|
||||
if (TF_GetCode(c_status.get()) != TF_OK) {
|
||||
LOG(ERROR) << TF_Message(c_status.get());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool StopTimer(Stream* stream, Timer* timer) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_Stream stream_handle =
|
||||
static_cast<CStream*>(stream->implementation())->Handle();
|
||||
SP_Timer timer_handle =
|
||||
static_cast<CTimer*>(timer->implementation())->Handle();
|
||||
stream_executor_->stop_timer(&device_, stream_handle, timer_handle,
|
||||
c_status.get());
|
||||
if (TF_GetCode(c_status.get()) != TF_OK) {
|
||||
LOG(ERROR) << TF_Message(c_status.get());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
port::Status BlockHostForEvent(Stream* stream, Event* event) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_Event event_handle =
|
||||
static_cast<CEvent*>(event->implementation())->Handle();
|
||||
stream_executor_->block_host_for_event(&device_, event_handle,
|
||||
c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status BlockHostUntilDone(Stream* stream) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_Event event_handle;
|
||||
stream_executor_->create_event(&device_, &event_handle, c_status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
|
||||
SP_Stream stream_handle =
|
||||
static_cast<CStream*>(stream->implementation())->Handle();
|
||||
stream_executor_->record_event(&device_, stream_handle, event_handle,
|
||||
c_status.get());
|
||||
port::Status s = StatusFromTF_Status(c_status.get());
|
||||
if (!s.ok()) {
|
||||
stream_executor_->destroy_event(&device_, event_handle);
|
||||
return s;
|
||||
}
|
||||
stream_executor_->block_host_for_event(&device_, event_handle,
|
||||
c_status.get());
|
||||
stream_executor_->destroy_event(&device_, event_handle);
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status GetStatus(Stream* stream) override {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
SP_Stream stream_handle =
|
||||
static_cast<CStream*>(stream->implementation())->Handle();
|
||||
stream_executor_->get_stream_status(&device_, stream_handle,
|
||||
c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
int PlatformDeviceCount() override { return visible_device_count_; }
|
||||
port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
|
||||
return port::UnimplementedError(
|
||||
"EnablePeerAccessTo is not supported by pluggable device.");
|
||||
}
|
||||
bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DeviceMemoryUsage(int64* free, int64* total) const override {
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
return stream_executor_->device_memory_usage(
|
||||
&device_, reinterpret_cast<int64_t*>(free),
|
||||
reinterpret_cast<int64_t*>(total));
|
||||
}
|
||||
|
||||
// Creates a new DeviceDescription object.
|
||||
// Ownership is transferred to the caller.
|
||||
port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
|
||||
const override {
|
||||
// TODO(annarev): Figure out if we need to support more description fields.
|
||||
internal::DeviceDescriptionBuilder builder;
|
||||
builder.set_name(platform_name_);
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
// Each call creates a new instance of the platform-specific implementation of
|
||||
// the corresponding interface type.
|
||||
std::unique_ptr<internal::EventInterface> CreateEventImplementation()
|
||||
override {
|
||||
return std::unique_ptr<internal::EventInterface>(
|
||||
new CEvent(&device_, stream_executor_));
|
||||
}
|
||||
std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
|
||||
override {
|
||||
LOG(FATAL)
|
||||
<< "CreateKernelImplementation is not supported by pluggable device.";
|
||||
}
|
||||
std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
|
||||
override {
|
||||
return std::unique_ptr<internal::StreamInterface>(
|
||||
new CStream(&device_, stream_executor_));
|
||||
}
|
||||
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
|
||||
return std::unique_ptr<internal::TimerInterface>(
|
||||
new CTimer(&device_, stream_executor_, timer_fns_));
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device device_;
|
||||
void (*destroy_device_)(SP_Device* const device);
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
std::string platform_name_;
|
||||
int visible_device_count_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
CPlatform::CPlatform(SP_Platform platform,
|
||||
void (*destroy_platform)(SP_Platform*),
|
||||
SP_StreamExecutor stream_executor, SP_TimerFns timer_fns)
|
||||
: platform_(std::move(platform)),
|
||||
destroy_platform_(destroy_platform),
|
||||
stream_executor_(std::move(stream_executor)),
|
||||
timer_fns_(std::move(timer_fns)),
|
||||
name_(platform.name) {}
|
||||
|
||||
CPlatform::~CPlatform() {
|
||||
executor_cache_.DestroyAllExecutors();
|
||||
platform_.destroy_stream_executor(&stream_executor_);
|
||||
platform_.destroy_timer_fns(&timer_fns_);
|
||||
destroy_platform_(&platform_);
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<DeviceDescription>>
|
||||
CPlatform::DescriptionForDevice(int ordinal) const {
|
||||
// TODO(annarev): see if we can get StreamExecutor instance
|
||||
// and call GetDeviceDescription. executor_cache_.Get would need
|
||||
// to be made const for it to work.
|
||||
internal::DeviceDescriptionBuilder builder;
|
||||
builder.set_name(name_);
|
||||
return builder.Build();
|
||||
}
|
||||
port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDevice(int ordinal) {
|
||||
stream_executor::StreamExecutorConfig config;
|
||||
config.ordinal = ordinal;
|
||||
return GetExecutor(config);
|
||||
}
|
||||
port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDeviceWithPluginConfig(
|
||||
int ordinal, const PluginConfig& plugin_config) {
|
||||
StreamExecutorConfig config;
|
||||
config.ordinal = ordinal;
|
||||
config.plugin_config = plugin_config;
|
||||
return GetExecutor(config);
|
||||
}
|
||||
port::StatusOr<StreamExecutor*> CPlatform::GetExecutor(
|
||||
const StreamExecutorConfig& config) {
|
||||
return executor_cache_.GetOrCreate(
|
||||
config, [&]() { return GetUncachedExecutor(config); });
|
||||
}
|
||||
port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
|
||||
const StreamExecutorConfig& config) {
|
||||
// Fill device creation params
|
||||
SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE};
|
||||
SP_Device device{SP_DEVICE_STRUCT_SIZE};
|
||||
device_params.device = &device;
|
||||
device_params.ext = nullptr;
|
||||
device_params.ordinal = config.ordinal;
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
|
||||
// Create Device
|
||||
platform_.create_device(&device_params, c_status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
|
||||
TF_RETURN_IF_ERROR(ValidateSPDevice(device));
|
||||
|
||||
auto executor = absl::make_unique<CStreamExecutor>(
|
||||
std::move(device), platform_.destroy_device, &stream_executor_,
|
||||
&timer_fns_, name_, platform_.visible_device_count);
|
||||
auto result = absl::make_unique<StreamExecutor>(this, std::move(executor),
|
||||
config.ordinal);
|
||||
return result;
|
||||
}
|
||||
|
||||
port::Status RegisterDevicePlugin(const std::string& dso_path) {
|
||||
// Step 1: Load plugin
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
void* dso_handle;
|
||||
TF_RETURN_IF_ERROR(env->LoadDynamicLibrary(dso_path.c_str(), &dso_handle));
|
||||
|
||||
// Step 2: Load symbol for `TF_InitPlugin`
|
||||
void* dso_symbol;
|
||||
TF_RETURN_IF_ERROR(
|
||||
env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol));
|
||||
|
||||
// Step 3: Call `TF_InitPlugin`
|
||||
auto init_fn = reinterpret_cast<SEPluginInitFn>(dso_symbol);
|
||||
return RegisterDevicePlugin(init_fn);
|
||||
}
|
||||
|
||||
port::Status RegisterDevicePlugin(SEPluginInitFn init_fn) {
|
||||
SE_PlatformRegistrationParams params{
|
||||
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
|
||||
SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
|
||||
params.major_version = SE_MAJOR;
|
||||
params.minor_version = SE_MINOR;
|
||||
params.revision_version = SE_REVISION;
|
||||
params.platform = &platform;
|
||||
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
init_fn(¶ms, c_status.get());
|
||||
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
|
||||
TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params));
|
||||
TF_RETURN_IF_ERROR(ValidateSPPlatform(platform));
|
||||
|
||||
// Fill stream executor creation params
|
||||
SE_CreateStreamExecutorParams se_params{
|
||||
SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE};
|
||||
SP_StreamExecutor se{SP_STREAMEXECUTOR_STRUCT_SIZE};
|
||||
se_params.stream_executor = &se;
|
||||
|
||||
// Create StreamExecutor
|
||||
platform.create_stream_executor(&se_params, c_status.get());
|
||||
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
|
||||
TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se));
|
||||
|
||||
SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE};
|
||||
platform.create_timer_fns(&timer_fns, c_status.get());
|
||||
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
|
||||
TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
|
||||
|
||||
// Register new platform
|
||||
std::string platform_name = std::string(platform.name);
|
||||
std::unique_ptr<stream_executor::CPlatform> cplatform(
|
||||
new stream_executor::CPlatform(std::move(platform),
|
||||
params.destroy_platform, std::move(se),
|
||||
std::move(timer_fns)));
|
||||
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
|
||||
std::move(cplatform)));
|
||||
|
||||
// TODO(annarev): Add pluggable device registration here.
|
||||
return port::Status::OK();
|
||||
}
|
||||
} // namespace stream_executor
|
395
tensorflow/c/experimental/stream_executor/stream_executor.h
Normal file
395
tensorflow/c/experimental/stream_executor/stream_executor.h
Normal file
@ -0,0 +1,395 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for StreamExecutor. The API is under active development and eventually
|
||||
// should allow registering a pluggable device with TensorFlow.
|
||||
//
|
||||
// Conventions:
|
||||
// * Struct prefix indicates whether struct fields should be filled by the
|
||||
// plugin or core implementation:
|
||||
// * SE_ : set/filled by core unless explicitly marked otherwise.
|
||||
// * SP_ : set/filled by plugin unless explicitly marked otherwise.
|
||||
// * We use `struct_size` for version checking. It is exempt from the `SE/SP`
|
||||
// rule above and should be set both by core and the plugin.
|
||||
// * For example, `create_device` function receives `SP_Device*` as input
|
||||
// with `struct_size` populated by core. The plugin is responsible for
|
||||
// setting `struct_size` as well, along with all other fields.
|
||||
// * Refer to "TensorFlow Versioning Strategy" section at
|
||||
// https://github.com/tensorflow/community/pull/257/files.
|
||||
// * Note that the API is still under active development and doesn't have
|
||||
// versioning guarantees yet.
|
||||
// * `void* ext` is a free-form field that can be populated by
|
||||
// a plugin in `SP_*` structs or potential future extension points in `SE_`
|
||||
// structs.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// /* Sample TensorFlow code below, exact implementation might differ. */
|
||||
// // Version checking uses `struct_size`. It is exempt from the `SE/SP` rule
|
||||
// // above and should be set both by core and the plugin."
|
||||
// SP_Device device { SP_DEVICE_STRUCT_SIZE };
|
||||
// SE_CreateDeviceParams params { SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE } ;
|
||||
// params.device = &device;
|
||||
//
|
||||
// /* Plugin code below */
|
||||
// constexpr char DEVICE_NAME[] = "MyDevice";
|
||||
// constexpr char DEVICE_TYPE[] = "GPU";
|
||||
//
|
||||
// void create_device(SE_CreateDeviceParams* params, TF_Status* status) {
|
||||
// // Custom actions based on TensorFlow's view of SP_Device.
|
||||
// OnTFDeviceView(params->device->struct_size);
|
||||
// params->device = { SP_DEVICE_STRUCT_SIZE };
|
||||
// params->device->device_handle = get_my_device_handle(device->ordinal);
|
||||
// params->device->ordinal = params->ordinal;
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// void destroy_device(SP_Device* device) {
|
||||
// delete_my_device_handle(device->device_handle);
|
||||
// }
|
||||
//
|
||||
// void SE_InitPlugin(
|
||||
// SE_PlatformRegistrationParams* params,
|
||||
// TF_Status* status) {
|
||||
// params->platform = { SP_PLATFORM_STRUCT_SIZE };
|
||||
// // Values such as `name` and `type` must outlive SE_InitPlugin call.
|
||||
// params->platform->name = DEVICE_NAME;
|
||||
// params->platform->type = DEVICE_TYPE;
|
||||
// params->platform->visible_device_count = 2;
|
||||
// params->platform->create_device = create_device;
|
||||
// params->platform->destroy_device = destroy_device;
|
||||
// ...
|
||||
// }
|
||||
|
||||
#define SE_MAJOR 0
|
||||
#define SE_MINOR 0
|
||||
#define SE_REVISION 1
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct SP_Stream_st* SP_Stream;
|
||||
typedef struct SP_Event_st* SP_Event;
|
||||
typedef struct SP_Timer_st* SP_Timer;
|
||||
// Takes `callback_arg` passed to `host_callback` as the first argument.
|
||||
typedef void (*SE_StatusCallbackFn)(void* const, TF_Status* const);
|
||||
|
||||
typedef struct SP_TimerFns {
|
||||
size_t struct_size;
|
||||
void* ext; // reserved for future use
|
||||
uint64_t (*nanoseconds)(SP_Timer timer);
|
||||
} SP_TimerFns;
|
||||
|
||||
#define SP_TIMER_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_TimerFns, nanoseconds)
|
||||
|
||||
typedef struct SP_AllocatorStats {
|
||||
size_t struct_size;
|
||||
int64_t num_allocs;
|
||||
int64_t bytes_in_use;
|
||||
int64_t peak_bytes_in_use;
|
||||
int64_t largest_alloc_size;
|
||||
|
||||
int8_t has_bytes_limit;
|
||||
int64_t bytes_limit;
|
||||
|
||||
int64_t bytes_reserved;
|
||||
int64_t peak_bytes_reserved;
|
||||
|
||||
int8_t has_bytes_reservable_limit;
|
||||
int64_t bytes_reservable_limit;
|
||||
|
||||
int64_t largest_free_block_bytes;
|
||||
} SP_AllocatorStats;
|
||||
|
||||
#define SP_ALLOCATORSTATS_STRUCT_SIZE \
|
||||
TF_OFFSET_OF_END(SP_AllocatorStats, largest_free_block_bytes)
|
||||
|
||||
// Potential states for an SP_Event. If `poll_for_status` returns anything aside
|
||||
// from kPending or kComplete, an error has occurred; kUnknown is a bad state.
|
||||
typedef enum SE_EventStatus {
|
||||
SE_EVENT_UNKNOWN,
|
||||
SE_EVENT_ERROR,
|
||||
SE_EVENT_PENDING,
|
||||
SE_EVENT_COMPLETE,
|
||||
} SE_EventStatus;
|
||||
|
||||
// Memory allocation information.
|
||||
// This matches DeviceMemoryBase defined here:
|
||||
// https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57
|
||||
typedef struct SP_DeviceMemoryBase {
|
||||
size_t struct_size;
|
||||
void* ext; // free-form data set by plugin
|
||||
// Platform-dependent value representing allocated memory.
|
||||
void* opaque;
|
||||
uint64_t size; // Size in bytes of this allocation.
|
||||
uint64_t payload; // Value for plugin's use
|
||||
} SP_DeviceMemoryBase;
|
||||
|
||||
#define SP_DEVICE_MEMORY_BASE_STRUCT_SIZE \
|
||||
TF_OFFSET_OF_END(SP_DeviceMemoryBase, size)
|
||||
|
||||
typedef struct SP_Device {
|
||||
size_t struct_size;
|
||||
void* ext; // free-form data set by plugin
|
||||
int32_t ordinal; // device index
|
||||
|
||||
// Device vendor can store handle to their device representation
|
||||
// here.
|
||||
void* device_handle;
|
||||
} SP_Device;
|
||||
|
||||
#define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, device_handle)
|
||||
|
||||
typedef struct SE_CreateDeviceParams {
|
||||
size_t struct_size;
|
||||
void* ext; // reserved for future use
|
||||
int32_t ordinal; // device index
|
||||
|
||||
SP_Device* device; // Input/output, struct_size set by TF for plugin to read.
|
||||
// Subsequently plugin fills the entire struct.
|
||||
} SE_CreateDeviceParams;
|
||||
|
||||
#define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \
|
||||
TF_OFFSET_OF_END(SE_CreateDeviceParams, device)
|
||||
|
||||
typedef struct SP_StreamExecutor {
|
||||
size_t struct_size;
|
||||
void* ext; // reserved for future use
|
||||
|
||||
/*** ALLOCATION CALLBACKS ***/
|
||||
// Synchronously allocates `size` bytes on the underlying platform and returns
|
||||
// `SP_DeviceMemoryBase` representing that allocation. In the case of failure,
|
||||
// nullptr is returned.
|
||||
// `memory_space` is reserved for a potential future usage and should be set
|
||||
// to 0.
|
||||
void (*allocate)(const SP_Device* device, uint64_t size, int64_t memory_space,
|
||||
SP_DeviceMemoryBase* mem);
|
||||
|
||||
// Deallocate the device memory previously allocated via this interface.
|
||||
// Deallocation of a nullptr-representative value is permitted.
|
||||
void (*deallocate)(const SP_Device* device, SP_DeviceMemoryBase* memory);
|
||||
|
||||
// Allocates a region of host memory and registers it with the platform API.
|
||||
// Memory allocated in this manner is required for use in asynchronous memcpy
|
||||
// operations, such as `memcpy_dtoh`.
|
||||
void* (*host_memory_allocate)(const SP_Device* device, uint64_t size);
|
||||
|
||||
// Deallocates a region of host memory allocated by `host_memory_allocate`.
|
||||
void (*host_memory_deallocate)(const SP_Device* device, void* mem);
|
||||
|
||||
// Fills SP_AllocatorStats with allocator statistics, if it is available.
|
||||
// If it is not available, return false.
|
||||
TF_Bool (*get_allocator_stats)(const SP_Device* device,
|
||||
SP_AllocatorStats* stats);
|
||||
// Fills the underlying device memory usage information, if it is
|
||||
// available. If it is not available (false is returned), free/total need not
|
||||
// be initialized.
|
||||
TF_Bool (*device_memory_usage)(const SP_Device* device, int64_t* free,
|
||||
int64_t* total);
|
||||
|
||||
/*** STREAM CALLBACKS ***/
|
||||
// Creates SP_Stream. This call should also allocate stream
|
||||
// resources on the underlying platform and initializes its
|
||||
// internals.
|
||||
void (*create_stream)(const SP_Device* device, SP_Stream* stream,
|
||||
TF_Status* status);
|
||||
|
||||
// Destroys SP_Stream and deallocates any underlying resources.
|
||||
void (*destroy_stream)(const SP_Device* device, SP_Stream stream);
|
||||
|
||||
// Causes `dependent` to not begin execution until `other` has finished its
|
||||
// last-enqueued work.
|
||||
void (*create_stream_dependency)(const SP_Device* device, SP_Stream dependent,
|
||||
SP_Stream other, TF_Status* status);
|
||||
|
||||
// Without blocking the device, retrieve the current stream status.
|
||||
void (*get_stream_status)(const SP_Device* device, SP_Stream stream,
|
||||
TF_Status* status);
|
||||
|
||||
/*** EVENT CALLBACKS ***/
|
||||
// Create SP_Event. Performs platform-specific allocation and initialization
|
||||
// of an event.
|
||||
void (*create_event)(const SP_Device* device, SP_Event* event,
|
||||
TF_Status* status);
|
||||
|
||||
// Destroy SE_Event and perform any platform-specific deallocation and
|
||||
// cleanup of an event.
|
||||
void (*destroy_event)(const SP_Device* device, SP_Event event);
|
||||
|
||||
// Requests the current status of the event from the underlying platform.
|
||||
SE_EventStatus (*get_event_status)(const SP_Device* device, SP_Event event);
|
||||
// Inserts the specified event at the end of the specified stream.
|
||||
void (*record_event)(const SP_Device* device, SP_Stream stream,
|
||||
SP_Event event, TF_Status* status);
|
||||
|
||||
// Wait for the specified event at the end of the specified stream.
|
||||
void (*wait_for_event)(const SP_Device* const device, SP_Stream stream,
|
||||
SP_Event event, TF_Status* const status);
|
||||
|
||||
/*** TIMER CALLBACKS ***/
|
||||
// Creates SP_Timer. Allocates timer resources on the underlying platform
|
||||
// and initializes its internals, setting `timer` output variable. Sets
|
||||
// values in `timer_fns` struct.
|
||||
void (*create_timer)(const SP_Device* device, SP_Timer* timer,
|
||||
TF_Status* status);
|
||||
|
||||
// Destroy timer and deallocates timer resources on the underlying platform.
|
||||
void (*destroy_timer)(const SP_Device* device, SP_Timer timer);
|
||||
|
||||
// Records a start event for an interval timer.
|
||||
void (*start_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer,
|
||||
TF_Status* status);
|
||||
|
||||
// Records a stop event for an interval timer.
|
||||
void (*stop_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer,
|
||||
TF_Status* status);
|
||||
|
||||
/*** MEMCPY CALLBACKS ***/
|
||||
// Enqueues a memcpy operation onto stream, with a host destination location
|
||||
// `host_dst` and a device memory source, with target size `size`.
|
||||
void (*memcpy_dtoh)(const SP_Device* device, SP_Stream stream, void* host_dst,
|
||||
const SP_DeviceMemoryBase* device_src, uint64_t size,
|
||||
TF_Status* status);
|
||||
|
||||
// Enqueues a memcpy operation onto stream, with a device destination
|
||||
// location and a host memory source, with target size `size`.
|
||||
void (*memcpy_htod)(const SP_Device* device, SP_Stream stream,
|
||||
SP_DeviceMemoryBase* device_dst, const void* host_src,
|
||||
uint64_t size, TF_Status* status);
|
||||
|
||||
// Enqueues a memcpy operation onto stream, with a device destination
|
||||
// location and a device memory source, with target size `size`.
|
||||
void (*memcpy_dtod)(const SP_Device* device, SP_Stream stream,
|
||||
SP_DeviceMemoryBase* device_dst,
|
||||
const SP_DeviceMemoryBase* device_src, uint64_t size,
|
||||
TF_Status* status);
|
||||
|
||||
// Blocks the caller while a data segment of the given size is
|
||||
// copied from the device source to the host destination.
|
||||
void (*sync_memcpy_dtoh)(const SP_Device* device, void* host_dst,
|
||||
const SP_DeviceMemoryBase* device_src, uint64_t size,
|
||||
TF_Status* status);
|
||||
|
||||
// Blocks the caller while a data segment of the given size is
|
||||
// copied from the host source to the device destination.
|
||||
void (*sync_memcpy_htod)(const SP_Device* device,
|
||||
SP_DeviceMemoryBase* device_dst,
|
||||
const void* host_src, uint64_t size,
|
||||
TF_Status* status);
|
||||
|
||||
// Blocks the caller while a data segment of the given size is copied from the
|
||||
// device source to the device destination.
|
||||
void (*sync_memcpy_dtod)(const SP_Device* device,
|
||||
SP_DeviceMemoryBase* device_dst,
|
||||
const SP_DeviceMemoryBase* device_src, uint64_t size,
|
||||
TF_Status* status);
|
||||
|
||||
// Causes the host code to synchronously wait for the event to complete.
|
||||
void (*block_host_for_event)(const SP_Device* device, SP_Event event,
|
||||
TF_Status* status);
|
||||
|
||||
// Synchronizes all activity occurring in the StreamExecutor's context (most
|
||||
// likely a whole device).
|
||||
void (*synchronize_all_activity)(const SP_Device* device, TF_Status* status);
|
||||
|
||||
// Enqueues on a stream a user-specified function to be run on the host.
|
||||
// `callback_arg` should be passed as the first argument to `callback_fn`.
|
||||
TF_Bool (*host_callback)(SP_Device* device, SP_Stream stream,
|
||||
SE_StatusCallbackFn callback_fn, void* callback_arg);
|
||||
} SP_StreamExecutor;
|
||||
|
||||
#define SP_STREAMEXECUTOR_STRUCT_SIZE \
|
||||
TF_OFFSET_OF_END(SP_StreamExecutor, host_callback)
|
||||
|
||||
typedef struct SE_CreateStreamExecutorParams {
|
||||
size_t struct_size;
|
||||
void* ext; // reserved for future use
|
||||
|
||||
SP_StreamExecutor* stream_executor; // output, to be filled by plugin
|
||||
} SE_CreateStreamExecutorParams;
|
||||
|
||||
#define SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE \
|
||||
TF_OFFSET_OF_END(SE_CreateStreamExecutorParams, stream_executor)
|
||||
|
||||
typedef struct SP_Platform {
|
||||
size_t struct_size;
|
||||
|
||||
void* ext; // free-form data set by plugin
|
||||
|
||||
// Platform name. Must be null-terminated.
|
||||
const char* name;
|
||||
|
||||
// Device type name, for example GPU. Must be null-terminated.
|
||||
const char* type;
|
||||
|
||||
// Number of visible devices
|
||||
size_t visible_device_count;
|
||||
|
||||
// Callbacks for creating/destroying SP_Device.
|
||||
void (*create_device)(SE_CreateDeviceParams* params, TF_Status* status);
|
||||
|
||||
// Clean up fields inside SP_Device that were allocated
|
||||
// by the plugin. `device` itself should not be deleted here.
|
||||
void (*destroy_device)(SP_Device* device);
|
||||
|
||||
// Callbacks for creating/destroying SP_StreamExecutor.
|
||||
void (*create_stream_executor)(SE_CreateStreamExecutorParams* params,
|
||||
TF_Status* status);
|
||||
// Clean up fields inside SP_StreamExecutor that were allocated
|
||||
// by the plugin. `stream_executor` itself should not be deleted here.
|
||||
void (*destroy_stream_executor)(SP_StreamExecutor* stream_executor);
|
||||
|
||||
// Callbacks for creating/destroying SP_TimerFns.
|
||||
void (*create_timer_fns)(SP_TimerFns* timer, TF_Status* status);
|
||||
|
||||
void (*destroy_timer_fns)(SP_TimerFns* timer_fns);
|
||||
} SP_Platform;
|
||||
|
||||
#define SP_PLATFORM_STRUCT_SIZE TF_OFFSET_OF_END(SP_Platform, destroy_timer_fns)
|
||||
|
||||
typedef struct SE_PlatformRegistrationParams {
|
||||
size_t struct_size;
|
||||
void* ext; // reserved for future use
|
||||
|
||||
// StreamExecutor C API version.
|
||||
int32_t major_version;
|
||||
int32_t minor_version;
|
||||
int32_t revision_version;
|
||||
|
||||
SP_Platform* platform; // output, set by plugin
|
||||
// Clean up fields inside SP_Platform that were allocated
|
||||
// by the plugin. `platform` itself should not be deleted here.
|
||||
void (*destroy_platform)(SP_Platform* platform); // out, set by plugin
|
||||
} SE_PlatformRegistrationParams;
|
||||
|
||||
#define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \
|
||||
TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform)
|
||||
|
||||
void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
|
@ -0,0 +1,80 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Classes and utilities that work with StreamExecutor C API for internal use.
|
||||
// This includes functions used for device registration and interfaces needed
|
||||
// for testing.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/stream_executor/executor_cache.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace stream_executor {
|
||||
|
||||
// Plugin initialization function that a device plugin
|
||||
// must define.
|
||||
typedef void (*SEPluginInitFn)(SE_PlatformRegistrationParams* const,
|
||||
TF_Status* const);
|
||||
|
||||
// Loads dso and registers StreamExecutor-based pluggable device.
|
||||
port::Status RegisterDevicePlugin(const std::string& dso_path);
|
||||
|
||||
// Allow registering a plugin using a function (used for testing).
|
||||
port::Status RegisterDevicePlugin(SEPluginInitFn init_fn);
|
||||
|
||||
class CPlatform : public Platform {
|
||||
public:
|
||||
explicit CPlatform(SP_Platform platform,
|
||||
void (*destroy_platform)(SP_Platform*),
|
||||
SP_StreamExecutor stream_executor, SP_TimerFns timer_fns);
|
||||
~CPlatform() override;
|
||||
|
||||
Id id() const override { return const_cast<int*>(&plugin_id_value_); }
|
||||
const std::string& Name() const override { return name_; }
|
||||
int VisibleDeviceCount() const override {
|
||||
return platform_.visible_device_count;
|
||||
}
|
||||
port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
|
||||
int ordinal) const override;
|
||||
port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
|
||||
port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
|
||||
int ordinal, const PluginConfig& plugin_config) override;
|
||||
port::StatusOr<StreamExecutor*> GetExecutor(
|
||||
const StreamExecutorConfig& config) override;
|
||||
port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
|
||||
const StreamExecutorConfig& config) override;
|
||||
|
||||
// Trace listener is not supported
|
||||
void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override {
|
||||
LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device";
|
||||
}
|
||||
void UnregisterTraceListener(TraceListener* listener) override {}
|
||||
|
||||
void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); }
|
||||
|
||||
private:
|
||||
SP_Platform platform_;
|
||||
void (*destroy_platform_)(SP_Platform*);
|
||||
SP_StreamExecutor stream_executor_;
|
||||
SP_TimerFns timer_fns_;
|
||||
const std::string name_;
|
||||
int plugin_id_value_;
|
||||
stream_executor::ExecutorCache executor_cache_;
|
||||
};
|
||||
|
||||
} // namespace stream_executor
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
@ -0,0 +1,802 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0(the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/stream_executor/event.h"
|
||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
#include "tensorflow/stream_executor/timer.h"
|
||||
|
||||
struct SP_Stream_st {
|
||||
explicit SP_Stream_st(int id) : stream_id(id) {}
|
||||
int stream_id;
|
||||
};
|
||||
|
||||
struct SP_Event_st {
|
||||
explicit SP_Event_st(int id) : event_id(id) {}
|
||||
int event_id;
|
||||
};
|
||||
|
||||
struct SP_Timer_st {
|
||||
explicit SP_Timer_st(int id) : timer_id(id) {}
|
||||
int timer_id;
|
||||
};
|
||||
|
||||
namespace stream_executor {
|
||||
namespace {
|
||||
constexpr int DEVICE_COUNT = 2;
|
||||
constexpr char DEVICE_NAME[] = "MyDevice";
|
||||
constexpr char DEVICE_TYPE[] = "GPU";
|
||||
|
||||
/*** Create SP_StreamExecutor (with empty functions) ***/
|
||||
void allocate(const SP_Device* const device, uint64_t size,
|
||||
int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
|
||||
void deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
|
||||
}
|
||||
TF_Bool get_allocator_stats(const SP_Device* const device,
|
||||
SP_AllocatorStats* const stats) {
|
||||
return true;
|
||||
}
|
||||
TF_Bool device_memory_usage(const SP_Device* const device, int64_t* const free,
|
||||
int64_t* const total) {
|
||||
return true;
|
||||
}
|
||||
void create_stream(const SP_Device* const device, SP_Stream* stream,
|
||||
TF_Status* const status) {
|
||||
stream = nullptr;
|
||||
}
|
||||
void destroy_stream(const SP_Device* const device, SP_Stream stream) {}
|
||||
void create_stream_dependency(const SP_Device* const device,
|
||||
SP_Stream dependent, SP_Stream other,
|
||||
TF_Status* const status) {}
|
||||
void get_stream_status(const SP_Device* const device, SP_Stream stream,
|
||||
TF_Status* const status) {}
|
||||
void create_event(const SP_Device* const device, SP_Event* event,
|
||||
TF_Status* const status) {
|
||||
event = nullptr;
|
||||
}
|
||||
void destroy_event(const SP_Device* const device, SP_Event event) {}
|
||||
SE_EventStatus get_event_status(const SP_Device* const device, SP_Event event) {
|
||||
return SE_EVENT_UNKNOWN;
|
||||
}
|
||||
void record_event(const SP_Device* const device, SP_Stream stream,
|
||||
SP_Event event, TF_Status* const status) {}
|
||||
void wait_for_event(const SP_Device* const device, SP_Stream stream,
|
||||
SP_Event event, TF_Status* const status) {}
|
||||
void create_timer(const SP_Device* const device, SP_Timer* timer,
|
||||
TF_Status* const status) {}
|
||||
void destroy_timer(const SP_Device* const device, SP_Timer timer) {}
|
||||
void start_timer(const SP_Device* const device, SP_Stream stream,
|
||||
SP_Timer timer, TF_Status* const status) {}
|
||||
void stop_timer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
|
||||
TF_Status* const status) {}
|
||||
void memcpy_dtoh(const SP_Device* const device, SP_Stream stream,
|
||||
void* host_dst, const SP_DeviceMemoryBase* const device_src,
|
||||
uint64_t size, TF_Status* const status) {}
|
||||
void memcpy_htod(const SP_Device* const device, SP_Stream stream,
|
||||
SP_DeviceMemoryBase* const device_dst, const void* host_src,
|
||||
uint64_t size, TF_Status* const status) {}
|
||||
void sync_memcpy_dtoh(const SP_Device* const device, void* host_dst,
|
||||
const SP_DeviceMemoryBase* const device_src,
|
||||
uint64_t size, TF_Status* const status) {}
|
||||
void sync_memcpy_htod(const SP_Device* const device,
|
||||
SP_DeviceMemoryBase* const device_dst,
|
||||
const void* host_src, uint64_t size,
|
||||
TF_Status* const status) {}
|
||||
void block_host_for_event(const SP_Device* const device, SP_Event event,
|
||||
TF_Status* const status) {}
|
||||
void synchronize_all_activity(const SP_Device* const device,
|
||||
TF_Status* const status) {}
|
||||
TF_Bool host_callback(SP_Device* const device, SP_Stream stream,
|
||||
SE_StatusCallbackFn const callback_fn,
|
||||
void* const callback_arg) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
|
||||
se->struct_size = SP_STREAMEXECUTOR_STRUCT_SIZE;
|
||||
se->allocate = allocate;
|
||||
se->deallocate = deallocate;
|
||||
se->get_allocator_stats = get_allocator_stats;
|
||||
se->device_memory_usage = device_memory_usage;
|
||||
se->create_stream = create_stream;
|
||||
se->destroy_stream = destroy_stream;
|
||||
se->create_stream_dependency = create_stream_dependency;
|
||||
se->get_stream_status = get_stream_status;
|
||||
se->create_event = create_event;
|
||||
se->destroy_event = destroy_event;
|
||||
se->get_event_status = get_event_status;
|
||||
se->record_event = record_event;
|
||||
se->wait_for_event = wait_for_event;
|
||||
se->create_timer = create_timer;
|
||||
se->destroy_timer = destroy_timer;
|
||||
se->start_timer = start_timer;
|
||||
se->stop_timer = stop_timer;
|
||||
se->memcpy_dtoh = memcpy_dtoh;
|
||||
se->memcpy_htod = memcpy_htod;
|
||||
se->sync_memcpy_dtoh = sync_memcpy_dtoh;
|
||||
se->sync_memcpy_htod = sync_memcpy_htod;
|
||||
se->block_host_for_event = block_host_for_event;
|
||||
se->synchronize_all_activity = synchronize_all_activity;
|
||||
se->host_callback = host_callback;
|
||||
}
|
||||
|
||||
/*** Create SP_TimerFns ***/
|
||||
uint64_t nanoseconds(SP_Timer timer) { return timer->timer_id; }
|
||||
|
||||
void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
|
||||
timer_fns->nanoseconds = nanoseconds;
|
||||
}
|
||||
|
||||
/*** Create SP_Platform ***/
|
||||
void create_timer_fns(SP_TimerFns* timer_fns, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultTimerFns(timer_fns);
|
||||
}
|
||||
void destroy_timer_fns(SP_TimerFns* timer_fns) {}
|
||||
|
||||
void create_stream_executor(SE_CreateStreamExecutorParams* params,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultStreamExecutor(params->stream_executor);
|
||||
}
|
||||
void destroy_stream_executor(SP_StreamExecutor* se) {}
|
||||
|
||||
void create_device(SE_CreateDeviceParams* params, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
params->device->struct_size = SP_DEVICE_STRUCT_SIZE;
|
||||
}
|
||||
void destroy_device(SP_Device* device) {}
|
||||
|
||||
void PopulateDefaultPlatform(SP_Platform* platform) {
|
||||
platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
||||
platform->name = DEVICE_NAME;
|
||||
platform->type = DEVICE_TYPE;
|
||||
platform->visible_device_count = DEVICE_COUNT;
|
||||
platform->create_device = create_device;
|
||||
platform->destroy_device = destroy_device;
|
||||
platform->create_stream_executor = create_stream_executor;
|
||||
platform->destroy_stream_executor = destroy_stream_executor;
|
||||
platform->create_timer_fns = create_timer_fns;
|
||||
platform->destroy_timer_fns = destroy_timer_fns;
|
||||
}
|
||||
|
||||
void destroy_platform(SP_Platform* const platform) {}
|
||||
|
||||
/*** Registration tests ***/
|
||||
TEST(StreamExecutor, SuccessfulRegistration) {
|
||||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultPlatform(params->platform);
|
||||
params->destroy_platform = destroy_platform;
|
||||
};
|
||||
port::Status status = RegisterDevicePlugin(plugin_init);
|
||||
TF_ASSERT_OK(status);
|
||||
port::StatusOr<Platform*> maybe_platform =
|
||||
MultiPlatformManager::PlatformWithName("MyDevice");
|
||||
TF_ASSERT_OK(maybe_platform.status());
|
||||
Platform* platform = maybe_platform.ConsumeValueOrDie();
|
||||
ASSERT_EQ(platform->Name(), DEVICE_NAME);
|
||||
ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT);
|
||||
|
||||
port::StatusOr<StreamExecutor*> maybe_executor =
|
||||
platform->ExecutorForDevice(0);
|
||||
TF_ASSERT_OK(maybe_executor.status());
|
||||
StreamExecutor* executor = maybe_executor.ConsumeValueOrDie();
|
||||
ASSERT_EQ(executor->GetDeviceDescription().name(), "MyDevice");
|
||||
}
|
||||
|
||||
TEST(StreamExecutor, NameNotSet) {
|
||||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultPlatform(params->platform);
|
||||
params->platform->name = nullptr;
|
||||
params->destroy_platform = destroy_platform;
|
||||
};
|
||||
|
||||
port::Status status = RegisterDevicePlugin(plugin_init);
|
||||
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
|
||||
}
|
||||
|
||||
TEST(StreamExecutor, CreateDeviceNotSet) {
|
||||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultPlatform(params->platform);
|
||||
params->platform->create_device = nullptr;
|
||||
params->destroy_platform = destroy_platform;
|
||||
};
|
||||
|
||||
port::Status status = RegisterDevicePlugin(plugin_init);
|
||||
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||
ASSERT_EQ(status.error_message(),
|
||||
"'create_device' field in SP_Platform must be set.");
|
||||
}
|
||||
|
||||
/*** StreamExecutor behavior tests ***/
|
||||
class StreamExecutorTest : public ::testing::Test {
|
||||
protected:
|
||||
StreamExecutorTest() {}
|
||||
void SetUp() override {
|
||||
PopulateDefaultPlatform(&platform_);
|
||||
PopulateDefaultStreamExecutor(&se_);
|
||||
PopulateDefaultTimerFns(&timer_fns_);
|
||||
}
|
||||
void TearDown() override {}
|
||||
|
||||
StreamExecutor* GetExecutor(int ordinal) {
|
||||
if (!cplatform_) {
|
||||
cplatform_ = absl::make_unique<CPlatform>(platform_, destroy_platform,
|
||||
se_, timer_fns_);
|
||||
}
|
||||
port::StatusOr<StreamExecutor*> maybe_executor =
|
||||
cplatform_->ExecutorForDevice(ordinal);
|
||||
TF_CHECK_OK(maybe_executor.status());
|
||||
return maybe_executor.ConsumeValueOrDie();
|
||||
}
|
||||
SP_Platform platform_;
|
||||
SP_StreamExecutor se_;
|
||||
SP_TimerFns timer_fns_;
|
||||
std::unique_ptr<CPlatform> cplatform_;
|
||||
};
|
||||
|
||||
TEST_F(StreamExecutorTest, Allocate) {
|
||||
se_.allocate = [](const SP_Device* const device, uint64_t size,
|
||||
int64_t memory_space, SP_DeviceMemoryBase* const mem) {
|
||||
mem->struct_size = SP_DEVICE_MEMORY_BASE_STRUCT_SIZE;
|
||||
mem->opaque = std::malloc(size);
|
||||
mem->size = size;
|
||||
};
|
||||
se_.deallocate = [](const SP_Device* const device,
|
||||
SP_DeviceMemoryBase* const mem) {
|
||||
EXPECT_EQ(mem->size, 2 * sizeof(int));
|
||||
std::free(mem->opaque);
|
||||
mem->opaque = nullptr;
|
||||
mem->size = 0;
|
||||
};
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
DeviceMemory<int> mem = executor->AllocateArray<int>(2);
|
||||
ASSERT_NE(mem.opaque(), nullptr);
|
||||
ASSERT_EQ(mem.size(), 2 * sizeof(int));
|
||||
executor->Deallocate(&mem);
|
||||
ASSERT_EQ(mem.opaque(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, HostMemoryAllocate) {
|
||||
static bool allocate_called = false;
|
||||
static bool deallocate_called = false;
|
||||
se_.host_memory_allocate = [](const SP_Device* const device, uint64_t size) {
|
||||
allocate_called = true;
|
||||
return std::malloc(size);
|
||||
};
|
||||
se_.host_memory_deallocate = [](const SP_Device* const device, void* mem) {
|
||||
std::free(mem);
|
||||
deallocate_called = true;
|
||||
};
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
ASSERT_FALSE(allocate_called);
|
||||
void* mem = executor->HostMemoryAllocate(8);
|
||||
ASSERT_NE(mem, nullptr);
|
||||
ASSERT_TRUE(allocate_called);
|
||||
ASSERT_FALSE(deallocate_called);
|
||||
executor->HostMemoryDeallocate(mem);
|
||||
ASSERT_TRUE(deallocate_called);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, GetAllocatorStats) {
|
||||
se_.get_allocator_stats = [](const SP_Device* const device,
|
||||
SP_AllocatorStats* const stat) -> TF_Bool {
|
||||
stat->struct_size = SP_ALLOCATORSTATS_STRUCT_SIZE;
|
||||
stat->bytes_in_use = 123;
|
||||
return true;
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
absl::optional<AllocatorStats> optional_stats = executor->GetAllocatorStats();
|
||||
ASSERT_TRUE(optional_stats.has_value());
|
||||
AllocatorStats stats = optional_stats.value();
|
||||
ASSERT_EQ(stats.bytes_in_use, 123);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, DeviceMemoryUsage) {
|
||||
se_.device_memory_usage = [](const SP_Device* const device,
|
||||
int64_t* const free,
|
||||
int64_t* const total) -> TF_Bool {
|
||||
*free = 45;
|
||||
*total = 7;
|
||||
return true;
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
int64 free = 0;
|
||||
int64 total = 0;
|
||||
executor->DeviceMemoryUsage(&free, &total);
|
||||
ASSERT_EQ(free, 45);
|
||||
ASSERT_EQ(total, 7);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, CreateStream) {
|
||||
static bool stream_created = false;
|
||||
static bool stream_deleted = false;
|
||||
se_.create_stream = [](const SP_Device* const device, SP_Stream* stream,
|
||||
TF_Status* const status) -> void {
|
||||
*stream = new SP_Stream_st(14);
|
||||
stream_created = true;
|
||||
};
|
||||
se_.destroy_stream = [](const SP_Device* const device,
|
||||
SP_Stream stream) -> void {
|
||||
auto custom_stream = static_cast<SP_Stream_st*>(stream);
|
||||
ASSERT_EQ(custom_stream->stream_id, 14);
|
||||
delete custom_stream;
|
||||
stream_deleted = true;
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
ASSERT_FALSE(stream_created);
|
||||
Stream* stream = new Stream(executor);
|
||||
stream->Init();
|
||||
ASSERT_TRUE(stream->ok());
|
||||
ASSERT_TRUE(stream_created);
|
||||
ASSERT_FALSE(stream_deleted);
|
||||
delete stream;
|
||||
ASSERT_TRUE(stream_deleted);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, CreateStreamDependency) {
|
||||
static bool create_stream_dependency_called = false;
|
||||
se_.create_stream_dependency = [](const SP_Device* const device,
|
||||
SP_Stream dependent, SP_Stream other,
|
||||
TF_Status* const status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
create_stream_dependency_called = true;
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Stream dependent(executor);
|
||||
dependent.Init();
|
||||
Stream other(executor);
|
||||
other.Init();
|
||||
ASSERT_FALSE(create_stream_dependency_called);
|
||||
dependent.ThenWaitFor(&other);
|
||||
ASSERT_TRUE(create_stream_dependency_called);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, StreamStatus) {
|
||||
static bool status_ok = true;
|
||||
se_.get_stream_status = [](const SP_Device* const device, SP_Stream stream,
|
||||
TF_Status* const status) -> void {
|
||||
if (status_ok) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INTERNAL, "Test error");
|
||||
}
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
ASSERT_TRUE(stream.ok());
|
||||
TF_ASSERT_OK(stream.RefreshStatus());
|
||||
status_ok = false;
|
||||
auto updated_status = stream.RefreshStatus();
|
||||
ASSERT_FALSE(stream.ok());
|
||||
ASSERT_EQ(updated_status.error_message(), "Test error");
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, CreateEvent) {
|
||||
static bool event_created = false;
|
||||
static bool event_deleted = false;
|
||||
se_.create_event = [](const SP_Device* const device, SP_Event* event,
|
||||
TF_Status* const status) -> void {
|
||||
*event = new SP_Event_st(123);
|
||||
event_created = true;
|
||||
};
|
||||
se_.destroy_event = [](const SP_Device* const device,
|
||||
SP_Event event) -> void {
|
||||
auto custom_event = static_cast<SP_Event_st*>(event);
|
||||
ASSERT_EQ(custom_event->event_id, 123);
|
||||
delete custom_event;
|
||||
event_deleted = true;
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
ASSERT_FALSE(event_created);
|
||||
Event* event = new Event(executor);
|
||||
event->Init();
|
||||
ASSERT_TRUE(event_created);
|
||||
ASSERT_FALSE(event_deleted);
|
||||
delete event;
|
||||
ASSERT_TRUE(event_deleted);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, PollForEventStatus) {
|
||||
static SE_EventStatus event_status = SE_EVENT_COMPLETE;
|
||||
se_.create_event = [](const SP_Device* const device, SP_Event* event,
|
||||
TF_Status* const status) -> void {
|
||||
*event = new SP_Event_st(123);
|
||||
};
|
||||
se_.destroy_event = [](const SP_Device* const device,
|
||||
SP_Event event) -> void { delete event; };
|
||||
se_.get_event_status = [](const SP_Device* const device,
|
||||
SP_Event event) -> SE_EventStatus {
|
||||
EXPECT_EQ(event->event_id, 123);
|
||||
return event_status;
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Event event(executor);
|
||||
event.Init();
|
||||
ASSERT_EQ(event.PollForStatus(), Event::Status::kComplete);
|
||||
event_status = SE_EVENT_ERROR;
|
||||
ASSERT_EQ(event.PollForStatus(), Event::Status::kError);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, RecordAndWaitForEvent) {
|
||||
static bool record_called = false;
|
||||
static bool wait_called = false;
|
||||
se_.create_stream = [](const SP_Device* const device, SP_Stream* stream,
|
||||
TF_Status* const status) -> void {
|
||||
*stream = new SP_Stream_st(1);
|
||||
};
|
||||
se_.destroy_stream = [](const SP_Device* const device,
|
||||
SP_Stream stream) -> void { delete stream; };
|
||||
se_.create_event = [](const SP_Device* const device, SP_Event* event,
|
||||
TF_Status* const status) -> void {
|
||||
*event = new SP_Event_st(2);
|
||||
};
|
||||
se_.destroy_event = [](const SP_Device* const device,
|
||||
SP_Event event) -> void { delete event; };
|
||||
se_.record_event = [](const SP_Device* const device, SP_Stream stream,
|
||||
SP_Event event, TF_Status* const status) {
|
||||
EXPECT_EQ(stream->stream_id, 1);
|
||||
EXPECT_EQ(event->event_id, 2);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
record_called = true;
|
||||
};
|
||||
se_.wait_for_event = [](const SP_Device* const device, SP_Stream stream,
|
||||
SP_Event event, TF_Status* const status) {
|
||||
EXPECT_EQ(stream->stream_id, 1);
|
||||
EXPECT_EQ(event->event_id, 2);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
wait_called = true;
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Event event(executor);
|
||||
event.Init();
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
ASSERT_FALSE(record_called);
|
||||
stream.ThenRecordEvent(&event);
|
||||
ASSERT_TRUE(record_called);
|
||||
ASSERT_FALSE(wait_called);
|
||||
stream.ThenWaitFor(&event);
|
||||
ASSERT_TRUE(wait_called);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, CreateTimer) {
|
||||
static bool timer_created = false;
|
||||
static bool timer_deleted = false;
|
||||
se_.create_timer = [](const SP_Device* const device, SP_Timer* timer,
|
||||
TF_Status* const status) -> void {
|
||||
*timer = new SP_Timer_st(25);
|
||||
timer_created = true;
|
||||
};
|
||||
se_.destroy_timer = [](const SP_Device* const device,
|
||||
SP_Timer timer) -> void {
|
||||
auto custom_timer = static_cast<SP_Timer_st*>(timer);
|
||||
EXPECT_EQ(custom_timer->timer_id, 25);
|
||||
delete custom_timer;
|
||||
timer_deleted = true;
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
ASSERT_FALSE(timer_created);
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
Timer* timer = new Timer(executor);
|
||||
stream.InitTimer(timer);
|
||||
ASSERT_TRUE(stream.ok());
|
||||
ASSERT_TRUE(timer_created);
|
||||
ASSERT_FALSE(timer_deleted);
|
||||
delete timer;
|
||||
ASSERT_TRUE(timer_deleted);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, StartTimer) {
|
||||
static bool start_called = false;
|
||||
static bool stop_called = false;
|
||||
static TF_Code start_timer_status = TF_OK;
|
||||
static TF_Code stop_timer_status = TF_OK;
|
||||
se_.create_timer = [](const SP_Device* const device, SP_Timer* timer,
|
||||
TF_Status* const status) -> void {
|
||||
*timer = new SP_Timer_st(7);
|
||||
};
|
||||
se_.destroy_timer = [](const SP_Device* const device,
|
||||
SP_Timer timer) -> void { delete timer; };
|
||||
se_.start_timer = [](const SP_Device* const device, SP_Stream stream,
|
||||
SP_Timer timer, TF_Status* const status) {
|
||||
TF_SetStatus(status, start_timer_status, "");
|
||||
EXPECT_EQ(timer->timer_id, 7);
|
||||
start_called = true;
|
||||
};
|
||||
se_.stop_timer = [](const SP_Device* const device, SP_Stream stream,
|
||||
SP_Timer timer, TF_Status* const status) {
|
||||
TF_SetStatus(status, stop_timer_status, "");
|
||||
EXPECT_EQ(timer->timer_id, 7);
|
||||
stop_called = true;
|
||||
};
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
Timer timer(executor);
|
||||
stream.InitTimer(&timer);
|
||||
|
||||
// Check both start and stop succeed
|
||||
ASSERT_FALSE(start_called);
|
||||
stream.ThenStartTimer(&timer);
|
||||
ASSERT_TRUE(start_called);
|
||||
ASSERT_FALSE(stop_called);
|
||||
stream.ThenStopTimer(&timer);
|
||||
ASSERT_TRUE(stop_called);
|
||||
|
||||
// Check start timer fails
|
||||
ASSERT_TRUE(stream.ok());
|
||||
start_timer_status = TF_UNKNOWN;
|
||||
stream.ThenStartTimer(&timer);
|
||||
ASSERT_FALSE(stream.ok());
|
||||
|
||||
// Check stop timer fails
|
||||
start_timer_status = TF_OK;
|
||||
stop_timer_status = TF_UNKNOWN;
|
||||
Stream stream2(executor);
|
||||
stream2.Init();
|
||||
Timer timer2(executor);
|
||||
stream2.InitTimer(&timer2);
|
||||
stream2.ThenStartTimer(&timer2);
|
||||
ASSERT_TRUE(stream2.ok());
|
||||
stream2.ThenStopTimer(&timer2);
|
||||
ASSERT_FALSE(stream2.ok());
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, TimerFns) {
|
||||
se_.create_timer = [](const SP_Device* const device, SP_Timer* timer,
|
||||
TF_Status* const status) -> void {
|
||||
*timer = new SP_Timer_st(25000);
|
||||
};
|
||||
se_.destroy_timer = [](const SP_Device* const device,
|
||||
SP_Timer timer) -> void { delete timer; };
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
Timer timer(executor);
|
||||
stream.InitTimer(&timer);
|
||||
// Our test nanoseconds function just returns value
|
||||
// passed to SP_Timer_st constructor.
|
||||
ASSERT_EQ(timer.Nanoseconds(), 25000);
|
||||
ASSERT_EQ(timer.Microseconds(), 25);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, MemcpyToHost) {
|
||||
se_.create_stream = [](const SP_Device* const device, SP_Stream* stream,
|
||||
TF_Status* const status) -> void {
|
||||
*stream = new SP_Stream_st(14);
|
||||
};
|
||||
se_.destroy_stream = [](const SP_Device* const device,
|
||||
SP_Stream stream) -> void { delete stream; };
|
||||
|
||||
se_.memcpy_dtoh = [](const SP_Device* const device, SP_Stream stream,
|
||||
void* host_dst,
|
||||
const SP_DeviceMemoryBase* const device_src,
|
||||
uint64_t size, TF_Status* const status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
EXPECT_EQ(stream->stream_id, 14);
|
||||
std::memcpy(host_dst, device_src->opaque, size);
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
size_t size = sizeof(int);
|
||||
int src_data = 34;
|
||||
int dst_data = 2;
|
||||
DeviceMemoryBase device_src(&src_data, size);
|
||||
Stream& stream_ref = stream.ThenMemcpy(&dst_data, device_src, size);
|
||||
ASSERT_EQ(dst_data, 34);
|
||||
ASSERT_EQ(stream_ref.implementation(), stream.implementation());
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, MemcpyFromHost) {
|
||||
se_.memcpy_htod = [](const SP_Device* const device, SP_Stream stream,
|
||||
SP_DeviceMemoryBase* const device_dst,
|
||||
const void* host_src, uint64_t size,
|
||||
TF_Status* const status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
std::memcpy(device_dst->opaque, host_src, size);
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
size_t size = sizeof(int);
|
||||
int src_data = 18;
|
||||
int dst_data = 0;
|
||||
DeviceMemoryBase device_dst(&dst_data, size);
|
||||
stream.ThenMemcpy(&device_dst, &src_data, size);
|
||||
ASSERT_EQ(dst_data, 18);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, MemcpyDeviceToDevice) {
|
||||
se_.memcpy_dtod = [](const SP_Device* const device, SP_Stream stream,
|
||||
SP_DeviceMemoryBase* const device_dst,
|
||||
const SP_DeviceMemoryBase* const device_src,
|
||||
uint64_t size, TF_Status* const status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
std::memcpy(device_dst->opaque, device_src->opaque, size);
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
size_t size = sizeof(int);
|
||||
int src_data = 18;
|
||||
int dst_data = 0;
|
||||
DeviceMemoryBase device_dst(&dst_data, size);
|
||||
DeviceMemoryBase device_src(&src_data, size);
|
||||
stream.ThenMemcpy(&device_dst, device_src, size);
|
||||
ASSERT_EQ(dst_data, 18);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, SyncMemcpyToHost) {
|
||||
se_.sync_memcpy_dtoh = [](const SP_Device* const device, void* host_dst,
|
||||
const SP_DeviceMemoryBase* const device_src,
|
||||
uint64_t size, TF_Status* const status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
std::memcpy(host_dst, device_src->opaque, size);
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
size_t size = sizeof(int);
|
||||
int src_data = 34;
|
||||
int dst_data = 2;
|
||||
DeviceMemoryBase device_src(&src_data, size);
|
||||
TF_ASSERT_OK(executor->SynchronousMemcpyD2H(device_src, size, &dst_data));
|
||||
ASSERT_EQ(dst_data, 34);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, SyncMemcpyFromHost) {
|
||||
se_.sync_memcpy_htod =
|
||||
[](const SP_Device* const device, SP_DeviceMemoryBase* const device_dst,
|
||||
const void* host_src, uint64_t size, TF_Status* const status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
std::memcpy(device_dst->opaque, host_src, size);
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
size_t size = sizeof(int);
|
||||
int src_data = 18;
|
||||
int dst_data = 0;
|
||||
DeviceMemoryBase device_dst(&dst_data, size);
|
||||
TF_ASSERT_OK(executor->SynchronousMemcpyH2D(&src_data, size, &device_dst));
|
||||
ASSERT_EQ(dst_data, 18);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, SyncMemcpyDeviceToDevice) {
|
||||
se_.sync_memcpy_dtod = [](const SP_Device* const device,
|
||||
SP_DeviceMemoryBase* const device_dst,
|
||||
const SP_DeviceMemoryBase* const device_src,
|
||||
uint64_t size, TF_Status* const status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
std::memcpy(device_dst->opaque, device_src->opaque, size);
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
size_t size = sizeof(int);
|
||||
int src_data = 18;
|
||||
int dst_data = 0;
|
||||
DeviceMemoryBase device_dst(&dst_data, size);
|
||||
DeviceMemoryBase device_src(&src_data, size);
|
||||
ASSERT_TRUE(executor->SynchronousMemcpy(&device_dst, device_src, size));
|
||||
ASSERT_EQ(dst_data, 18);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, BlockHostForEvent) {
|
||||
static bool block_host_for_event_called = false;
|
||||
se_.create_event = [](const SP_Device* const device, SP_Event* event,
|
||||
TF_Status* const status) {
|
||||
*event = new SP_Event_st(357);
|
||||
};
|
||||
se_.destroy_event = [](const SP_Device* const device, SP_Event event) {
|
||||
delete event;
|
||||
};
|
||||
se_.block_host_for_event = [](const SP_Device* const device, SP_Event event,
|
||||
TF_Status* const status) -> void {
|
||||
ASSERT_EQ(event->event_id, 357);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
block_host_for_event_called = true;
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
ASSERT_FALSE(block_host_for_event_called);
|
||||
TF_ASSERT_OK(stream.BlockHostUntilDone());
|
||||
ASSERT_TRUE(block_host_for_event_called);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, SynchronizeAllActivity) {
|
||||
static bool synchronize_all_called = false;
|
||||
se_.synchronize_all_activity = [](const SP_Device* const device,
|
||||
TF_Status* const status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
synchronize_all_called = true;
|
||||
};
|
||||
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
ASSERT_FALSE(synchronize_all_called);
|
||||
ASSERT_TRUE(executor->SynchronizeAllActivity());
|
||||
ASSERT_TRUE(synchronize_all_called);
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, HostCallbackOk) {
|
||||
se_.host_callback = [](SP_Device* const device, SP_Stream stream,
|
||||
SE_StatusCallbackFn const callback_fn,
|
||||
void* const callback_arg) -> TF_Bool {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
callback_fn(callback_arg, status);
|
||||
bool ok = TF_GetCode(status) == TF_OK;
|
||||
TF_DeleteStatus(status);
|
||||
return ok;
|
||||
};
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
std::function<port::Status()> callback = []() -> port::Status {
|
||||
return port::Status::OK();
|
||||
};
|
||||
stream.ThenDoHostCallbackWithStatus(callback);
|
||||
ASSERT_TRUE(stream.ok());
|
||||
}
|
||||
|
||||
TEST_F(StreamExecutorTest, HostCallbackError) {
|
||||
se_.host_callback = [](SP_Device* const device, SP_Stream stream,
|
||||
SE_StatusCallbackFn const callback_fn,
|
||||
void* const callback_arg) -> TF_Bool {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
callback_fn(callback_arg, status);
|
||||
bool ok = TF_GetCode(status) == TF_OK;
|
||||
TF_DeleteStatus(status);
|
||||
return ok;
|
||||
};
|
||||
StreamExecutor* executor = GetExecutor(0);
|
||||
Stream stream(executor);
|
||||
stream.Init();
|
||||
std::function<port::Status()> callback = []() -> port::Status {
|
||||
return port::UnimplementedError("Unimplemented");
|
||||
};
|
||||
stream.ThenDoHostCallbackWithStatus(callback);
|
||||
ASSERT_FALSE(stream.ok());
|
||||
}
|
||||
} // namespace
|
||||
} // namespace stream_executor
|
@ -280,6 +280,36 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
|
||||
return tf_tensor;
|
||||
}
|
||||
|
||||
TF_Tensor* TF_ForwardInputOrAllocateOutput(
|
||||
TF_OpKernelContext* context, int* candidate_input_indices,
|
||||
int num_candidate_input_indices, int output_index, int64_t* output_dims,
|
||||
int output_num_dims, int* forwarded_input, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
|
||||
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
tensorflow::gtl::ArraySlice<int> input_indices_array(
|
||||
candidate_input_indices, num_candidate_input_indices);
|
||||
tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
|
||||
reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
|
||||
tensorflow::Tensor* output_tensor_pointer;
|
||||
tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
|
||||
input_indices_array, output_index,
|
||||
tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
|
||||
forwarded_input);
|
||||
if (!s.ok()) {
|
||||
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return nullptr;
|
||||
}
|
||||
TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
|
||||
if (!s.ok()) {
|
||||
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return nullptr;
|
||||
}
|
||||
return tf_tensor_output;
|
||||
}
|
||||
|
||||
TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
|
||||
int64_t* dims, int num_dims,
|
||||
TF_AllocatorAttributes* attributes,
|
||||
|
@ -200,6 +200,17 @@ TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context,
|
||||
int64_t* dims, int num_dims,
|
||||
size_t len, TF_Status* status);
|
||||
|
||||
// Tries to forward one of the inputs given in input_indices to
|
||||
// output[output_index]. If none of the given inputs can be forwarded, calls
|
||||
// allocate_output() to allocate a new output buffer. The index of the
|
||||
// forwarded input will be assign to output argument forwarded_input (if it's
|
||||
// not nullptr). If no inputs are forwarded, forwarded_input will be assigned
|
||||
// -1.
|
||||
TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput(
|
||||
TF_OpKernelContext* context, int* candidate_input_indices,
|
||||
int num_candidate_input_indices, int output_index, int64_t* output_dims,
|
||||
int output_num_dims, int* forwarded_input, TF_Status* status);
|
||||
|
||||
// Allocates a temporary Tensor of the specified type and shape. The
|
||||
// Tensor must not be used after kernel construction is
|
||||
// complete.
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||
#include "tensorflow/core/platform/bfloat16.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/platform/bfloat16.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
@ -565,6 +565,74 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
|
||||
output->DebugString(100));
|
||||
}
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
|
||||
const char* node_name = "TestForwardInputOrAllocateOutputKernel";
|
||||
const char* op_name = "BazOp";
|
||||
const char* device_name = "FakeDeviceName";
|
||||
|
||||
REGISTER_OP(op_name)
|
||||
.Input("input1: float")
|
||||
.Input("input2: float")
|
||||
.Output("output1: float")
|
||||
.Attr("SomeDataTypeAttr: type");
|
||||
|
||||
// A kernel whose Compute function that forwards a scalar input to output
|
||||
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
int candidate_input_indices[1] = {0};
|
||||
int forwarded_input;
|
||||
int64_t output_dims[1] = {};
|
||||
TF_Tensor* output = TF_ForwardInputOrAllocateOutput(
|
||||
/*context=*/ctx, candidate_input_indices,
|
||||
/*num_candidate_input_indices=*/1,
|
||||
/*output_index=*/0, output_dims, /*output_num_dims=*/0,
|
||||
&forwarded_input, /*status=*/s);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||
EXPECT_EQ(forwarded_input, 0);
|
||||
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
|
||||
EXPECT_EQ(0, TF_NumDims(output));
|
||||
TF_DeleteStatus(s);
|
||||
TF_DeleteTensor(output);
|
||||
};
|
||||
|
||||
TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
|
||||
my_compute_func, nullptr);
|
||||
|
||||
{
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterKernelBuilder(node_name, builder, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
{
|
||||
OpKernelContext::Params p;
|
||||
DummyDevice dummy_device(nullptr);
|
||||
p.device = &dummy_device;
|
||||
AllocatorAttributes alloc_attrs;
|
||||
p.output_attr_array = &alloc_attrs;
|
||||
|
||||
Tensor t(123.0f);
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
// GetFakeKernel requires a NodeDef with two inputs
|
||||
inputs.emplace_back(&t);
|
||||
inputs.emplace_back();
|
||||
p.inputs = &inputs;
|
||||
|
||||
Status status;
|
||||
std::unique_ptr<OpKernel> kernel =
|
||||
GetFakeKernel(device_name, op_name, node_name, &status);
|
||||
TF_EXPECT_OK(status);
|
||||
ASSERT_NE(nullptr, kernel.get());
|
||||
|
||||
p.op_kernel = kernel.get();
|
||||
OpKernelContext ctx(&p);
|
||||
kernel->Compute(&ctx);
|
||||
ASSERT_EQ(123, ctx.mutable_output(0)->scalar<float>()());
|
||||
}
|
||||
}
|
||||
|
||||
void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
|
||||
TF_DataType dtype) {
|
||||
EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor));
|
||||
|
@ -28,6 +28,7 @@ void TF_Log(TF_LogLevel level, const char* fmt, ...) {
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
auto message = BuildMessage(fmt, args);
|
||||
va_end(args);
|
||||
switch (level) {
|
||||
case TF_INFO:
|
||||
LOG(INFO) << message;
|
||||
@ -48,6 +49,7 @@ void TF_VLog(int level, const char* fmt, ...) {
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
auto message = BuildMessage(fmt, args);
|
||||
va_end(args);
|
||||
VLOG(level) << message;
|
||||
}
|
||||
|
||||
@ -55,5 +57,6 @@ void TF_DVLog(int level, const char* fmt, ...) {
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
auto message = BuildMessage(fmt, args);
|
||||
va_end(args);
|
||||
DVLOG(level) << message;
|
||||
}
|
||||
|
@ -47,6 +47,7 @@ cc_library(
|
||||
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
||||
# cannot be built on mobile platforms. Instead, include the appropriate
|
||||
# tf_lib depending on the build platform.
|
||||
"@com_google_absl//absl/memory:memory",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]),
|
||||
@ -171,6 +172,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":constants",
|
||||
":loader",
|
||||
":reader",
|
||||
":signature_constants",
|
||||
":tag_constants",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -51,8 +51,32 @@ cc_library(
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list",
|
||||
":signature_def_function",
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function",
|
||||
hdrs = [
|
||||
"signature_def_function.h",
|
||||
],
|
||||
deps = [
|
||||
":signature_def_function_metadata",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/experimental/saved_model/public:signature_def_function",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function_metadata",
|
||||
hdrs = [
|
||||
"signature_def_function_metadata.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata",
|
||||
],
|
||||
)
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/signature_def_function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
@ -80,8 +81,8 @@ class SavedModelAPI {
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
|
||||
// is bound to SavedModelAPI it was loaded from.
|
||||
ConcreteFunction* GetSignatureDefFunction(const std::string& function_path,
|
||||
Status* status);
|
||||
SignatureDefFunction* GetSignatureDefFunction(
|
||||
const std::string& function_path, Status* status);
|
||||
|
||||
// Lists all Conrete Functions available from the SavedModel.
|
||||
std::vector<ConcreteFunction*> ListFunctions();
|
||||
@ -140,14 +141,14 @@ inline ConcreteFunction* SavedModelAPI::GetConcreteFunction(
|
||||
return ConcreteFunction::wrap(function);
|
||||
}
|
||||
|
||||
inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction(
|
||||
inline SignatureDefFunction* SavedModelAPI::GetSignatureDefFunction(
|
||||
const std::string& function_path, Status* status) {
|
||||
TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SignatureDefFunction* function = TF_GetSavedModelSignatureDefFunction(
|
||||
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return ConcreteFunction::wrap(function);
|
||||
return SignatureDefFunction::wrap(function);
|
||||
}
|
||||
|
||||
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
||||
|
@ -0,0 +1,89 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// SignatureDefFunctions are functions that correspond to either:
|
||||
// "signatures" saved from a TF2 SavedModel APIs:
|
||||
// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/python/saved_model/save.py#L830-L854
|
||||
// Or the "SignatureDefMap" saved from TF1 SavedModel APIs:
|
||||
// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/python/saved_model/load_v1_in_v2_test.py#L170-L174
|
||||
// In both cases, a SignatureDef is serialized as a SignatureDef protobuf:
|
||||
// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/core/protobuf/meta_graph.proto#L260-L330
|
||||
// and represents a computation defined by a TF subgraph.
|
||||
// These Signatures were primarily designed to be interoperable with the legacy
|
||||
// TF 1 Session-based C++ SavedModelBundle loading APIs:
|
||||
// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/cc/saved_model/loader.h#L96-L108
|
||||
// SignatureDefFunctions have different semantics from regular TF2
|
||||
// ConcreteFunctions, and are mainly intended provide a serving-friendly
|
||||
// transition point from the TF1 Session API.
|
||||
// First, SignatureDefFunctions have different calling conventions.
|
||||
// SignatureDefFunctions' inputs and outputs are constrained to **flattened
|
||||
// lists of TensorHandles only**. They do not support more exotic input/output
|
||||
// types (like optionals, generators, etc). Additionally, this flattening means
|
||||
// they will not preserve the exact interface of the original tf.function they
|
||||
// were traced from, as things like composite tensors decay into their
|
||||
// internal dense tensor representation.
|
||||
// Second, all inputs and outputs are "named", and these names are load bearing
|
||||
// (eg: they are part of the interface of tensorflow_serving):
|
||||
// https://github.com/tensorflow/serving/blob/e0d247b2e4050713194b8fad0be24a0636df7209/tensorflow_serving/apis/predict.proto#L21
|
||||
// https://github.com/tensorflow/serving/blob/e0d247b2e4050713194b8fad0be24a0636df7209/tensorflow_serving/apis/predict.proto#L39
|
||||
// The name of each input/output is stored in the corresponding tf::Argument in
|
||||
// SignatureDefFunctionMetadata::arguments(). Users must ensure the order of
|
||||
// TensorHandles passed to the function matches with the order of named
|
||||
// arguments. Similarly the name of the outputs is stored in
|
||||
// SignatureDefFunctionMetadata::returns().
|
||||
class SignatureDefFunction final {
|
||||
public:
|
||||
// Returns FunctionMetadata associated with this ConcreteFunction.
|
||||
const SignatureDefFunctionMetadata* GetFunctionMetadata();
|
||||
|
||||
private:
|
||||
friend class SavedModelAPI;
|
||||
friend class ConcreteFunctionList;
|
||||
|
||||
// TODO(bmzhao): Consider adding a macro for wrapping/unwrapping
|
||||
// when moving out of experimental.
|
||||
static SignatureDefFunction* wrap(TF_SignatureDefFunction* p) {
|
||||
return reinterpret_cast<SignatureDefFunction*>(p);
|
||||
}
|
||||
static TF_SignatureDefFunction* unwrap(SignatureDefFunction* p) {
|
||||
return reinterpret_cast<TF_SignatureDefFunction*>(p);
|
||||
}
|
||||
};
|
||||
|
||||
inline const SignatureDefFunctionMetadata*
|
||||
SignatureDefFunction::GetFunctionMetadata() {
|
||||
return SignatureDefFunctionMetadata::wrap(
|
||||
TF_SignatureDefFunctionGetMetadata(unwrap(this)));
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_
|
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// SignatureDefFunctionMetadata stores additional information on each input
|
||||
// and output's names, dtypes, and shape.
|
||||
class SignatureDefFunctionMetadata final {
|
||||
// TODO(bmzhao): Add getters here as necessary.
|
||||
private:
|
||||
friend class SignatureDefFunction;
|
||||
static SignatureDefFunctionMetadata* wrap(
|
||||
TF_SignatureDefFunctionMetadata* p) {
|
||||
return reinterpret_cast<SignatureDefFunctionMetadata*>(p);
|
||||
}
|
||||
static TF_SignatureDefFunctionMetadata* unwrap(
|
||||
SignatureDefFunctionMetadata* p) {
|
||||
return reinterpret_cast<TF_SignatureDefFunctionMetadata*>(p);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/saver.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
@ -95,16 +96,6 @@ static Status ValidateSavedTensors(const GraphDef& graph_def) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||
const SessionOptions& session_options,
|
||||
std::unique_ptr<Session>* session) {
|
||||
Session* session_p = nullptr;
|
||||
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
|
||||
session->reset(session_p);
|
||||
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def()));
|
||||
return (*session)->Create(meta_graph_def.graph_def());
|
||||
}
|
||||
|
||||
Tensor CreateStringTensor(const string& value) {
|
||||
Tensor tensor(DT_STRING, TensorShape({}));
|
||||
tensor.scalar<tstring>()() = value;
|
||||
@ -228,22 +219,18 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
nullptr /* outputs */, &run_metadata, session);
|
||||
}
|
||||
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
|
||||
LOG(INFO) << "Reading SavedModel debug info (if present) from: "
|
||||
<< export_dir;
|
||||
} // namespace
|
||||
|
||||
const string debug_info_pb_path =
|
||||
io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
|
||||
if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
|
||||
GraphDebugInfo debug_info;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
|
||||
*debug_info_proto =
|
||||
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
|
||||
}
|
||||
return Status::OK();
|
||||
SavedModelBundleInterface::~SavedModelBundleInterface() {}
|
||||
|
||||
Status LoadMetagraphIntoSession(const SessionOptions& session_options,
|
||||
const MetaGraphDef& meta_graph,
|
||||
std::unique_ptr<Session>* session) {
|
||||
Session* session_p = nullptr;
|
||||
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
|
||||
session->reset(session_p);
|
||||
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph.graph_def()));
|
||||
return (*session)->Create(meta_graph.graph_def());
|
||||
}
|
||||
|
||||
Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
@ -251,46 +238,17 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
SavedModelBundle* const bundle) {
|
||||
const uint64 read_start_microseconds = Env::Default()->NowMicros();
|
||||
TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
|
||||
&bundle->meta_graph_def));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info));
|
||||
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
|
||||
bundle->meta_graph_def, session_options, &bundle->session));
|
||||
|
||||
std::vector<AssetFileDef> asset_file_defs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunRestore(run_options, export_dir,
|
||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
||||
asset_file_defs, bundle->session.get()));
|
||||
// Record walltime spent in restoring graph from disk, but postpone metric
|
||||
// increments until graph init finishes.
|
||||
const uint64 restore_graph_walltime =
|
||||
GetLatencyMicroseconds(read_start_microseconds);
|
||||
|
||||
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
|
||||
string init_op_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
|
||||
asset_file_defs, bundle->session.get(),
|
||||
init_op_name));
|
||||
load_latency_by_stage->GetCell(export_dir, "restore_graph")
|
||||
->Add(restore_graph_walltime);
|
||||
// Record wall time spent in init op.
|
||||
load_latency_by_stage->GetCell(export_dir, "init_graph")
|
||||
->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
|
||||
TF_RETURN_IF_ERROR(LoadMetagraphIntoSession(
|
||||
session_options, bundle->meta_graph_def, &bundle->session));
|
||||
TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def,
|
||||
export_dir, &bundle->session));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SavedModelBundleInterface::~SavedModelBundleInterface() {}
|
||||
|
||||
Status LoadSavedModel(const SessionOptions& session_options,
|
||||
const RunOptions& run_options, const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
@ -424,6 +382,35 @@ class LiteSessionWrapper : public Session {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Status RestoreSession(const RunOptions& run_options,
|
||||
const MetaGraphDef& meta_graph, const string& export_dir,
|
||||
std::unique_ptr<Session>* session) {
|
||||
const uint64 read_start_microseconds = Env::Default()->NowMicros();
|
||||
std::vector<AssetFileDef> asset_file_defs;
|
||||
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
|
||||
TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
|
||||
meta_graph.saver_def().restore_op_name(),
|
||||
meta_graph.saver_def().filename_tensor_name(),
|
||||
asset_file_defs, session->get()));
|
||||
// Record walltime spent in restoring graph from disk, but postpone metric
|
||||
// increments until graph init finishes.
|
||||
const uint64 restore_graph_walltime =
|
||||
GetLatencyMicroseconds(read_start_microseconds);
|
||||
|
||||
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
|
||||
string init_op_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
internal::GetInitOp(export_dir, meta_graph, &init_op_name));
|
||||
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, meta_graph,
|
||||
asset_file_defs, session->get(), init_op_name));
|
||||
load_latency_by_stage->GetCell(export_dir, "restore_graph")
|
||||
->Add(restore_graph_walltime);
|
||||
// Record wall time spent in init op.
|
||||
load_latency_by_stage->GetCell(export_dir, "init_graph")
|
||||
->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LoadSavedModel(const SessionOptions& session_options,
|
||||
const RunOptions& run_options, const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
|
@ -96,6 +96,21 @@ class SavedModelBundleLite : public SavedModelBundleInterface {
|
||||
protobuf::Map<string, SignatureDef> signatures_;
|
||||
};
|
||||
|
||||
// Restore variable and resources in the SavedModel export dir for the
|
||||
// indicated metagraph.
|
||||
// The recommended way to load a saved model is to call LoadSavedModel,
|
||||
// which provides an already initialized Metagraph, Session, and DebugInfo.
|
||||
Status RestoreSession(const RunOptions& run_options,
|
||||
const MetaGraphDef& meta_graph, const string& export_dir,
|
||||
std::unique_ptr<Session>* session);
|
||||
|
||||
// Initialize a session which wraps this metagraph.
|
||||
// The recommended way to load a saved model is to call LoadSavedModel,
|
||||
// which provides an already initialized Metagraph, Session, and DebugInfo.
|
||||
Status LoadMetagraphIntoSession(const SessionOptions& session_options,
|
||||
const MetaGraphDef& meta_graph,
|
||||
std::unique_ptr<Session>* session);
|
||||
|
||||
/// Loads a SavedModel from the specified export directory. The MetaGraphDef
|
||||
/// to be loaded is identified by the supplied tags, corresponding exactly to
|
||||
/// the set of tags used at SavedModel build time. Stores a SavedModel bundle in
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -86,4 +87,22 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
|
||||
LOG(INFO) << "Reading SavedModel debug info (if present) from: "
|
||||
<< export_dir;
|
||||
|
||||
const string debug_info_pb_path =
|
||||
io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
|
||||
if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
|
||||
GraphDebugInfo debug_info;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
|
||||
*debug_info_proto =
|
||||
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -34,6 +35,11 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
MetaGraphDef* const meta_graph_def);
|
||||
|
||||
// Store debug info from the SavedModel export dir.
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_
|
||||
|
@ -106,5 +106,11 @@ TEST_F(ReaderTest, InvalidExportPath) {
|
||||
EXPECT_FALSE(st.ok());
|
||||
}
|
||||
|
||||
TEST_F(ReaderTest, ReadSavedModelDebugInfoIfPresent) {
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
std::unique_ptr<GraphDebugInfo> debug_info_proto;
|
||||
TF_ASSERT_OK(ReadSavedModelDebugInfoIfPresent(export_dir, &debug_info_proto));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/cc/saved_model/signature_constants.h"
|
||||
#include "tensorflow/cc/saved_model/tag_constants.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -131,6 +132,43 @@ TEST_F(LoaderTest, TagMatch) {
|
||||
CheckSavedModelBundle(export_dir, bundle);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, ReadMetaGraphFromSavedModel) {
|
||||
SavedModelBundle bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
MetaGraphDef actual_metagraph;
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&actual_metagraph));
|
||||
EXPECT_EQ(actual_metagraph.DebugString(),
|
||||
bundle.meta_graph_def.DebugString());
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, RestoreSession) {
|
||||
SavedModelBundle bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
|
||||
SavedModelBundle actual_bundle;
|
||||
const std::unordered_set<std::string> tags = {kSavedModelTagServe};
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, tags,
|
||||
&actual_bundle.meta_graph_def));
|
||||
TF_ASSERT_OK(LoadMetagraphIntoSession(
|
||||
session_options, actual_bundle.meta_graph_def, &actual_bundle.session));
|
||||
TF_ASSERT_OK(RestoreSession(run_options, actual_bundle.meta_graph_def,
|
||||
export_dir, &actual_bundle.session));
|
||||
CheckSavedModelBundle(export_dir, actual_bundle);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, NoTagMatch) {
|
||||
SavedModelBundle bundle;
|
||||
RunOptions run_options;
|
||||
|
@ -278,16 +278,14 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
const NodeDef& node_def = ctx->op_kernel().def();
|
||||
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
|
||||
|
||||
bool are_args_supported =
|
||||
absl::c_all_of(args, [](const XlaCompiler::Argument arg) {
|
||||
return arg.kind == XlaCompiler::Argument::kConstant ||
|
||||
arg.kind == XlaCompiler::Argument::kParameter;
|
||||
bool has_tensor_list_arg =
|
||||
absl::c_any_of(args, [](const XlaCompiler::Argument arg) {
|
||||
return arg.kind == XlaCompiler::Argument::kTensorList;
|
||||
});
|
||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
||||
// TODO(b/155596779): Understand the source of other argument types and
|
||||
// depending on the source either support those or avoid these codepath.
|
||||
if (!use_mlir || !are_args_supported) {
|
||||
// TODO(b/155596779): Support TensorList args.
|
||||
if (!use_mlir || !has_tensor_list_arg) {
|
||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||
std::move(graph), args, result);
|
||||
}
|
||||
|
@ -40,13 +40,16 @@ cc_library(
|
||||
srcs = ["tf_mlir_opt_main.cc"],
|
||||
deps = [
|
||||
":init_mlir",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
@ -127,9 +130,7 @@ tf_cc_binary(
|
||||
deps = [
|
||||
":passes",
|
||||
":tf_mlir_opt_main",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing",
|
||||
|
@ -813,7 +813,8 @@ cc_binary(
|
||||
],
|
||||
deps = [
|
||||
":all_passes",
|
||||
":hlo_dialect_registration",
|
||||
":hlo",
|
||||
":lhlo",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -56,19 +56,9 @@ class MhloDialect : public Dialect {
|
||||
void printType(Type type, DialectAsmPrinter &os) const override;
|
||||
};
|
||||
|
||||
namespace HLOTypes {
|
||||
enum Kind {
|
||||
Token = Type::FIRST_XLA_HLO_TYPE,
|
||||
};
|
||||
} // namespace HLOTypes
|
||||
|
||||
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static TokenType get(MLIRContext *context) {
|
||||
return Base::get(context, HLOTypes::Token);
|
||||
}
|
||||
};
|
||||
|
||||
// Shape derivation function that computes the shape of the result based on
|
||||
|
@ -81,6 +81,8 @@ def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp {
|
||||
ElementsAttr:$value,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||
);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
|
||||
|
@ -65,9 +65,24 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
|
||||
|
||||
// Returns DenseElementsAttr of rank zero with the given element type and the
|
||||
// value.
|
||||
// Requires `ty` to be either FloatType of IntegerType.
|
||||
// Requires `ty` to be either FloatType, IntegerType, or ComplexType.
|
||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
|
||||
|
||||
// Enum type used to specify scalar argument to GetScalarLimitOfType.
|
||||
enum ScalarLimit {
|
||||
kLowest, // The scalar corresponding to numeric_limits<T>::lowest.
|
||||
kInfinityLowest, // Like kMax, but returns -infinity where available.
|
||||
kMax, // The scalar corresponding to numeric_limits<T>::max.
|
||||
kInfinityMax, // Like kMax, but returns infinity where available.
|
||||
};
|
||||
|
||||
// Returns a scalar limit value for the given type.
|
||||
//
|
||||
// The argument 'limit' describes which scalar value to return.
|
||||
//
|
||||
// Requires `ty` to be either FloatType or IntegerType.
|
||||
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
@ -56,6 +57,38 @@ LmhloDialect::LmhloDialect(MLIRContext *context)
|
||||
>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// An lho.constant on an memref that is locally allocated and with no other
|
||||
/// users (other than dealloc's) can be erased.
|
||||
// TODO: This can be generalized to an arbitrary op by making use of memory
|
||||
// effects (write memory effect).
|
||||
struct EraseConstOp : public OpRewritePattern<ConstOp> {
|
||||
using OpRewritePattern<ConstOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ConstOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Value memref = op.output();
|
||||
if (!memref.getDefiningOp<AllocOp>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check that all uses of the memref are either DeallocOps or this op.
|
||||
for (Operation* user : memref.getUsers())
|
||||
if (user != op && !isa<DeallocOp>(user)) return failure();
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
MLIRContext* context) {
|
||||
results.insert<EraseConstOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StaticMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
@ -598,6 +600,7 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||
unsigned currSrcDim = 0, currDstDim = 0;
|
||||
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
|
||||
dstShape.size());
|
||||
bool isExpandingOrCollapsing = true;
|
||||
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
|
||||
int64_t dstSize = dstShape[currDstDim];
|
||||
int64_t srcSize = srcShape[currSrcDim];
|
||||
@ -619,11 +622,47 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return failure();
|
||||
isExpandingOrCollapsing = false;
|
||||
break;
|
||||
}
|
||||
currDstDim++;
|
||||
}
|
||||
if (currSrcDim != srcShape.size()) return failure();
|
||||
if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false;
|
||||
|
||||
if (!isExpandingOrCollapsing) {
|
||||
auto getIdentityExprs = [&rewriter](int n) {
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
for (int i = 0; i < n; ++i)
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
return exprs;
|
||||
};
|
||||
Location loc = reshapeOp.getLoc();
|
||||
int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
auto elemType = operandType.getElementType();
|
||||
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
|
||||
getIdentityExprs(dstShape.size())};
|
||||
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
|
||||
getIdentityExprs(srcShape.size())};
|
||||
|
||||
if (isLHLO) {
|
||||
auto collapsedType = MemRefType::get({totalElems}, elemType);
|
||||
Value collapsedOp = rewriter.create<linalg::ReshapeOp>(
|
||||
loc, collapsedType, args[0], collapsingMap);
|
||||
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||
loc, resultType, collapsedOp, expandingMap);
|
||||
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
|
||||
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
|
||||
/*outputPermutation =*/nullptr);
|
||||
} else {
|
||||
auto collapsedType = RankedTensorType::get({totalElems}, elemType);
|
||||
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
|
||||
loc, collapsedType, args[0], collapsingMap);
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||
reshapeOp, resultType, collapsedOp, expandingMap);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
if (isLHLO) {
|
||||
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||
|
@ -60,10 +60,76 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
||||
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
} else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
|
||||
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
} else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
|
||||
Type complex_element_ty = complex_ty.getElementType();
|
||||
if (complex_element_ty.isF32()) {
|
||||
return DenseElementsAttr::get(
|
||||
scalar_ty, static_cast<std::complex<float>>(raw_value));
|
||||
} else if (complex_element_ty.isF64()) {
|
||||
return DenseElementsAttr::get(
|
||||
scalar_ty, static_cast<std::complex<double>>(raw_value));
|
||||
}
|
||||
}
|
||||
auto int_ty = ty.cast<IntegerType>();
|
||||
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
llvm_unreachable("unsupported type");
|
||||
}
|
||||
|
||||
static APFloat GetScalarLimitOfFloatType(FloatType float_ty,
|
||||
ScalarLimit limit) {
|
||||
auto &semantics = float_ty.getFloatSemantics();
|
||||
switch (limit) {
|
||||
case kLowest:
|
||||
return APFloat::getLargest(semantics, /*negative=*/true);
|
||||
case kInfinityLowest:
|
||||
return APFloat::getInf(semantics, /*negative=*/true);
|
||||
case kMax:
|
||||
return APFloat::getLargest(semantics, /*negative=*/false);
|
||||
case kInfinityMax:
|
||||
return APFloat::getInf(semantics, /*negative=*/false);
|
||||
}
|
||||
llvm_unreachable("invalid limit");
|
||||
}
|
||||
|
||||
// Returns a scalar value for the given integer type.
|
||||
//
|
||||
// The argument 'scalar' describes which scalar value to return. `integer_value`
|
||||
// is used to specify the integer value for kInteger. For any other scalar,
|
||||
// integer_value is ignored.
|
||||
static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty,
|
||||
ScalarLimit limit) {
|
||||
unsigned width = integer_ty.getWidth();
|
||||
switch (limit) {
|
||||
case kLowest:
|
||||
case kInfinityLowest:
|
||||
if (integer_ty.isUnsigned()) {
|
||||
return APInt::getMinValue(width);
|
||||
} else {
|
||||
return APInt::getSignedMinValue(width);
|
||||
}
|
||||
|
||||
case kMax:
|
||||
case kInfinityMax:
|
||||
if (integer_ty.isUnsigned()) {
|
||||
return APInt::getMaxValue(width);
|
||||
} else {
|
||||
return APInt::getSignedMaxValue(width);
|
||||
}
|
||||
}
|
||||
llvm_unreachable("invalid limit");
|
||||
}
|
||||
|
||||
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
|
||||
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||
return DenseElementsAttr::get(scalar_ty,
|
||||
GetScalarLimitOfFloatType(float_ty, limit));
|
||||
} else if (auto integer_ty = ty.dyn_cast<IntegerType>()) {
|
||||
return DenseElementsAttr::get(
|
||||
scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit));
|
||||
}
|
||||
llvm_unreachable("unsupported type");
|
||||
}
|
||||
|
||||
} // namespace hlo
|
||||
|
@ -597,3 +597,24 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tupl
|
||||
// CHECK: return [[ARG0]]
|
||||
return %3 : tuple<tensor<i32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @erase_dead_lhlo_constant
|
||||
func @erase_dead_lhlo_constant() {
|
||||
%M = alloc() : memref<256x1024xf32>
|
||||
// CHECK-NEXT: return
|
||||
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||
dealloc %M : memref<256x1024xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// A negative test for dead lhlo constant op erasure.
|
||||
// CHECK-LABEL: func @erase_dead_lhlo_constant_negative
|
||||
func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
|
||||
// CHECK-NEXT: lmhlo.constant
|
||||
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
|
||||
// CHECK-NEXT: alloc
|
||||
// CHECK-NEXT: lmhlo.constant
|
||||
%N = alloc() : memref<256x1024xf32>
|
||||
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||
return %N : memref<256x1024xf32>
|
||||
}
|
||||
|
@ -373,6 +373,18 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_3D_4D
|
||||
func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32>
|
||||
return %0 : tensor<1x784x1x1xf32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @minf
|
||||
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "mhlo.minimum"(%lhs, %rhs)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user