Allowing tf.debugging.set_log_device_placement to control logging of device placement of tf.functions.

PiperOrigin-RevId: 314746176
Change-Id: I2516ec91e99f5d2b21db11b1cbfa2fdd1d5ef796
This commit is contained in:
Jiri Simsa 2020-06-04 09:43:27 -07:00 committed by TensorFlower Gardener
parent 80a93674ea
commit fe33f393b8
6 changed files with 27 additions and 14 deletions

View File

@ -537,7 +537,8 @@ Status GetOrCreateKernelAndDevice(
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
}
TF_RETURN_IF_ERROR(kernel->Init(ndef, graph_collector));
TF_RETURN_IF_ERROR(
kernel->Init({ctx.LogDevicePlacement()}, ndef, graph_collector));
if (op->is_function()) {
ctx.AddKernelToCache(cache_key, kernel.get());

View File

@ -99,7 +99,7 @@ KernelAndDeviceFunc::~KernelAndDeviceFunc() {
}
}
Status KernelAndDeviceOp::Init(const NodeDef& ndef,
Status KernelAndDeviceOp::Init(const Context& ctx, const NodeDef& ndef,
GraphCollector* graph_collector) {
OpKernel* k = nullptr;
if (flr_ == nullptr) {
@ -131,7 +131,8 @@ Status KernelAndDeviceOp::Init(const NodeDef& ndef,
return Status::OK();
}
Status KernelAndDeviceFunc::InstantiateFunc(const NodeDef& ndef,
Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
const NodeDef& ndef,
GraphCollector* graph_collector) {
const OpDef* op_def = nullptr;
const FunctionDef* function_def;
@ -209,14 +210,16 @@ Status KernelAndDeviceFunc::InstantiateFunc(const NodeDef& ndef,
->mutable_optimizer_options()
->set_do_function_inlining(true);
options.config_proto.set_log_device_placement(ctx.log_device_placement);
TF_RETURN_IF_ERROR(
pflr_->Instantiate(ndef.op(), AttrSlice(ndef), options, &handle_));
return pflr_->IsCrossProcess(handle_, &is_cross_process_);
}
Status KernelAndDeviceFunc::Init(const NodeDef& ndef,
Status KernelAndDeviceFunc::Init(const Context& ctx, const NodeDef& ndef,
GraphCollector* graph_collector) {
TF_RETURN_IF_ERROR(InstantiateFunc(ndef, graph_collector));
TF_RETURN_IF_ERROR(InstantiateFunc(ctx, ndef, graph_collector));
return pflr_->GetOutputDevices(handle_, &output_devices_);
}

View File

@ -93,11 +93,16 @@ class EagerKernelArgs : public FunctionArgsInterface {
// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
class KernelAndDevice : public core::RefCounted {
public:
struct Context {
bool log_device_placement = false;
};
// Populates this with a kernel appropriate for 'ndef'.
//
// The provided FunctionLibraryRuntime MUST outlive all calls to
// Run() on the returned KernelAndDevice.
virtual Status Init(const NodeDef& ndef, GraphCollector* graph_collector) = 0;
virtual Status Init(const Context& ctx, const NodeDef& ndef,
GraphCollector* graph_collector) = 0;
// Non-multi-device functions are run using regular CallOp and look like
// primitive operations from KernelAndDevice perspective.
@ -194,7 +199,8 @@ class KernelAndDeviceOp final : public KernelAndDevice {
~KernelAndDeviceOp() override {}
Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override;
Status Init(const Context& ctx, const NodeDef& ndef,
GraphCollector* graph_collector) override;
Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
std::vector<Tensor>* outputs,
@ -282,9 +288,11 @@ class KernelAndDeviceFunc : public KernelAndDevice {
bool IsFunction() override { return true; };
Status InstantiateFunc(const NodeDef& ndef, GraphCollector* graph_collector);
Status InstantiateFunc(const Context& ctx, const NodeDef& ndef,
GraphCollector* graph_collector);
Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override;
Status Init(const Context& ctx, const NodeDef& ndef,
GraphCollector* graph_collector) override;
Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
std::vector<Tensor>* outputs,

View File

@ -122,7 +122,7 @@ void BM_KernelAndDeviceInit(int iters) {
nullptr, env.cpu_device());
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(k.Init(ndef, nullptr));
TF_CHECK_OK(k.Init({}, ndef, nullptr));
}
}
BENCHMARK(BM_KernelAndDeviceInit);
@ -143,7 +143,7 @@ void BM_KernelAndDeviceRun(int iters) {
TestEnv env;
KernelAndDeviceOp k(nullptr, false, env.function_library_runtime(), nullptr,
nullptr, env.cpu_device());
TF_CHECK_OK(k.Init(ndef, nullptr));
TF_CHECK_OK(k.Init({}, ndef, nullptr));
const EagerKernelArgs args(std::move(inputs));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {

View File

@ -920,7 +920,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
// Instantiate MatMulFunction on remote_device.
const NodeDef node_def = MatMulFunctionNodeDef();
TF_ASSERT_OK(kernel->InstantiateFunc(node_def, nullptr));
TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
// Run MatMulFunction on remote_device.
gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
@ -967,7 +967,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
// Instantiate MatMulFunction on remote_device.
const NodeDef node_def = MatMulFunctionNodeDef();
TF_ASSERT_OK(kernel->InstantiateFunc(node_def, nullptr));
TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
// Run MatMulFunction on remote_device.
gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};

View File

@ -58,7 +58,8 @@ Status CreateUncachedKernelAndDeviceOp(
ctx.HostCPU()));
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
return kernel->get()->Init(ndef, /*graph_collector=*/nullptr);
return kernel->get()->Init({ctx.LogDevicePlacement()}, ndef,
/*graph_collector=*/nullptr);
}
// This gets a unique wire ID. We add a random identifier so that if the