eager: Avoid unnecessary distinction between ops and functions.

PiperOrigin-RevId: 168022783
This commit is contained in:
Asim Shankar 2017-09-08 11:29:17 -07:00 committed by TensorFlower Gardener
parent c14550a383
commit cd68ce1f57
5 changed files with 60 additions and 48 deletions

View File

@ -79,6 +79,8 @@ tf_cc_test(
"//tensorflow/cc:ops", "//tensorflow/cc:ops",
"//tensorflow/cc:scope", "//tensorflow/cc:scope",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
], ],

View File

@ -479,20 +479,16 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
if (kernel == nullptr) { if (kernel == nullptr) {
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
if (!op->is_function()) { // Knowledge of the implementation of Init (and in-turn
status->status = // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
tensorflow::KernelAndDevice::InitOp(device, ndef, kernel); // will be accessed, so grab on to the lock.
} else { // See WARNING comment below - would be nice to rework to avoid this
// Knowledge of the implementation of InitFn (and in-turn // subtlety.
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def tensorflow::tf_shared_lock l(ctx->functions_mu);
// will be accessed, so grab on to the lock. status->status =
// See WARNING comment below - would be nice to rework to avoid this tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
// subtlety.
tensorflow::mutex_lock l(ctx->functions_mu);
status->status = tensorflow::KernelAndDevice::InitFn(
ndef, ctx->func_lib(device), kernel);
}
if (!status->status.ok()) { if (!status->status.ok()) {
delete kernel;
return; return;
} }
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);

View File

@ -238,9 +238,8 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
} }
// static // static
Status KernelAndDevice::InitFn(const NodeDef& ndef, Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
FunctionLibraryRuntime* flib, KernelAndDevice* out) {
KernelAndDevice* out) {
OpKernel* k = nullptr; OpKernel* k = nullptr;
Status s = flib->CreateKernel(ndef, &k); Status s = flib->CreateKernel(ndef, &k);
out->device_ = flib->device(); out->device_ = flib->device();

View File

@ -150,28 +150,19 @@ class KernelAndDevice {
public: public:
// Populates 'out' with a kernel appropriate for 'ndef'. // Populates 'out' with a kernel appropriate for 'ndef'.
// //
// Assumes that 'ndef' refers to a primitive op (as opposed to a function).
static Status InitOp(Device* device, const NodeDef& ndef,
KernelAndDevice* out);
// Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
// TensorFlow function in the FunctionLibraryRuntime).
//
// The provided FunctionLibraryRuntime MUST outlive all calls to // The provided FunctionLibraryRuntime MUST outlive all calls to
// Run() on the returned KernelAndDevice. // Run() on the returned KernelAndDevice.
// //
// TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn. // TODO(ashankar): Figure out thread-safety concerns around
// The implementation of InitFn should work for both because // FunctionLibraryRuntime (in particular, how the underlying
// FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if // FunctionLibraryDefinition might be mutated by another thread as new
// appropriate. However, for now we keep them separate because I haven't // functions are registered with it). Conservatively, thread-safe usage of
// figured out thread-safety concerns around FunctionLibraryRuntime (in // the FunctionLibraryRuntime is pushed on to the caller (see locking in
// particular, how the underlying FunctionLibraryDefinition might be mutated // c_api.cc).
// by another thread as new functions are registered with it). static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
// Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed KernelAndDevice* out);
// on to the caller (see locking in c_api.cc) for now. But I really should // TODO(ashankar): Remove this
// dig into this so that both InitOp and InitFn can be collapsed to static Status InitOp(Device* device, const NodeDef& ndef,
// FunctionLibraryRuntime::CreateKernel.
static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
KernelAndDevice* out); KernelAndDevice* out);
KernelAndDevice(tensorflow::Rendezvous* rendez) KernelAndDevice(tensorflow::Rendezvous* rendez)
@ -184,10 +175,10 @@ class KernelAndDevice {
private: private:
std::unique_ptr<OpKernel> kernel_; std::unique_ptr<OpKernel> kernel_;
tensorflow::Device* device_; Device* device_;
tensorflow::FunctionLibraryRuntime* flib_; FunctionLibraryRuntime* flib_;
tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
tensorflow::Rendezvous* rendez_; Rendezvous* rendez_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -23,15 +23,36 @@ limitations under the License.
#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
Device* CPUDevice() { class TestEnv {
return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); public:
} TestEnv() : flib_def_(OpRegistry::Global(), {}) {
Device* device =
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
device_mgr_.reset(new DeviceMgr({device}));
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
device, TF_GRAPH_DEF_VERSION,
&flib_def_, {}, nullptr);
}
FunctionLibraryRuntime* function_library_runtime() const {
return flib_runtime_.get();
}
private:
FunctionLibraryDefinition flib_def_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
};
TEST(AttrTypeMap, Lookup) { TEST(AttrTypeMap, Lookup) {
const AttrTypeMap* m = nullptr; const AttrTypeMap* m = nullptr;
@ -69,9 +90,10 @@ TEST(KernelAndDevice, Run) {
.Set("transpose_b", false) .Set("transpose_b", false)
.NumInputs(inputs.size()) .NumInputs(inputs.size())
.BuildNodeDef()); .BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice()); TestEnv env;
KernelAndDevice kernel(nullptr); KernelAndDevice kernel(nullptr);
Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel); Status s =
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel);
ASSERT_TRUE(s.ok()) << s; ASSERT_TRUE(s.ok()) << s;
std::vector<Tensor> outputs; std::vector<Tensor> outputs;
s = kernel.Run(&inputs, &outputs); s = kernel.Run(&inputs, &outputs);
@ -132,11 +154,12 @@ void BM_KernelAndDeviceInit(int iters) {
.Set("transpose_b", false) .Set("transpose_b", false)
.NumInputs(2) .NumInputs(2)
.BuildNodeDef()); .BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice()); TestEnv env;
KernelAndDevice k(nullptr); KernelAndDevice k(nullptr);
tensorflow::testing::StartTiming(); tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) { for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k)); TF_CHECK_OK(
KernelAndDevice::Init(ndef, env.function_library_runtime(), &k));
} }
} }
BENCHMARK(BM_KernelAndDeviceInit); BENCHMARK(BM_KernelAndDeviceInit);
@ -154,9 +177,10 @@ void BM_KernelAndDeviceRun(int iters) {
.Set("transpose_b", false) .Set("transpose_b", false)
.NumInputs(inputs.size()) .NumInputs(inputs.size())
.BuildNodeDef()); .BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice()); TestEnv env;
KernelAndDevice kernel(nullptr); KernelAndDevice kernel(nullptr);
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel)); TF_CHECK_OK(
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel));
tensorflow::testing::StartTiming(); tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) { for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(kernel.Run(&inputs, &outputs)); TF_CHECK_OK(kernel.Run(&inputs, &outputs));