eager: Avoid unnecessary distinction between ops and functions.
PiperOrigin-RevId: 168022783
This commit is contained in:
parent
c14550a383
commit
cd68ce1f57
@ -79,6 +79,8 @@ tf_cc_test(
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
|
@ -479,20 +479,16 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
if (kernel == nullptr) {
|
||||
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
|
||||
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
|
||||
if (!op->is_function()) {
|
||||
status->status =
|
||||
tensorflow::KernelAndDevice::InitOp(device, ndef, kernel);
|
||||
} else {
|
||||
// Knowledge of the implementation of InitFn (and in-turn
|
||||
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
|
||||
// will be accessed, so grab on to the lock.
|
||||
// See WARNING comment below - would be nice to rework to avoid this
|
||||
// subtlety.
|
||||
tensorflow::mutex_lock l(ctx->functions_mu);
|
||||
status->status = tensorflow::KernelAndDevice::InitFn(
|
||||
ndef, ctx->func_lib(device), kernel);
|
||||
}
|
||||
// Knowledge of the implementation of Init (and in-turn
|
||||
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
|
||||
// will be accessed, so grab on to the lock.
|
||||
// See WARNING comment below - would be nice to rework to avoid this
|
||||
// subtlety.
|
||||
tensorflow::tf_shared_lock l(ctx->functions_mu);
|
||||
status->status =
|
||||
tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
|
||||
if (!status->status.ok()) {
|
||||
delete kernel;
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
|
||||
|
@ -238,9 +238,8 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
|
||||
}
|
||||
|
||||
// static
|
||||
Status KernelAndDevice::InitFn(const NodeDef& ndef,
|
||||
FunctionLibraryRuntime* flib,
|
||||
KernelAndDevice* out) {
|
||||
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
||||
KernelAndDevice* out) {
|
||||
OpKernel* k = nullptr;
|
||||
Status s = flib->CreateKernel(ndef, &k);
|
||||
out->device_ = flib->device();
|
||||
|
@ -150,28 +150,19 @@ class KernelAndDevice {
|
||||
public:
|
||||
// Populates 'out' with a kernel appropriate for 'ndef'.
|
||||
//
|
||||
// Assumes that 'ndef' refers to a primitive op (as opposed to a function).
|
||||
static Status InitOp(Device* device, const NodeDef& ndef,
|
||||
KernelAndDevice* out);
|
||||
|
||||
// Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
|
||||
// TensorFlow function in the FunctionLibraryRuntime).
|
||||
//
|
||||
// The provided FunctionLibraryRuntime MUST outlive all calls to
|
||||
// Run() on the returned KernelAndDevice.
|
||||
//
|
||||
// TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn.
|
||||
// The implementation of InitFn should work for both because
|
||||
// FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if
|
||||
// appropriate. However, for now we keep them separate because I haven't
|
||||
// figured out thread-safety concerns around FunctionLibraryRuntime (in
|
||||
// particular, how the underlying FunctionLibraryDefinition might be mutated
|
||||
// by another thread as new functions are registered with it).
|
||||
// Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed
|
||||
// on to the caller (see locking in c_api.cc) for now. But I really should
|
||||
// dig into this so that both InitOp and InitFn can be collapsed to
|
||||
// FunctionLibraryRuntime::CreateKernel.
|
||||
static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
||||
// TODO(ashankar): Figure out thread-safety concerns around
|
||||
// FunctionLibraryRuntime (in particular, how the underlying
|
||||
// FunctionLibraryDefinition might be mutated by another thread as new
|
||||
// functions are registered with it). Conservatively, thread-safe usage of
|
||||
// the FunctionLibraryRuntime is pushed on to the caller (see locking in
|
||||
// c_api.cc).
|
||||
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
||||
KernelAndDevice* out);
|
||||
// TODO(ashankar): Remove this
|
||||
static Status InitOp(Device* device, const NodeDef& ndef,
|
||||
KernelAndDevice* out);
|
||||
|
||||
KernelAndDevice(tensorflow::Rendezvous* rendez)
|
||||
@ -184,10 +175,10 @@ class KernelAndDevice {
|
||||
|
||||
private:
|
||||
std::unique_ptr<OpKernel> kernel_;
|
||||
tensorflow::Device* device_;
|
||||
tensorflow::FunctionLibraryRuntime* flib_;
|
||||
tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
|
||||
tensorflow::Rendezvous* rendez_;
|
||||
Device* device_;
|
||||
FunctionLibraryRuntime* flib_;
|
||||
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
|
||||
Rendezvous* rendez_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -23,15 +23,36 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Device* CPUDevice() {
|
||||
return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
|
||||
}
|
||||
class TestEnv {
|
||||
public:
|
||||
TestEnv() : flib_def_(OpRegistry::Global(), {}) {
|
||||
Device* device =
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
|
||||
device_mgr_.reset(new DeviceMgr({device}));
|
||||
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
|
||||
device, TF_GRAPH_DEF_VERSION,
|
||||
&flib_def_, {}, nullptr);
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* function_library_runtime() const {
|
||||
return flib_runtime_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
FunctionLibraryDefinition flib_def_;
|
||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
|
||||
};
|
||||
|
||||
TEST(AttrTypeMap, Lookup) {
|
||||
const AttrTypeMap* m = nullptr;
|
||||
@ -69,9 +90,10 @@ TEST(KernelAndDevice, Run) {
|
||||
.Set("transpose_b", false)
|
||||
.NumInputs(inputs.size())
|
||||
.BuildNodeDef());
|
||||
std::unique_ptr<Device> device(CPUDevice());
|
||||
TestEnv env;
|
||||
KernelAndDevice kernel(nullptr);
|
||||
Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
|
||||
Status s =
|
||||
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel);
|
||||
ASSERT_TRUE(s.ok()) << s;
|
||||
std::vector<Tensor> outputs;
|
||||
s = kernel.Run(&inputs, &outputs);
|
||||
@ -132,11 +154,12 @@ void BM_KernelAndDeviceInit(int iters) {
|
||||
.Set("transpose_b", false)
|
||||
.NumInputs(2)
|
||||
.BuildNodeDef());
|
||||
std::unique_ptr<Device> device(CPUDevice());
|
||||
TestEnv env;
|
||||
KernelAndDevice k(nullptr);
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k));
|
||||
TF_CHECK_OK(
|
||||
KernelAndDevice::Init(ndef, env.function_library_runtime(), &k));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_KernelAndDeviceInit);
|
||||
@ -154,9 +177,10 @@ void BM_KernelAndDeviceRun(int iters) {
|
||||
.Set("transpose_b", false)
|
||||
.NumInputs(inputs.size())
|
||||
.BuildNodeDef());
|
||||
std::unique_ptr<Device> device(CPUDevice());
|
||||
TestEnv env;
|
||||
KernelAndDevice kernel(nullptr);
|
||||
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel));
|
||||
TF_CHECK_OK(
|
||||
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel));
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TF_CHECK_OK(kernel.Run(&inputs, &outputs));
|
||||
|
Loading…
Reference in New Issue
Block a user