From cd68ce1f57ea844f4f6f4b394a17b21eb48b728a Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Fri, 8 Sep 2017 11:29:17 -0700 Subject: [PATCH] eager: Avoid unnecessary distinction between ops and functions. PiperOrigin-RevId: 168022783 --- tensorflow/c/eager/BUILD | 2 ++ tensorflow/c/eager/c_api.cc | 22 +++++++--------- tensorflow/c/eager/runtime.cc | 5 ++-- tensorflow/c/eager/runtime.h | 37 ++++++++++---------------- tensorflow/c/eager/runtime_test.cc | 42 +++++++++++++++++++++++------- 5 files changed, 60 insertions(+), 48 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 85c4e4fd93c..52945d32391 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -79,6 +79,8 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 473e6339f36..56f7303f70f 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -479,20 +479,16 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, if (kernel == nullptr) { const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); - if (!op->is_function()) { - status->status = - tensorflow::KernelAndDevice::InitOp(device, ndef, kernel); - } else { - // Knowledge of the implementation of InitFn (and in-turn - // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def - // will be accessed, so grab on to the lock. - // See WARNING comment below - would be nice to rework to avoid this - // subtlety. - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = tensorflow::KernelAndDevice::InitFn( - ndef, ctx->func_lib(device), kernel); - } + // Knowledge of the implementation of Init (and in-turn + // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def + // will be accessed, so grab on to the lock. + // See WARNING comment below - would be nice to rework to avoid this + // subtlety. + tensorflow::tf_shared_lock l(ctx->functions_mu); + status->status = + tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); if (!status->status.ok()) { + delete kernel; return; } tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index b6d53872c97..3b39903e09a 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -238,9 +238,8 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, } // static -Status KernelAndDevice::InitFn(const NodeDef& ndef, - FunctionLibraryRuntime* flib, - KernelAndDevice* out) { +Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, + KernelAndDevice* out) { OpKernel* k = nullptr; Status s = flib->CreateKernel(ndef, &k); out->device_ = flib->device(); diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index bb098f74013..13b49e5e8cb 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -150,28 +150,19 @@ class KernelAndDevice { public: // Populates 'out' with a kernel appropriate for 'ndef'. // - // Assumes that 'ndef' refers to a primitive op (as opposed to a function). - static Status InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out); - - // Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a - // TensorFlow function in the FunctionLibraryRuntime). - // // The provided FunctionLibraryRuntime MUST outlive all calls to // Run() on the returned KernelAndDevice. // - // TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn. - // The implementation of InitFn should work for both because - // FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if - // appropriate. However, for now we keep them separate because I haven't - // figured out thread-safety concerns around FunctionLibraryRuntime (in - // particular, how the underlying FunctionLibraryDefinition might be mutated - // by another thread as new functions are registered with it). - // Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed - // on to the caller (see locking in c_api.cc) for now. But I really should - // dig into this so that both InitOp and InitFn can be collapsed to - // FunctionLibraryRuntime::CreateKernel. - static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib, + // TODO(ashankar): Figure out thread-safety concerns around + // FunctionLibraryRuntime (in particular, how the underlying + // FunctionLibraryDefinition might be mutated by another thread as new + // functions are registered with it). Conservatively, thread-safe usage of + // the FunctionLibraryRuntime is pushed on to the caller (see locking in + // c_api.cc). + static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, + KernelAndDevice* out); + // TODO(ashankar): Remove this + static Status InitOp(Device* device, const NodeDef& ndef, KernelAndDevice* out); KernelAndDevice(tensorflow::Rendezvous* rendez) @@ -184,10 +175,10 @@ class KernelAndDevice { private: std::unique_ptr kernel_; - tensorflow::Device* device_; - tensorflow::FunctionLibraryRuntime* flib_; - tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; - tensorflow::Rendezvous* rendez_; + Device* device_; + FunctionLibraryRuntime* flib_; + checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; + Rendezvous* rendez_; }; } // namespace tensorflow diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index f9bfce38580..3236c6be0ec 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -23,15 +23,36 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { -Device* CPUDevice() { - return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); -} +class TestEnv { + public: + TestEnv() : flib_def_(OpRegistry::Global(), {}) { + Device* device = + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); + device_mgr_.reset(new DeviceMgr({device})); + flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(), + device, TF_GRAPH_DEF_VERSION, + &flib_def_, {}, nullptr); + } + + FunctionLibraryRuntime* function_library_runtime() const { + return flib_runtime_.get(); + } + + private: + FunctionLibraryDefinition flib_def_; + std::unique_ptr device_mgr_; + std::unique_ptr flib_runtime_; +}; TEST(AttrTypeMap, Lookup) { const AttrTypeMap* m = nullptr; @@ -69,9 +90,10 @@ TEST(KernelAndDevice, Run) { .Set("transpose_b", false) .NumInputs(inputs.size()) .BuildNodeDef()); - std::unique_ptr device(CPUDevice()); + TestEnv env; KernelAndDevice kernel(nullptr); - Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel); + Status s = + KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel); ASSERT_TRUE(s.ok()) << s; std::vector outputs; s = kernel.Run(&inputs, &outputs); @@ -132,11 +154,12 @@ void BM_KernelAndDeviceInit(int iters) { .Set("transpose_b", false) .NumInputs(2) .BuildNodeDef()); - std::unique_ptr device(CPUDevice()); + TestEnv env; KernelAndDevice k(nullptr); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k)); + TF_CHECK_OK( + KernelAndDevice::Init(ndef, env.function_library_runtime(), &k)); } } BENCHMARK(BM_KernelAndDeviceInit); @@ -154,9 +177,10 @@ void BM_KernelAndDeviceRun(int iters) { .Set("transpose_b", false) .NumInputs(inputs.size()) .BuildNodeDef()); - std::unique_ptr device(CPUDevice()); + TestEnv env; KernelAndDevice kernel(nullptr); - TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel)); + TF_CHECK_OK( + KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_CHECK_OK(kernel.Run(&inputs, &outputs));