Allow forwarding inputs in eager
PiperOrigin-RevId: 226396223
This commit is contained in:
parent
e0963c4073
commit
0c31fca446
@ -148,6 +148,62 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_experimental",
|
||||
srcs = [
|
||||
"c_api_experimental.cc",
|
||||
],
|
||||
hdrs = ["c_api_experimental.h"],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/common_runtime/eager:copy_to_device_node",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
|
||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape",
|
||||
hdrs = ["tape.h"],
|
||||
|
23
tensorflow/c/eager/c_api_experimental.cc
Normal file
23
tensorflow/c/eager/c_api_experimental.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
|
||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
op->operation.ConsumeInput(h->handle);
|
||||
}
|
32
tensorflow/c/eager/c_api_experimental.h
Normal file
32
tensorflow/c/eager/c_api_experimental.h
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
|
@ -30,4 +30,9 @@ void EagerOperation::AddInput(tensorflow::TensorHandle* h) {
|
||||
inputs_.push_back(h);
|
||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||
}
|
||||
|
||||
void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) {
|
||||
inputs_.push_back(h);
|
||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -53,6 +53,7 @@ class EagerOperation {
|
||||
return &inputs_;
|
||||
}
|
||||
void AddInput(tensorflow::TensorHandle* h);
|
||||
void ConsumeInput(tensorflow::TensorHandle* h);
|
||||
|
||||
const tensorflow::string& Name() const { return name_; }
|
||||
const tensorflow::AttrTypeMap* AttrTypes() const { return attr_types_; }
|
||||
|
@ -712,22 +712,37 @@ Status EagerExecute(EagerContext* ctx, Device* device,
|
||||
std::vector<Tensor> outputs(1);
|
||||
const MemoryTypeVector* output_memory_types = nullptr;
|
||||
output_memory_types = &kernel->kernel()->output_memory_types();
|
||||
std::vector<Tensor> inputs(op_inputs.size());
|
||||
|
||||
// If there are multiple references to a TensorHandle in 'op_inputs' we must
|
||||
// increment the reference count of the corresponding Tensor or risk it being
|
||||
// overwritten during kernel execution. The reference count is incremented
|
||||
// below when we insert a copy of the Tensor into protected_tensors, and will
|
||||
// be decremented once execution is complete.
|
||||
std::vector<tensorflow::Tensor> protected_tensors;
|
||||
for (int i = 0; i < op_inputs.size(); ++i) {
|
||||
const Tensor* input_tensor = nullptr;
|
||||
TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
|
||||
inputs[i] = *input_tensor;
|
||||
if (!op_inputs[i]->RefCountIsOne()) {
|
||||
const Tensor* input_tensor = nullptr;
|
||||
TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
|
||||
protected_tensors.push_back(*input_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> input_vector(op_inputs.size());
|
||||
for (int i = 0; i < op_inputs.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(op_inputs[i]->TensorValue(&input_vector[i]));
|
||||
}
|
||||
|
||||
// TODO(apassos) figure out how to record stats for ops which are a part of
|
||||
// functions.
|
||||
// TODO(agarwal): change Run to take vector of handles ?
|
||||
ScopedStepContainer* container = ctx->StepContainer();
|
||||
if (container == nullptr) {
|
||||
TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats,
|
||||
TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, maybe_stats,
|
||||
maybe_step_stats, graph_collector));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats,
|
||||
maybe_step_stats, graph_collector));
|
||||
TF_RETURN_IF_ERROR(kernel->Run(container, input_vector, &outputs,
|
||||
maybe_stats, maybe_step_stats,
|
||||
graph_collector));
|
||||
}
|
||||
if (maybe_stats != nullptr) {
|
||||
int64 nanos = Env::Default()->NowNanos();
|
||||
|
@ -57,7 +57,7 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flr,
|
||||
return OutputTypesForNode(ndef, *op_def, &out->output_dtypes_);
|
||||
}
|
||||
|
||||
Status KernelAndDevice::Run(std::vector<Tensor>* inputs,
|
||||
Status KernelAndDevice::Run(const gtl::InlinedVector<TensorValue, 4>& inputs,
|
||||
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
||||
StepStats* step_stats,
|
||||
GraphCollector* graph_collector) {
|
||||
@ -69,15 +69,10 @@ Status KernelAndDevice::Run(std::vector<Tensor>* inputs,
|
||||
}
|
||||
|
||||
Status KernelAndDevice::Run(ScopedStepContainer* step_container,
|
||||
std::vector<Tensor>* inputs,
|
||||
const gtl::InlinedVector<TensorValue, 4>& inputs,
|
||||
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
||||
StepStats* step_stats,
|
||||
GraphCollector* graph_collector) {
|
||||
gtl::InlinedVector<TensorValue, 4> input_vector;
|
||||
for (Tensor& t : *inputs) {
|
||||
input_vector.push_back(TensorValue(&t));
|
||||
}
|
||||
|
||||
std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs());
|
||||
for (size_t i = 0; i < out_attrs.size(); ++i) {
|
||||
out_attrs[i].set_on_host(kernel_->output_memory_types()[i] ==
|
||||
@ -85,7 +80,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
|
||||
}
|
||||
|
||||
gtl::InlinedVector<DeviceContext*, 4> input_device_contexts;
|
||||
for (int i = 0; i < inputs->size(); i++) {
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
DeviceContext* device_context = nullptr;
|
||||
if (device_->tensorflow_gpu_device_info() != nullptr) {
|
||||
device_context = device_->tensorflow_gpu_device_info()->default_context;
|
||||
@ -96,7 +91,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
|
||||
OpKernelContext::Params params;
|
||||
params.device = device_;
|
||||
params.frame_iter = FrameAndIter(0, 0);
|
||||
params.inputs = &input_vector;
|
||||
params.inputs = &inputs;
|
||||
params.op_kernel = kernel_.get();
|
||||
params.resource_manager = device_->resource_manager();
|
||||
params.output_attr_array = gtl::vector_as_array(&out_attrs);
|
||||
|
@ -68,11 +68,12 @@ class KernelAndDevice {
|
||||
collective_executor_(std::move(collective_executor)) {}
|
||||
|
||||
// TODO(ashankar): Handle list-valued inputs.
|
||||
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
|
||||
NodeExecStats* stats, StepStats* step_stats,
|
||||
GraphCollector* graph_collector);
|
||||
Status Run(const gtl::InlinedVector<TensorValue, 4>& inputs,
|
||||
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
||||
StepStats* step_stats, GraphCollector* graph_collector);
|
||||
|
||||
Status Run(ScopedStepContainer* step_container, std::vector<Tensor>* inputs,
|
||||
Status Run(ScopedStepContainer* step_container,
|
||||
const gtl::InlinedVector<TensorValue, 4>& inputs,
|
||||
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
||||
StepStats* step_stats, GraphCollector* graph_collector);
|
||||
|
||||
|
@ -118,9 +118,9 @@ BENCHMARK(BM_KernelAndDeviceInit);
|
||||
void BM_KernelAndDeviceRun(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
|
||||
std::vector<Tensor> inputs;
|
||||
inputs.push_back(t);
|
||||
inputs.push_back(t);
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
inputs.push_back(TensorValue(&t));
|
||||
inputs.push_back(TensorValue(&t));
|
||||
std::vector<Tensor> outputs;
|
||||
NodeDef ndef(AttrBuilder("MatMul")
|
||||
.Set("T", DT_FLOAT)
|
||||
@ -134,7 +134,7 @@ void BM_KernelAndDeviceRun(int iters) {
|
||||
nullptr, &kernel));
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr, nullptr));
|
||||
TF_CHECK_OK(kernel.Run(inputs, &outputs, nullptr, nullptr, nullptr));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_KernelAndDeviceRun);
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
@ -79,6 +78,13 @@ Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
|
||||
TF_RETURN_IF_ERROR(WaitReady());
|
||||
DCHECK(IsReady());
|
||||
*t = tensorflow::TensorValue(&tensor_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
|
||||
tensorflow::Device** device,
|
||||
tensorflow::Device** op_device) {
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
@ -102,6 +101,8 @@ class TensorHandle : public core::RefCounted {
|
||||
|
||||
Status Tensor(const tensorflow::Tensor** t);
|
||||
|
||||
Status TensorValue(tensorflow::TensorValue* t);
|
||||
|
||||
tensorflow::Device* device() const { return device_; }
|
||||
|
||||
tensorflow::Device* op_device() const { return op_device_; }
|
||||
|
Loading…
Reference in New Issue
Block a user