Allow forwarding inputs in eager

PiperOrigin-RevId: 226396223
This commit is contained in:
A. Unique TensorFlower 2018-12-20 14:43:25 -08:00 committed by TensorFlower Gardener
parent e0963c4073
commit 0c31fca446
11 changed files with 161 additions and 26 deletions

View File

@ -148,6 +148,62 @@ tf_cuda_cc_test(
],
)
tf_cuda_library(
name = "c_api_experimental",
srcs = [
"c_api_experimental.cc",
],
hdrs = ["c_api_experimental.h"],
copts = tf_copts() + tfe_xla_copts(),
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api",
":c_api_internal",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/common_runtime/eager:copy_to_device_node",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
}) + select({
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/jit",
"//tensorflow/compiler/jit:xla_device",
],
"//conditions:default": [],
}) + [
"@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core:gpu_runtime",
],
)
cc_library(
name = "tape",
hdrs = ["tape.h"],

View File

@ -0,0 +1,23 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(h->handle);
}

View File

@ -0,0 +1,32 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_

View File

@ -30,4 +30,9 @@ void EagerOperation::AddInput(tensorflow::TensorHandle* h) {
inputs_.push_back(h);
attrs_.NumInputs(static_cast<int>(inputs_.size()));
}
void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) {
inputs_.push_back(h);
attrs_.NumInputs(static_cast<int>(inputs_.size()));
}
} // namespace tensorflow

View File

@ -53,6 +53,7 @@ class EagerOperation {
return &inputs_;
}
void AddInput(tensorflow::TensorHandle* h);
void ConsumeInput(tensorflow::TensorHandle* h);
const tensorflow::string& Name() const { return name_; }
const tensorflow::AttrTypeMap* AttrTypes() const { return attr_types_; }

View File

@ -712,22 +712,37 @@ Status EagerExecute(EagerContext* ctx, Device* device,
std::vector<Tensor> outputs(1);
const MemoryTypeVector* output_memory_types = nullptr;
output_memory_types = &kernel->kernel()->output_memory_types();
std::vector<Tensor> inputs(op_inputs.size());
// If there are multiple references to a TensorHandle in 'op_inputs' we must
// increment the reference count of the corresponding Tensor or risk it being
// overwritten during kernel execution. The reference count is incremented
// below when we insert a copy of the Tensor into protected_tensors, and will
// be decremented once execution is complete.
std::vector<tensorflow::Tensor> protected_tensors;
for (int i = 0; i < op_inputs.size(); ++i) {
if (!op_inputs[i]->RefCountIsOne()) {
const Tensor* input_tensor = nullptr;
TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
inputs[i] = *input_tensor;
protected_tensors.push_back(*input_tensor);
}
}
gtl::InlinedVector<TensorValue, 4> input_vector(op_inputs.size());
for (int i = 0; i < op_inputs.size(); ++i) {
TF_RETURN_IF_ERROR(op_inputs[i]->TensorValue(&input_vector[i]));
}
// TODO(apassos) figure out how to record stats for ops which are a part of
// functions.
// TODO(agarwal): change Run to take vector of handles ?
ScopedStepContainer* container = ctx->StepContainer();
if (container == nullptr) {
TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats,
TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, maybe_stats,
maybe_step_stats, graph_collector));
} else {
TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats,
maybe_step_stats, graph_collector));
TF_RETURN_IF_ERROR(kernel->Run(container, input_vector, &outputs,
maybe_stats, maybe_step_stats,
graph_collector));
}
if (maybe_stats != nullptr) {
int64 nanos = Env::Default()->NowNanos();

View File

@ -57,7 +57,7 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flr,
return OutputTypesForNode(ndef, *op_def, &out->output_dtypes_);
}
Status KernelAndDevice::Run(std::vector<Tensor>* inputs,
Status KernelAndDevice::Run(const gtl::InlinedVector<TensorValue, 4>& inputs,
std::vector<Tensor>* outputs, NodeExecStats* stats,
StepStats* step_stats,
GraphCollector* graph_collector) {
@ -69,15 +69,10 @@ Status KernelAndDevice::Run(std::vector<Tensor>* inputs,
}
Status KernelAndDevice::Run(ScopedStepContainer* step_container,
std::vector<Tensor>* inputs,
const gtl::InlinedVector<TensorValue, 4>& inputs,
std::vector<Tensor>* outputs, NodeExecStats* stats,
StepStats* step_stats,
GraphCollector* graph_collector) {
gtl::InlinedVector<TensorValue, 4> input_vector;
for (Tensor& t : *inputs) {
input_vector.push_back(TensorValue(&t));
}
std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs());
for (size_t i = 0; i < out_attrs.size(); ++i) {
out_attrs[i].set_on_host(kernel_->output_memory_types()[i] ==
@ -85,7 +80,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
}
gtl::InlinedVector<DeviceContext*, 4> input_device_contexts;
for (int i = 0; i < inputs->size(); i++) {
for (int i = 0; i < inputs.size(); i++) {
DeviceContext* device_context = nullptr;
if (device_->tensorflow_gpu_device_info() != nullptr) {
device_context = device_->tensorflow_gpu_device_info()->default_context;
@ -96,7 +91,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
OpKernelContext::Params params;
params.device = device_;
params.frame_iter = FrameAndIter(0, 0);
params.inputs = &input_vector;
params.inputs = &inputs;
params.op_kernel = kernel_.get();
params.resource_manager = device_->resource_manager();
params.output_attr_array = gtl::vector_as_array(&out_attrs);

View File

@ -68,11 +68,12 @@ class KernelAndDevice {
collective_executor_(std::move(collective_executor)) {}
// TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
NodeExecStats* stats, StepStats* step_stats,
GraphCollector* graph_collector);
Status Run(const gtl::InlinedVector<TensorValue, 4>& inputs,
std::vector<Tensor>* outputs, NodeExecStats* stats,
StepStats* step_stats, GraphCollector* graph_collector);
Status Run(ScopedStepContainer* step_container, std::vector<Tensor>* inputs,
Status Run(ScopedStepContainer* step_container,
const gtl::InlinedVector<TensorValue, 4>& inputs,
std::vector<Tensor>* outputs, NodeExecStats* stats,
StepStats* step_stats, GraphCollector* graph_collector);

View File

@ -118,9 +118,9 @@ BENCHMARK(BM_KernelAndDeviceInit);
void BM_KernelAndDeviceRun(int iters) {
tensorflow::testing::StopTiming();
Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
std::vector<Tensor> inputs;
inputs.push_back(t);
inputs.push_back(t);
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.push_back(TensorValue(&t));
inputs.push_back(TensorValue(&t));
std::vector<Tensor> outputs;
NodeDef ndef(AttrBuilder("MatMul")
.Set("T", DT_FLOAT)
@ -134,7 +134,7 @@ void BM_KernelAndDeviceRun(int iters) {
nullptr, &kernel));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr, nullptr));
TF_CHECK_OK(kernel.Run(inputs, &outputs, nullptr, nullptr, nullptr));
}
}
BENCHMARK(BM_KernelAndDeviceRun);

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
@ -79,6 +78,13 @@ Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
return Status::OK();
}
Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
TF_RETURN_IF_ERROR(WaitReady());
DCHECK(IsReady());
*t = tensorflow::TensorValue(&tensor_);
return Status::OK();
}
Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
tensorflow::Device** device,
tensorflow::Device** op_device) {

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
@ -102,6 +101,8 @@ class TensorHandle : public core::RefCounted {
Status Tensor(const tensorflow::Tensor** t);
Status TensorValue(tensorflow::TensorValue* t);
tensorflow::Device* device() const { return device_; }
tensorflow::Device* op_device() const { return op_device_; }