eager: Avoid unnecessary distinction between ops and functions.

PiperOrigin-RevId: 168022783
This commit is contained in:
Asim Shankar 2017-09-08 11:29:17 -07:00 committed by TensorFlower Gardener
parent c14550a383
commit cd68ce1f57
5 changed files with 60 additions and 48 deletions

View File

@ -79,6 +79,8 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],

View File

@ -479,20 +479,16 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
if (kernel == nullptr) {
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
if (!op->is_function()) {
status->status =
tensorflow::KernelAndDevice::InitOp(device, ndef, kernel);
} else {
// Knowledge of the implementation of InitFn (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
// will be accessed, so grab on to the lock.
// See WARNING comment below - would be nice to rework to avoid this
// subtlety.
tensorflow::mutex_lock l(ctx->functions_mu);
status->status = tensorflow::KernelAndDevice::InitFn(
ndef, ctx->func_lib(device), kernel);
}
// Knowledge of the implementation of Init (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
// will be accessed, so grab on to the lock.
// See WARNING comment below - would be nice to rework to avoid this
// subtlety.
tensorflow::tf_shared_lock l(ctx->functions_mu);
status->status =
tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
if (!status->status.ok()) {
delete kernel;
return;
}
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);

View File

@ -238,9 +238,8 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
}
// static
Status KernelAndDevice::InitFn(const NodeDef& ndef,
FunctionLibraryRuntime* flib,
KernelAndDevice* out) {
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
KernelAndDevice* out) {
OpKernel* k = nullptr;
Status s = flib->CreateKernel(ndef, &k);
out->device_ = flib->device();

View File

@ -150,28 +150,19 @@ class KernelAndDevice {
public:
// Populates 'out' with a kernel appropriate for 'ndef'.
//
// Assumes that 'ndef' refers to a primitive op (as opposed to a function).
static Status InitOp(Device* device, const NodeDef& ndef,
KernelAndDevice* out);
// Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
// TensorFlow function in the FunctionLibraryRuntime).
//
// The provided FunctionLibraryRuntime MUST outlive all calls to
// Run() on the returned KernelAndDevice.
//
// TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn.
// The implementation of InitFn should work for both because
// FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if
// appropriate. However, for now we keep them separate because I haven't
// figured out thread-safety concerns around FunctionLibraryRuntime (in
// particular, how the underlying FunctionLibraryDefinition might be mutated
// by another thread as new functions are registered with it).
// Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed
// on to the caller (see locking in c_api.cc) for now. But I really should
// dig into this so that both InitOp and InitFn can be collapsed to
// FunctionLibraryRuntime::CreateKernel.
static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
// TODO(ashankar): Figure out thread-safety concerns around
// FunctionLibraryRuntime (in particular, how the underlying
// FunctionLibraryDefinition might be mutated by another thread as new
// functions are registered with it). Conservatively, thread-safe usage of
// the FunctionLibraryRuntime is pushed on to the caller (see locking in
// c_api.cc).
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
KernelAndDevice* out);
// TODO(ashankar): Remove this
static Status InitOp(Device* device, const NodeDef& ndef,
KernelAndDevice* out);
KernelAndDevice(tensorflow::Rendezvous* rendez)
@ -184,10 +175,10 @@ class KernelAndDevice {
private:
std::unique_ptr<OpKernel> kernel_;
tensorflow::Device* device_;
tensorflow::FunctionLibraryRuntime* flib_;
tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
tensorflow::Rendezvous* rendez_;
Device* device_;
FunctionLibraryRuntime* flib_;
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
Rendezvous* rendez_;
};
} // namespace tensorflow

View File

@ -23,15 +23,36 @@ limitations under the License.
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
Device* CPUDevice() {
return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
}
class TestEnv {
public:
TestEnv() : flib_def_(OpRegistry::Global(), {}) {
Device* device =
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
device_mgr_.reset(new DeviceMgr({device}));
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
device, TF_GRAPH_DEF_VERSION,
&flib_def_, {}, nullptr);
}
FunctionLibraryRuntime* function_library_runtime() const {
return flib_runtime_.get();
}
private:
FunctionLibraryDefinition flib_def_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
};
TEST(AttrTypeMap, Lookup) {
const AttrTypeMap* m = nullptr;
@ -69,9 +90,10 @@ TEST(KernelAndDevice, Run) {
.Set("transpose_b", false)
.NumInputs(inputs.size())
.BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice());
TestEnv env;
KernelAndDevice kernel(nullptr);
Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
Status s =
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel);
ASSERT_TRUE(s.ok()) << s;
std::vector<Tensor> outputs;
s = kernel.Run(&inputs, &outputs);
@ -132,11 +154,12 @@ void BM_KernelAndDeviceInit(int iters) {
.Set("transpose_b", false)
.NumInputs(2)
.BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice());
TestEnv env;
KernelAndDevice k(nullptr);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k));
TF_CHECK_OK(
KernelAndDevice::Init(ndef, env.function_library_runtime(), &k));
}
}
BENCHMARK(BM_KernelAndDeviceInit);
@ -154,9 +177,10 @@ void BM_KernelAndDeviceRun(int iters) {
.Set("transpose_b", false)
.NumInputs(inputs.size())
.BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice());
TestEnv env;
KernelAndDevice kernel(nullptr);
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel));
TF_CHECK_OK(
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(kernel.Run(&inputs, &outputs));