diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index a9ffae4fb8d..88360ac9dc4 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -631,6 +631,7 @@ cc_library( ], copts = tf_copts(), deps = [ + ":composite_device", ":device", ":device_mgr", ":device_set", @@ -651,6 +652,7 @@ cc_library( ":process_util", ":rendezvous_mgr", ":rendezvous_util", + ":replicate_per_replica_nodes", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -658,6 +660,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/core/common_runtime/eager/execute_node.cc b/tensorflow/core/common_runtime/eager/execute_node.cc index 8084ae98abc..b7bebd4ba11 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.cc +++ b/tensorflow/core/common_runtime/eager/execute_node.cc @@ -52,9 +52,14 @@ Status ExecuteNodeArgs::Init( #if !defined(IS_MOBILE_PLATFORM) if (has_remote_inputs_) { serialize_remote_handle_ = - [ctx, &op_inputs](const int i, + [ctx, &op_inputs](const FunctionArgIndex& index, eager::RemoteTensorHandle* handle) -> Status { - VariantDevice variant_device = op_inputs[i]->device(); + if (index.sub_index >= 0) { + return errors::InvalidArgument("Got unexpected sub_index ", + index.sub_index, " for argument ", + index.index); + } + VariantDevice variant_device = op_inputs[index.index]->device(); if (VariantDeviceIsCustom(variant_device)) { return errors::Internal( "Custom devices and remote execution are currently not supported " @@ -62,7 +67,7 @@ Status ExecuteNodeArgs::Init( } Device* device = absl::get(variant_device); return ctx->RemoteMgr()->SerializeRemoteTensorHandle( - op_inputs[i], handle, device, device->name()); + op_inputs[index.index], handle, device, device->name()); }; } #endif // !IS_MOBILE_PLATFORM diff --git a/tensorflow/core/common_runtime/eager/execute_node.h b/tensorflow/core/common_runtime/eager/execute_node.h index be6e4009896..d416f58bbcd 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.h +++ b/tensorflow/core/common_runtime/eager/execute_node.h @@ -54,10 +54,12 @@ class ExecuteNodeArgs : public EagerKernelArgs { const absl::InlinedVector& op_inputs, const core::RefCountPtr& kernel); - bool HasRemoteInputs() const override { return has_remote_inputs_; }; + bool HasRemoteOrPackedInputs() const override { + return has_remote_inputs_ || has_packed_inputs_; + }; #if !defined(IS_MOBILE_PLATFORM) - Status GetRemoteArg(const int index, + Status GetRemoteArg(const FunctionArgIndex& index, eager::RemoteTensorHandle* val) const override { return serialize_remote_handle_(index, val); } @@ -65,8 +67,9 @@ class ExecuteNodeArgs : public EagerKernelArgs { private: bool has_remote_inputs_ = false; + bool has_packed_inputs_ = false; #if !defined(IS_MOBILE_PLATFORM) - std::function + std::function serialize_remote_handle_; #endif // IS_MOBILE_PLATFORM }; diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index c9ff9e506b8..98d71959e2d 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/profiler/lib/annotated_traceme.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -49,13 +50,18 @@ limitations under the License. namespace tensorflow { -Status EagerKernelArgs::GetLocalArg(const int index, Tensor* val) const { - Tensor* arg = tensor_args_.at(index).tensor; +Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const { + if (index.sub_index >= 0) { + return errors::InvalidArgument("Got unexpected sub_index ", index.sub_index, + " for argument ", index.index); + } + Tensor* arg = tensor_args_.at(index.index).tensor; if (arg) { *val = *arg; return Status::OK(); } else { - return errors::NotFound("Argument ", index, " has no local tensor."); + return errors::NotFound("Argument ", index.index, " has no local tensor."); } } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 0597dc0aa2e..a740b898262 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -66,16 +66,16 @@ class EagerKernelArgs : public FunctionArgsInterface { ~EagerKernelArgs() override{}; - bool HasRemoteInputs() const override { return false; }; + bool HasRemoteOrPackedInputs() const override { return false; }; TensorValue* MutableInput(int i) { return &tensor_args_[i]; } - Status GetLocalArg(const int index, Tensor* val) const override; + Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override; std::vector GetLocalTensors() const override; - const gtl::InlinedVector* GetTensorValues() const override { + const gtl::InlinedVector* GetTensorValues() const { return &tensor_args_; - }; + } protected: gtl::InlinedVector tensor_args_; diff --git a/tensorflow/core/common_runtime/partitioning_utils.cc b/tensorflow/core/common_runtime/partitioning_utils.cc index 6cb56080a27..6fb7526c512 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.cc +++ b/tensorflow/core/common_runtime/partitioning_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/partitioning_utils.h" #include +#include #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" @@ -73,11 +74,11 @@ Status PartitionFunctionGraph( } Status UpdateArgAndRetvalMetadata( - Graph* subgraph, const string& device_type, std::vector* arg_indices, - std::vector* ret_indices, + Graph* subgraph, const string& device_type, + std::vector* arg_indices, std::vector* ret_indices, std::vector* arg_alloc_attrs, std::vector* ret_alloc_attrs) { - std::vector> arg_nodes; + std::vector> arg_nodes; std::vector> ret_nodes; const AttrValue* attr_value; @@ -87,7 +88,11 @@ Status UpdateArgAndRetvalMetadata( if (node->IsArg()) { TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); int index = static_cast(attr_value->i()); - arg_nodes.emplace_back(node, index); + int sub_index = -1; + if (node->attrs().Find("sub_index", &attr_value).ok()) { + sub_index = static_cast(attr_value->i()); + } + arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index)); } else if (node->IsRetval()) { TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); int index = static_cast(attr_value->i()); @@ -99,11 +104,16 @@ Status UpdateArgAndRetvalMetadata( // // In particular, this enables calling a single-partition function with // the same signature as the original unpartitioned function. - auto comparator = [](std::pair a, std::pair b) { + auto arg_comparator = [](std::pair a, + std::pair b) { + return std::tie(a.second.index, a.second.sub_index) < + std::tie(b.second.index, b.second.sub_index); + }; + std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator); + auto ret_comparator = [](std::pair a, std::pair b) { return a.second < b.second; }; - std::sort(arg_nodes.begin(), arg_nodes.end(), comparator); - std::sort(ret_nodes.begin(), ret_nodes.end(), comparator); + std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator); arg_indices->reserve(arg_nodes.size()); for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second); @@ -144,16 +154,6 @@ Status UpdateArgAndRetvalMetadata( return Status::OK(); } -std::vector GetArgsForIndices(const std::vector& indices, - gtl::ArraySlice arguments) { - std::vector args; - args.reserve(indices.size()); - for (int i : indices) { - args.push_back(arguments[i]); - } - return args; -} - string FunctionNameGenerator::GetName() { while (true) { const string candidate = strings::StrCat(name_, "_", counter_++); diff --git a/tensorflow/core/common_runtime/partitioning_utils.h b/tensorflow/core/common_runtime/partitioning_utils.h index 7d2a2c2d2eb..1eb17423de0 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.h +++ b/tensorflow/core/common_runtime/partitioning_utils.h @@ -58,15 +58,11 @@ Status PartitionFunctionGraph( // (3) records which `Arg` and `Retval` nodes live in host memory in // `*_alloc_attrs`. Status UpdateArgAndRetvalMetadata( - Graph* subgraph, const string& device_type, std::vector* arg_indices, - std::vector* ret_indices, + Graph* subgraph, const string& device_type, + std::vector* arg_indices, std::vector* ret_indices, std::vector* arg_alloc_attrs, std::vector* ret_alloc_attrs); -// Extracts tensors at `indices` from `arguments`. -std::vector GetArgsForIndices(const std::vector& indices, - gtl::ArraySlice arguments); - // Utility for generating function names not present in `flib_def`, using // given `name` as the base for the name. class FunctionNameGenerator { diff --git a/tensorflow/core/common_runtime/partitioning_utils_test.cc b/tensorflow/core/common_runtime/partitioning_utils_test.cc index 9c4ce259bf8..b33eae85ba1 100644 --- a/tensorflow/core/common_runtime/partitioning_utils_test.cc +++ b/tensorflow/core/common_runtime/partitioning_utils_test.cc @@ -158,14 +158,23 @@ TEST_F(PartitioningUtilsTest, TwoDevices) { ASSERT_EQ(3, part2->num_op_nodes()); } -void CheckIndices(const std::vector& expected, - const std::vector& actual) { +void CheckRetIndices(const std::vector& expected, + const std::vector& actual) { ASSERT_EQ(expected.size(), actual.size()); for (int i = 0; i < expected.size(); ++i) { ASSERT_EQ(expected[i], actual[i]) << " at index " << i; } } +void CheckArgIndices(const std::vector& expected, + const std::vector& actual) { + ASSERT_EQ(expected.size(), actual.size()); + for (int i = 0; i < expected.size(); ++i) { + ASSERT_EQ(expected[i].index, actual[i].index) << " at index " << i; + ASSERT_EQ(expected[i].sub_index, actual[i].sub_index) << " at index " << i; + } +} + void CheckAlloc(const std::vector& expected, const std::vector& actual) { ASSERT_EQ(expected.size(), actual.size()); @@ -185,7 +194,7 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) { auto graph = absl::make_unique(OpRegistry::Global()); SubGraph(graph.get(), DT_FLOAT, {3}, {5}); - std::vector arg_indices; + std::vector arg_indices; std::vector ret_indices; std::vector arg_alloc_attrs; std::vector ret_alloc_attrs; @@ -197,8 +206,8 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) { &ret_alloc_attrs); ASSERT_TRUE(status.ok()) << status.ToString(); - CheckIndices({3}, arg_indices); - CheckIndices({5}, ret_indices); + CheckArgIndices({{3, -1}}, arg_indices); + CheckRetIndices({5}, ret_indices); CheckAlloc({false}, arg_alloc_attrs); CheckAlloc({false}, ret_alloc_attrs); @@ -213,7 +222,18 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) { auto graph = absl::make_unique(OpRegistry::Global()); SubGraph(graph.get(), DT_FLOAT, {9, 7, 5, 3, 1}, {2, 4, 6, 8, 10}); - std::vector arg_indices; + const std::map sub_indices = { + {7, 2}, {3, 1}, {1, 0}, {5, 2}, {9, 0}}; + const AttrValue* attr_value; + for (Node* n : graph->op_nodes()) { + if (n->IsArg()) { + TF_ASSERT_OK(n->attrs().Find("index", &attr_value)); + n->AddAttr("sub_index", + sub_indices.at(static_cast(attr_value->i()))); + } + } + + std::vector arg_indices; std::vector ret_indices; std::vector arg_alloc_attrs; std::vector ret_alloc_attrs; @@ -225,8 +245,8 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) { &ret_alloc_attrs); ASSERT_TRUE(status.ok()) << status.ToString(); - CheckIndices({1, 3, 5, 7, 9}, arg_indices); - CheckIndices({2, 4, 6, 8, 10}, ret_indices); + CheckArgIndices({{1, 0}, {3, 1}, {5, 2}, {7, 2}, {9, 0}}, arg_indices); + CheckRetIndices({2, 4, 6, 8, 10}, ret_indices); CheckAlloc({false, false, false, false, false}, arg_alloc_attrs); CheckAlloc({false, false, false, false, false}, ret_alloc_attrs); } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index c447832c91b..271169f2a5e 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "tensorflow/core/common_runtime/device_set.h" @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/rendezvous_util.h" +#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/op_kernel.h" @@ -301,7 +303,7 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice( // Replace the given handle with the handle for the single component // function. - handle = component_data.handle_; + handle = component_data.handle; } auto iter = function_data_.find(handle); @@ -777,6 +779,14 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); + // Expand the nodes assigned to a CompositeDevice before graph partition to + // avoid generating a subgraph on a virtual device for execution. + // This transformation should happen as late as possible, in order to run as + // more graph optimization passes (e.g. PRE_PLACEMENT, PLACER, + // POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible. + TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph( + options.composite_devices, graph.get())); + if (options.graph_collector != nullptr) { GraphDef def; graph->ToGraphDef(&def); @@ -869,9 +879,9 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( Graph* subgraph = pair.second.get(); status->Update(UpdateArgAndRetvalMetadata( - subgraph, device_type, &comp_data->arg_indices_, - &comp_data->ret_indices_, &comp_data->arg_alloc_attrs_, - &comp_data->ret_alloc_attrs_)); + subgraph, device_type, &comp_data->arg_indices, + &comp_data->ret_indices, &comp_data->arg_alloc_attrs, + &comp_data->ret_alloc_attrs)); if (!status->ok()) { counter.DecrementCount(); return; @@ -913,7 +923,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( data->is_cross_process_ = true; } } - comp_data->handle_ = *component_handle; + comp_data->handle = *component_handle; } delete component_handle; counter.DecrementCount(); @@ -955,16 +965,16 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices( for (const auto& pair : data->glue_) { const ComponentFunctionData& comp_data = pair.second; - DCHECK(comp_data.ret_alloc_attrs_.size() == comp_data.ret_indices_.size()); + DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size()); const string& target = pair.first; FunctionLibraryRuntime* target_flr = GetFLR(target); if (target_flr == nullptr) { - if (!comp_data.ret_indices_.empty()) { + if (!comp_data.ret_indices.empty()) { return errors::Unimplemented( "Currently, outputting tensors on remote devices is not supported. " "The ", - comp_data.ret_indices_[0], + comp_data.ret_indices[0], "-th return value of the function outputs to target_device: ", target, " Please copy the tensor to local device explicitly using " @@ -973,17 +983,17 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices( continue; } Device* target_device = target_flr->device(); - const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle_); + const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle); DCHECK(fbody != nullptr); output_devices->resize(data->num_outputs_); - for (int j = 0; j < comp_data.ret_indices_.size(); ++j) { - int ret_index = comp_data.ret_indices_[j]; + for (int j = 0; j < comp_data.ret_indices.size(); ++j) { + int ret_index = comp_data.ret_indices[j]; if (fbody->ret_types[j] == DT_RESOURCE) { (*output_devices)[ret_index] = target_device; } else { (*output_devices)[ret_index] = - comp_data.ret_alloc_attrs_[j].on_host() ? nullptr : target_device; + comp_data.ret_alloc_attrs[j].on_host() ? nullptr : target_device; } } } @@ -1013,9 +1023,8 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( const MultiDeviceFunctionData* data = IsMultiDevice(handle); if (data == nullptr) { - done( - errors::InvalidArgument("Failed for find multi-device function handle ", - handle, ". Was the function instantiated?")); + done(errors::NotFound("Multi-device function handle ", handle, + "not found. Was the function instantiated?")); return; } @@ -1046,10 +1055,10 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( for (const auto& pair : data->glue_) { const string& target = pair.first; const ComponentFunctionData& comp_data = pair.second; - FunctionLibraryRuntime::Handle handle = pair.second.handle_; + FunctionLibraryRuntime::Handle handle = pair.second.handle; - opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_; - opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_; + opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs; + opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs; opts_copy.remote_execution = false; InternalArgs comp_args; @@ -1086,7 +1095,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( Status(status.code(), function_and_msg)); } else { for (int i = 0; i < comp_rets->size(); ++i) { - (*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i]; + (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i]; } } delete comp_rets; @@ -1108,7 +1117,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( refcounted_done->UpdateStatus(status); } else { for (int i = 0; i < comp_rets->size(); ++i) { - (*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i]; + (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i]; } } delete comp_rets; @@ -1225,7 +1234,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( Status overall_status; for (const auto& it : mdata->glue_) { const string& device = it.first; - FunctionLibraryRuntime::Handle flr_handle = it.second.handle_; + FunctionLibraryRuntime::Handle flr_handle = it.second.handle; FunctionLibraryRuntime* flr = GetFLR(device); if (flr == nullptr) { // TODO(nareshmodi): Implement DeregisterGraph call to remote device if @@ -1297,6 +1306,19 @@ ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback( }; } +Status ProcessFunctionLibraryRuntime::CreateRendezvous( + const FunctionLibraryRuntime::Options& opts, + Rendezvous** created_rendezvous) const { + if (rendezvous_factory_) { + return rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous); + } else { + return errors::FailedPrecondition( + "The caller does not provide a rendezvous and " + "ProcessFunctionLibraryRuntime was created without a rendezvous " + "factory."); + } +} + void ProcessFunctionLibraryRuntime::Run( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, gtl::ArraySlice args, @@ -1305,21 +1327,12 @@ void ProcessFunctionLibraryRuntime::Run( FunctionLibraryRuntime::Options new_opts = opts; Rendezvous* created_rendezvous = nullptr; if (!opts.rendezvous) { - if (rendezvous_factory_) { - Status s = - rendezvous_factory_(opts.step_id, device_mgr_, &created_rendezvous); - if (!s.ok()) { - done(s); - return; - } - new_opts.rendezvous = created_rendezvous; - } else { - done( - errors::FailedPrecondition("The caller does not provide a rendezvous " - "and ProcessFunctionLibraryRuntime was " - "created without a rendezvous factory.")); + Status s = CreateRendezvous(opts, &created_rendezvous); + if (!s.ok()) { + done(s); return; } + new_opts.rendezvous = created_rendezvous; new_opts.create_rendezvous = false; } @@ -1334,9 +1347,14 @@ void ProcessFunctionLibraryRuntime::Run( if (multi_device) { auto get_component_args = [&args](const ComponentFunctionData& comp_data, InternalArgs* comp_args) -> Status { - for (const auto& tensor : - GetArgsForIndices(comp_data.arg_indices_, args)) { - comp_args->args.push_back(tensor); + // "Index"s of _Arg nodes are unique when all arguments are local Tensors. + for (const auto& it : comp_data.arg_indices) { + if (it.sub_index >= 0) { + return errors::InvalidArgument("Got unexpected sub_index ", + it.sub_index, " for argument ", + it.index); + } + comp_args->args.push_back(args[it.index]); } return Status::OK(); }; @@ -1520,11 +1538,23 @@ void ProcessFunctionLibraryRuntime::Run( FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) const { - if (!args.HasRemoteInputs()) { + if (!args.HasRemoteOrPackedInputs()) { const std::vector local_inputs = args.GetLocalTensors(); return Run(opts, handle, local_inputs, rets, std::move(done)); } + FunctionLibraryRuntime::Options new_opts = opts; + Rendezvous* created_rendezvous = nullptr; + if (!opts.rendezvous) { + Status s = CreateRendezvous(opts, &created_rendezvous); + if (!s.ok()) { + done(s); + return; + } + new_opts.rendezvous = created_rendezvous; + new_opts.create_rendezvous = false; + } + #if defined(IS_MOBILE_PLATFORM) done(errors::Unimplemented( "Remote inputs are not available on mobile devices.")); @@ -1532,12 +1562,12 @@ void ProcessFunctionLibraryRuntime::Run( #else // !IS_MOBILE_PLATFORM auto* cleanup_items = new std::vector>; done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id, - /*rendezvous=*/nullptr); + created_rendezvous); auto get_component_args = [&args](const ComponentFunctionData& comp_data, InternalArgs* comp_args) -> Status { - for (int i = 0; i < comp_data.arg_indices_.size(); ++i) { - const int index = comp_data.arg_indices_.at(i); + for (int i = 0; i < comp_data.arg_indices.size(); ++i) { + const FunctionArgIndex index = comp_data.arg_indices.at(i); Tensor tensor; if (args.GetLocalArg(index, &tensor).ok()) { comp_args->args.push_back(std::move(tensor)); @@ -1552,7 +1582,7 @@ void ProcessFunctionLibraryRuntime::Run( } return Status::OK(); }; - return RunMultiDevice(opts, handle, rets, cleanup_items, std::move(done), + return RunMultiDevice(new_opts, handle, rets, cleanup_items, std::move(done), std::move(get_component_args)); #endif // !IS_MOBILE_PLATFORM } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 104872e5a1c..bc68c9c2807 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/variant.h" +#include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/function.h" @@ -40,16 +41,15 @@ class FunctionArgsInterface { public: virtual ~FunctionArgsInterface() {} - virtual bool HasRemoteInputs() const = 0; + virtual bool HasRemoteOrPackedInputs() const = 0; - virtual Status GetLocalArg(const int index, Tensor* val) const = 0; + virtual Status GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const = 0; virtual std::vector GetLocalTensors() const = 0; - virtual const gtl::InlinedVector* GetTensorValues() const = 0; - #if !defined(IS_MOBILE_PLATFORM) - virtual Status GetRemoteArg(const int index, + virtual Status GetRemoteArg(const FunctionArgIndex& index, eager::RemoteTensorHandle* val) const { return errors::Unimplemented( "Serializing a remote argument is not implemented."); @@ -217,6 +217,12 @@ class ProcessFunctionLibraryRuntime { return lib_def_; } + // Add a CompositeDevice to `device_set_` + void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + device_set_->AddDevice(d); + } + protected: friend class FunctionLibraryRuntimeImpl; @@ -232,21 +238,21 @@ class ProcessFunctionLibraryRuntime { // piece of a multi-device function) fits into the multi-device function. struct ComponentFunctionData { // The handle for the instantiated component function. - FunctionLibraryRuntime::Handle handle_; - // arg_indices_.size() is the number of arguments to the component function. + FunctionLibraryRuntime::Handle handle; + // arg_indices.size() is the number of arguments to the component function. // The i-th argument of the component function comes from the - // `arg_indices_[i]`-th argument of the multi-device function. - std::vector arg_indices_; - // ret_indices_.size() is the number of return values of the component + // `arg_indices[i]`-th argument of the multi-device function. + std::vector arg_indices; + // ret_indices.size() is the number of return values of the component // function. The i-th return value of the component function goes to the - // `ret_indices_[i]`-th return value of the multi-device function. - std::vector ret_indices_; - // arg_alloc_attrs_[i] are the allocator attributes of the i-th argument to + // `ret_indices[i]`-th return value of the multi-device function. + std::vector ret_indices; + // arg_alloc_attrs[i] are the allocator attributes of the i-th argument to // the component function. - std::vector arg_alloc_attrs_; - // ret_alloc_attrs_[i] are the allocator attributes of the i-th return value + std::vector arg_alloc_attrs; + // ret_alloc_attrs[i] are the allocator attributes of the i-th return value // of the component function. - std::vector ret_alloc_attrs_; + std::vector ret_alloc_attrs; }; // Data structure holding information for a single instantiated multi-device @@ -304,6 +310,9 @@ class ProcessFunctionLibraryRuntime { InternalArgs* args)> get_component_args) const; + Status CreateRendezvous(const FunctionLibraryRuntime::Options& opts, + Rendezvous** created_rendezvous) const; + FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback( std::vector>* items, FunctionLibraryRuntime::DoneCallback done, const int64 step_id, diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index aa31e2b11f2..247b94dc58c 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" @@ -143,6 +144,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { }})); } + void AddCompositeDevice(CompositeDevice* d) { + proc_flr_->AddCompositeDevice(d); + } + Status Instantiate( const string& name, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, @@ -187,11 +192,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } + template Status RunWithRuntime( const string& name, FunctionLibraryRuntime::Options opts, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, - const std::vector& args, std::vector rets, + const T& args, std::vector rets, ProcessFunctionLibraryRuntime* pflr) { FunctionLibraryRuntime::Handle handle; Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle); @@ -248,9 +254,20 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { Status Run(const string& name, FunctionLibraryRuntime::Options opts, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, - const std::vector& args, std::vector rets) { - return RunWithRuntime(name, opts, attrs, instantiate_opts, args, rets, - proc_flr_.get()); + const std::vector& args, std::vector rets, + ProcessFunctionLibraryRuntime* pflr = nullptr) { + return RunWithRuntime>( + name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get()); + } + + Status RunWithPackedArgs( + const string& name, FunctionLibraryRuntime::Options opts, + test::function::Attrs attrs, + const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, + const FunctionArgsInterface& args, std::vector rets, + ProcessFunctionLibraryRuntime* pflr = nullptr) { + return RunWithRuntime( + name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get()); } Status RunInstantiated(FunctionLibraryRuntime::Handle handle, @@ -719,6 +736,112 @@ Tensor GetResourceHandle(const string& var_name, const string& container, return tensor; } +// Returns a function which adds two variables on different devices. +FunctionDef AddVarAcrossDevices() { + return FunctionDefHelper::Create( + // Name + "AddVarAcrossDevices", + // Args + {"x: resource"}, + // Return values + {"y: float"}, + // Attr def + {}, + // Nodes + { + {{"read0"}, + "ReadVariableOp", + {"x"}, + {{"dtype", DT_FLOAT}}, + {}, + "/device:CPU:0"}, + {{"read1"}, + "ReadVariableOp", + {"x"}, + {{"dtype", DT_FLOAT}}, + {}, + "/device:CPU:1"}, + {{"add"}, + "Add", + {"read0:value:0", "read1:value:0"}, + {{"T", DT_FLOAT}}, + {}, + "/device:CPU:0"}, + }, + {{"y", "add:z:0"}}); +} + +// An implementation of FunctionArgsInterface for packed inputs. +class TestFunctionPackedArgs : public FunctionArgsInterface { + public: + TestFunctionPackedArgs(const int index, + gtl::InlinedVector&& tensor_args) { + packed_args_.emplace(index, std::move(tensor_args)); + } + + ~TestFunctionPackedArgs() override{}; + + bool HasRemoteOrPackedInputs() const override { return true; }; + + Status GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const override { + *val = *packed_args_.at(index.index).at(index.sub_index).tensor; + return Status::OK(); + }; + + std::vector GetLocalTensors() const override { return {}; } + + private: + absl::flat_hash_map> packed_args_; +}; + +TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) { + Init({AddVarAcrossDevices()}); + // Create two variables on two devices. + const Tensor initial_resource_value0 = test::AsTensor({10, 20}); + Var* resource0 = new Var(DT_FLOAT); + *resource0->tensor() = initial_resource_value0; + resource0->is_initialized = true; + const Tensor initial_resource_value1 = test::AsTensor({30, 40}); + Var* resource1 = new Var(DT_FLOAT); + *resource1->tensor() = initial_resource_value1; + resource1->is_initialized = true; + ResourceMgr* mgr0 = device0_->resource_manager(); + ResourceMgr* mgr1 = device1_->resource_manager(); + TF_ASSERT_OK(mgr0->Create(mgr0->default_container(), "var", resource0)); + TF_ASSERT_OK(mgr1->Create(mgr1->default_container(), "var", resource1)); + + Tensor resource_handle0 = + GetResourceHandle("var", mgr0->default_container(), device0_->name()); + Tensor resource_handle1 = + GetResourceHandle("var", mgr1->default_container(), device1_->name()); + + // Create a CompositeDevice + Status s; + std::unique_ptr composite_device = + CompositeDevice::MakeDevice({device0_->name(), device1_->name()}, + /*unique_device_id=*/0, &s); + TF_ASSERT_OK(s); + AddCompositeDevice(composite_device.get()); + + FunctionLibraryRuntime::Options opts; + FunctionLibraryRuntime::InstantiateOptions inst_opts = + MakeOptions("CPU:0", {"COMPOSITE:0"}, {"CPU:0"}); + inst_opts.composite_devices[composite_device->name()] = + composite_device->underlying_devices(); + inst_opts.input_resource_dtypes_and_shapes[0] = { + initial_resource_value0.dtype(), initial_resource_value0.shape()}; + + gtl::InlinedVector handles; + handles.push_back(TensorValue(&resource_handle0)); + handles.push_back(TensorValue(&resource_handle1)); + TestFunctionPackedArgs args(0, std::move(handles)); + Tensor ret; + TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}}, + inst_opts, args, {&ret})); + test::ExpectTensorEqual(ret, test::AsTensor({40, 60})); +} + TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) { if (gpu_device_ == nullptr) { GTEST_SKIP() << "No GPUs available"; @@ -1025,9 +1148,9 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) { instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0"; const auto x = test::AsTensor({17}); Tensor y; - TF_CHECK_OK(RunWithRuntime("SessionMetadataReaderFn", opts, {}, - instantiate_opts, {x}, {&y}, - cloned_proc_flr.get())); + TF_CHECK_OK(RunWithRuntime>( + "SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y}, + cloned_proc_flr.get())); SessionMetadata read_metadata; ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar()(), &read_metadata)); diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc index 4c32a54aee4..3609a5e7e1f 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc @@ -42,6 +42,7 @@ class ReplicateHelper { Node* replicated_node = graph->AddNode(node_def, &status); TF_RETURN_IF_ERROR(status); replicated_node->set_assigned_device_name(device); + replicated_node->AddAttr("sub_index", i); replicated_nodes[i] = replicated_node; } replicated_nodes_map_.emplace(node, std::move(replicated_nodes)); @@ -180,7 +181,8 @@ Status ReplicateEdges(const ReplicateHelper& helper, } // namespace Status ReplicatePerReplicaNodesInFunctionGraph( - const absl::flat_hash_map>& composite_devices, + const absl::flat_hash_map*>& + composite_devices, Graph* graph) { std::set composite_device_names; for (const auto& it : composite_devices) { @@ -198,7 +200,8 @@ Status ReplicatePerReplicaNodesInFunctionGraph( } for (const auto& it : composite_device_to_cluster_nodes) { - const std::vector& allowed_devices = composite_devices.at(it.first); + const std::vector& allowed_devices = + *composite_devices.at(it.first); if (allowed_devices.empty()) { return errors::InvalidArgument("No allowed device of composite device: ", it.first); @@ -208,6 +211,7 @@ Status ReplicatePerReplicaNodesInFunctionGraph( // Reuse the original nodes if there is only one allowed device. for (Node* n : cluster_nodes) { n->set_assigned_device_name(allowed_devices.at(0)); + n->AddAttr("sub_index", 0); } continue; } diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes.h b/tensorflow/core/common_runtime/replicate_per_replica_nodes.h index 872e77c8671..fd696db4905 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes.h +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes.h @@ -35,7 +35,8 @@ namespace tensorflow { // dependency. // TODO(b/145922293): Register it as a POST_REWRITE_FOR_EXEC pass. Status ReplicatePerReplicaNodesInFunctionGraph( - const absl::flat_hash_map>& composite_devices, + const absl::flat_hash_map*>& + composite_devices, Graph* graph); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc b/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc index 094d86944ee..db05907710c 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc @@ -75,8 +75,9 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) { auto ret = ops::_Retval( scope.WithOpName("ret").WithControlDependencies({write}), read, 0); - const absl::flat_hash_map> composite_devices = { - {"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}}; + const std::vector underlying_devices = {"TPU:0", "TPU:1"}; + const absl::flat_hash_map*> + composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}}; Graph graph(OpRegistry::Global()); TF_ASSERT_OK(scope.ToGraph(&graph)); @@ -118,8 +119,9 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) { auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32); auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0); - const absl::flat_hash_map> composite_devices = { - {"TPU_COMPOSITE:0", {"TPU:0"}}}; + const std::vector underlying_devices = {"TPU:0"}; + const absl::flat_hash_map*> + composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}}; Graph graph(OpRegistry::Global()); TF_ASSERT_OK(scope.ToGraph(&graph)); @@ -156,9 +158,11 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) { auto add = ops::Add(scope.WithOpName("add"), identity0, identity1); auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0); - const absl::flat_hash_map> composite_devices = { - {"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}, - {"TPU_COMPOSITE:1", {"TPU:2", "TPU:3"}}}; + const std::vector underlying_devices_0 = {"TPU:0", "TPU:1"}; + const std::vector underlying_devices_1 = {"TPU:2", "TPU:3"}; + const absl::flat_hash_map*> + composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices_0}, + {"TPU_COMPOSITE:1", &underlying_devices_1}}; Graph graph(OpRegistry::Global()); TF_ASSERT_OK(scope.ToGraph(&graph)); @@ -204,8 +208,9 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) { } TEST(ReplicatePerReplicaNodesTest, NestedFunctions) { - const absl::flat_hash_map> composite_devices = { - {"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}}; + const std::vector underlying_devices = {"TPU:0", "TPU:1"}; + const absl::flat_hash_map*> + composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}}; FunctionDefLibrary fdef_lib; FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 76ca5c318fb..92e92f47356 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -500,11 +500,11 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest { : EagerKernelArgs(std::move(tensor_args)), serialize_remote_handle_(std::move(serialize_remote_handle)) {} - bool HasRemoteInputs() const override { return true; } + bool HasRemoteOrPackedInputs() const override { return true; } - Status GetRemoteArg(const int index, + Status GetRemoteArg(const FunctionArgIndex& index, eager::RemoteTensorHandle* val) const override { - return serialize_remote_handle_(index, val); + return serialize_remote_handle_(index.index, val); } private: @@ -562,7 +562,14 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest { eager_pflr_ = absl::make_unique( remote_device_mgr_.get(), Env::Default(), /*config=*/ nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(), - /*thread_pool=*/nullptr, eager_cluster_flr_.get()); + /*thread_pool=*/nullptr, eager_cluster_flr_.get(), + /*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr, + Rendezvous::Factory{[this](const int64 step_id, + const DeviceMgr* device_mgr, + Rendezvous** r) { + *r = worker_env_.rendezvous_mgr->Find(step_id); + return Status::OK(); + }}); } void CheckOutputTensorAndClose(const Tensor& tensor) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index b64047e999f..788c49675e5 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/platform.h" // clang-format on +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -525,6 +526,20 @@ class Device; // Forward declare. Defined in common_runtime/device_mgr.h class DeviceMgr; +// Index of an _Arg node. +struct FunctionArgIndex { + explicit FunctionArgIndex(const int index) : index(index) {} + FunctionArgIndex(const int index, const int sub_index) + : index(index), sub_index(sub_index) {} + + // The value of the attribute "Index" of the _Arg node. + int index; + // Set only when the _Arg node represents multiple arguments (e.g. an _Arg + // node is replicated to multiple devices/subgraphs). Use sub-index to + // distinguish arguments with the same index. + int sub_index = -1; +}; + class FunctionLibraryRuntime { public: virtual ~FunctionLibraryRuntime() {} @@ -576,6 +591,10 @@ class FunctionLibraryRuntime { // infer correct device. std::vector output_devices; + // Maps from a CompositeDevice name to a list of underlying physical + // devices. + absl::flat_hash_map*> composite_devices; + // This interface is EXPERIMENTAL and subject to change. // // For multi-device functions, a mapping from _Arg node index to type and