Allowing tf.debugging.set_log_device_placement
to control logging of device placement of tf.functions.
PiperOrigin-RevId: 314746176 Change-Id: I2516ec91e99f5d2b21db11b1cbfa2fdd1d5ef796
This commit is contained in:
parent
80a93674ea
commit
fe33f393b8
tensorflow/core
common_runtime/eager
distributed_runtime/eager
@ -537,7 +537,8 @@ Status GetOrCreateKernelAndDevice(
|
||||
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(kernel->Init(ndef, graph_collector));
|
||||
TF_RETURN_IF_ERROR(
|
||||
kernel->Init({ctx.LogDevicePlacement()}, ndef, graph_collector));
|
||||
|
||||
if (op->is_function()) {
|
||||
ctx.AddKernelToCache(cache_key, kernel.get());
|
||||
|
@ -99,7 +99,7 @@ KernelAndDeviceFunc::~KernelAndDeviceFunc() {
|
||||
}
|
||||
}
|
||||
|
||||
Status KernelAndDeviceOp::Init(const NodeDef& ndef,
|
||||
Status KernelAndDeviceOp::Init(const Context& ctx, const NodeDef& ndef,
|
||||
GraphCollector* graph_collector) {
|
||||
OpKernel* k = nullptr;
|
||||
if (flr_ == nullptr) {
|
||||
@ -131,7 +131,8 @@ Status KernelAndDeviceOp::Init(const NodeDef& ndef,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status KernelAndDeviceFunc::InstantiateFunc(const NodeDef& ndef,
|
||||
Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
||||
const NodeDef& ndef,
|
||||
GraphCollector* graph_collector) {
|
||||
const OpDef* op_def = nullptr;
|
||||
const FunctionDef* function_def;
|
||||
@ -209,14 +210,16 @@ Status KernelAndDeviceFunc::InstantiateFunc(const NodeDef& ndef,
|
||||
->mutable_optimizer_options()
|
||||
->set_do_function_inlining(true);
|
||||
|
||||
options.config_proto.set_log_device_placement(ctx.log_device_placement);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
pflr_->Instantiate(ndef.op(), AttrSlice(ndef), options, &handle_));
|
||||
return pflr_->IsCrossProcess(handle_, &is_cross_process_);
|
||||
}
|
||||
|
||||
Status KernelAndDeviceFunc::Init(const NodeDef& ndef,
|
||||
Status KernelAndDeviceFunc::Init(const Context& ctx, const NodeDef& ndef,
|
||||
GraphCollector* graph_collector) {
|
||||
TF_RETURN_IF_ERROR(InstantiateFunc(ndef, graph_collector));
|
||||
TF_RETURN_IF_ERROR(InstantiateFunc(ctx, ndef, graph_collector));
|
||||
return pflr_->GetOutputDevices(handle_, &output_devices_);
|
||||
}
|
||||
|
||||
|
@ -93,11 +93,16 @@ class EagerKernelArgs : public FunctionArgsInterface {
|
||||
// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
|
||||
class KernelAndDevice : public core::RefCounted {
|
||||
public:
|
||||
struct Context {
|
||||
bool log_device_placement = false;
|
||||
};
|
||||
|
||||
// Populates this with a kernel appropriate for 'ndef'.
|
||||
//
|
||||
// The provided FunctionLibraryRuntime MUST outlive all calls to
|
||||
// Run() on the returned KernelAndDevice.
|
||||
virtual Status Init(const NodeDef& ndef, GraphCollector* graph_collector) = 0;
|
||||
virtual Status Init(const Context& ctx, const NodeDef& ndef,
|
||||
GraphCollector* graph_collector) = 0;
|
||||
|
||||
// Non-multi-device functions are run using regular CallOp and look like
|
||||
// primitive operations from KernelAndDevice perspective.
|
||||
@ -194,7 +199,8 @@ class KernelAndDeviceOp final : public KernelAndDevice {
|
||||
|
||||
~KernelAndDeviceOp() override {}
|
||||
|
||||
Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override;
|
||||
Status Init(const Context& ctx, const NodeDef& ndef,
|
||||
GraphCollector* graph_collector) override;
|
||||
|
||||
Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
||||
std::vector<Tensor>* outputs,
|
||||
@ -282,9 +288,11 @@ class KernelAndDeviceFunc : public KernelAndDevice {
|
||||
|
||||
bool IsFunction() override { return true; };
|
||||
|
||||
Status InstantiateFunc(const NodeDef& ndef, GraphCollector* graph_collector);
|
||||
Status InstantiateFunc(const Context& ctx, const NodeDef& ndef,
|
||||
GraphCollector* graph_collector);
|
||||
|
||||
Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override;
|
||||
Status Init(const Context& ctx, const NodeDef& ndef,
|
||||
GraphCollector* graph_collector) override;
|
||||
|
||||
Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
||||
std::vector<Tensor>* outputs,
|
||||
|
@ -122,7 +122,7 @@ void BM_KernelAndDeviceInit(int iters) {
|
||||
nullptr, env.cpu_device());
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TF_CHECK_OK(k.Init(ndef, nullptr));
|
||||
TF_CHECK_OK(k.Init({}, ndef, nullptr));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_KernelAndDeviceInit);
|
||||
@ -143,7 +143,7 @@ void BM_KernelAndDeviceRun(int iters) {
|
||||
TestEnv env;
|
||||
KernelAndDeviceOp k(nullptr, false, env.function_library_runtime(), nullptr,
|
||||
nullptr, env.cpu_device());
|
||||
TF_CHECK_OK(k.Init(ndef, nullptr));
|
||||
TF_CHECK_OK(k.Init({}, ndef, nullptr));
|
||||
const EagerKernelArgs args(std::move(inputs));
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
|
@ -920,7 +920,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
|
||||
|
||||
// Instantiate MatMulFunction on remote_device.
|
||||
const NodeDef node_def = MatMulFunctionNodeDef();
|
||||
TF_ASSERT_OK(kernel->InstantiateFunc(node_def, nullptr));
|
||||
TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
|
||||
|
||||
// Run MatMulFunction on remote_device.
|
||||
gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
|
||||
@ -967,7 +967,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
|
||||
|
||||
// Instantiate MatMulFunction on remote_device.
|
||||
const NodeDef node_def = MatMulFunctionNodeDef();
|
||||
TF_ASSERT_OK(kernel->InstantiateFunc(node_def, nullptr));
|
||||
TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
|
||||
|
||||
// Run MatMulFunction on remote_device.
|
||||
gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
|
||||
|
@ -58,7 +58,8 @@ Status CreateUncachedKernelAndDeviceOp(
|
||||
ctx.HostCPU()));
|
||||
|
||||
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
|
||||
return kernel->get()->Init(ndef, /*graph_collector=*/nullptr);
|
||||
return kernel->get()->Init({ctx.LogDevicePlacement()}, ndef,
|
||||
/*graph_collector=*/nullptr);
|
||||
}
|
||||
|
||||
// This gets a unique wire ID. We add a random identifier so that if the
|
||||
|
Loading…
Reference in New Issue
Block a user