eager: Avoid unnecessary distinction between ops and functions.
PiperOrigin-RevId: 168022783
This commit is contained in:
parent
c14550a383
commit
cd68ce1f57
@ -79,6 +79,8 @@ tf_cc_test(
|
|||||||
"//tensorflow/cc:ops",
|
"//tensorflow/cc:ops",
|
||||||
"//tensorflow/cc:scope",
|
"//tensorflow/cc:scope",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
],
|
],
|
||||||
|
@ -479,20 +479,16 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
|||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
|
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
|
||||||
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
|
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
|
||||||
if (!op->is_function()) {
|
// Knowledge of the implementation of Init (and in-turn
|
||||||
status->status =
|
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
|
||||||
tensorflow::KernelAndDevice::InitOp(device, ndef, kernel);
|
// will be accessed, so grab on to the lock.
|
||||||
} else {
|
// See WARNING comment below - would be nice to rework to avoid this
|
||||||
// Knowledge of the implementation of InitFn (and in-turn
|
// subtlety.
|
||||||
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
|
tensorflow::tf_shared_lock l(ctx->functions_mu);
|
||||||
// will be accessed, so grab on to the lock.
|
status->status =
|
||||||
// See WARNING comment below - would be nice to rework to avoid this
|
tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
|
||||||
// subtlety.
|
|
||||||
tensorflow::mutex_lock l(ctx->functions_mu);
|
|
||||||
status->status = tensorflow::KernelAndDevice::InitFn(
|
|
||||||
ndef, ctx->func_lib(device), kernel);
|
|
||||||
}
|
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
|
delete kernel;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
|
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
|
||||||
|
@ -238,9 +238,8 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
Status KernelAndDevice::InitFn(const NodeDef& ndef,
|
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
||||||
FunctionLibraryRuntime* flib,
|
KernelAndDevice* out) {
|
||||||
KernelAndDevice* out) {
|
|
||||||
OpKernel* k = nullptr;
|
OpKernel* k = nullptr;
|
||||||
Status s = flib->CreateKernel(ndef, &k);
|
Status s = flib->CreateKernel(ndef, &k);
|
||||||
out->device_ = flib->device();
|
out->device_ = flib->device();
|
||||||
|
@ -150,28 +150,19 @@ class KernelAndDevice {
|
|||||||
public:
|
public:
|
||||||
// Populates 'out' with a kernel appropriate for 'ndef'.
|
// Populates 'out' with a kernel appropriate for 'ndef'.
|
||||||
//
|
//
|
||||||
// Assumes that 'ndef' refers to a primitive op (as opposed to a function).
|
|
||||||
static Status InitOp(Device* device, const NodeDef& ndef,
|
|
||||||
KernelAndDevice* out);
|
|
||||||
|
|
||||||
// Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
|
|
||||||
// TensorFlow function in the FunctionLibraryRuntime).
|
|
||||||
//
|
|
||||||
// The provided FunctionLibraryRuntime MUST outlive all calls to
|
// The provided FunctionLibraryRuntime MUST outlive all calls to
|
||||||
// Run() on the returned KernelAndDevice.
|
// Run() on the returned KernelAndDevice.
|
||||||
//
|
//
|
||||||
// TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn.
|
// TODO(ashankar): Figure out thread-safety concerns around
|
||||||
// The implementation of InitFn should work for both because
|
// FunctionLibraryRuntime (in particular, how the underlying
|
||||||
// FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if
|
// FunctionLibraryDefinition might be mutated by another thread as new
|
||||||
// appropriate. However, for now we keep them separate because I haven't
|
// functions are registered with it). Conservatively, thread-safe usage of
|
||||||
// figured out thread-safety concerns around FunctionLibraryRuntime (in
|
// the FunctionLibraryRuntime is pushed on to the caller (see locking in
|
||||||
// particular, how the underlying FunctionLibraryDefinition might be mutated
|
// c_api.cc).
|
||||||
// by another thread as new functions are registered with it).
|
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
||||||
// Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed
|
KernelAndDevice* out);
|
||||||
// on to the caller (see locking in c_api.cc) for now. But I really should
|
// TODO(ashankar): Remove this
|
||||||
// dig into this so that both InitOp and InitFn can be collapsed to
|
static Status InitOp(Device* device, const NodeDef& ndef,
|
||||||
// FunctionLibraryRuntime::CreateKernel.
|
|
||||||
static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
|
||||||
KernelAndDevice* out);
|
KernelAndDevice* out);
|
||||||
|
|
||||||
KernelAndDevice(tensorflow::Rendezvous* rendez)
|
KernelAndDevice(tensorflow::Rendezvous* rendez)
|
||||||
@ -184,10 +175,10 @@ class KernelAndDevice {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<OpKernel> kernel_;
|
std::unique_ptr<OpKernel> kernel_;
|
||||||
tensorflow::Device* device_;
|
Device* device_;
|
||||||
tensorflow::FunctionLibraryRuntime* flib_;
|
FunctionLibraryRuntime* flib_;
|
||||||
tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
|
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
|
||||||
tensorflow::Rendezvous* rendez_;
|
Rendezvous* rendez_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -23,15 +23,36 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/framework/scope.h"
|
#include "tensorflow/cc/framework/scope.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Device* CPUDevice() {
|
class TestEnv {
|
||||||
return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
|
public:
|
||||||
}
|
TestEnv() : flib_def_(OpRegistry::Global(), {}) {
|
||||||
|
Device* device =
|
||||||
|
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
|
||||||
|
device_mgr_.reset(new DeviceMgr({device}));
|
||||||
|
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
|
||||||
|
device, TF_GRAPH_DEF_VERSION,
|
||||||
|
&flib_def_, {}, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionLibraryRuntime* function_library_runtime() const {
|
||||||
|
return flib_runtime_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FunctionLibraryDefinition flib_def_;
|
||||||
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
|
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
|
||||||
|
};
|
||||||
|
|
||||||
TEST(AttrTypeMap, Lookup) {
|
TEST(AttrTypeMap, Lookup) {
|
||||||
const AttrTypeMap* m = nullptr;
|
const AttrTypeMap* m = nullptr;
|
||||||
@ -69,9 +90,10 @@ TEST(KernelAndDevice, Run) {
|
|||||||
.Set("transpose_b", false)
|
.Set("transpose_b", false)
|
||||||
.NumInputs(inputs.size())
|
.NumInputs(inputs.size())
|
||||||
.BuildNodeDef());
|
.BuildNodeDef());
|
||||||
std::unique_ptr<Device> device(CPUDevice());
|
TestEnv env;
|
||||||
KernelAndDevice kernel(nullptr);
|
KernelAndDevice kernel(nullptr);
|
||||||
Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
|
Status s =
|
||||||
|
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel);
|
||||||
ASSERT_TRUE(s.ok()) << s;
|
ASSERT_TRUE(s.ok()) << s;
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
s = kernel.Run(&inputs, &outputs);
|
s = kernel.Run(&inputs, &outputs);
|
||||||
@ -132,11 +154,12 @@ void BM_KernelAndDeviceInit(int iters) {
|
|||||||
.Set("transpose_b", false)
|
.Set("transpose_b", false)
|
||||||
.NumInputs(2)
|
.NumInputs(2)
|
||||||
.BuildNodeDef());
|
.BuildNodeDef());
|
||||||
std::unique_ptr<Device> device(CPUDevice());
|
TestEnv env;
|
||||||
KernelAndDevice k(nullptr);
|
KernelAndDevice k(nullptr);
|
||||||
tensorflow::testing::StartTiming();
|
tensorflow::testing::StartTiming();
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (int i = 0; i < iters; ++i) {
|
||||||
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k));
|
TF_CHECK_OK(
|
||||||
|
KernelAndDevice::Init(ndef, env.function_library_runtime(), &k));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_KernelAndDeviceInit);
|
BENCHMARK(BM_KernelAndDeviceInit);
|
||||||
@ -154,9 +177,10 @@ void BM_KernelAndDeviceRun(int iters) {
|
|||||||
.Set("transpose_b", false)
|
.Set("transpose_b", false)
|
||||||
.NumInputs(inputs.size())
|
.NumInputs(inputs.size())
|
||||||
.BuildNodeDef());
|
.BuildNodeDef());
|
||||||
std::unique_ptr<Device> device(CPUDevice());
|
TestEnv env;
|
||||||
KernelAndDevice kernel(nullptr);
|
KernelAndDevice kernel(nullptr);
|
||||||
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel));
|
TF_CHECK_OK(
|
||||||
|
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel));
|
||||||
tensorflow::testing::StartTiming();
|
tensorflow::testing::StartTiming();
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (int i = 0; i < iters; ++i) {
|
||||||
TF_CHECK_OK(kernel.Run(&inputs, &outputs));
|
TF_CHECK_OK(kernel.Run(&inputs, &outputs));
|
||||||
|
Loading…
Reference in New Issue
Block a user