From bae9776f5793d55b7710c8585c3d88e6c8ec2a96 Mon Sep 17 00:00:00 2001 From: Yujing Zhang Date: Mon, 9 Mar 2020 16:37:17 -0700 Subject: [PATCH] Roll forward https://github.com/tensorflow/tensorflow/commit/ed0e46c219efa62559901... PiperOrigin-RevId: 299965831 Change-Id: I22b867dd0c7c95552d52b49595f1f5b5279639ec --- tensorflow/compiler/tf2xla/tf2xla.cc | 26 +++++++--- .../base_api/api_def_VarHandleOp.pbtxt | 7 +++ .../core/common_runtime/eager/execute.cc | 33 +++++++++--- .../common_runtime/eager/tensor_handle.cc | 52 +++++++++++++------ .../core/common_runtime/eager/tensor_handle.h | 20 ++++--- .../distributed_runtime/eager/remote_mgr.cc | 12 +++-- tensorflow/core/framework/resource_mgr.h | 12 +++-- .../core/kernels/resource_variable_ops.cc | 8 ++- .../core/kernels/resource_variable_ops.h | 4 ++ tensorflow/core/ops/resource_variable_ops.cc | 1 + .../resource_variable_ops_test.py | 36 +++++++++++++ .../api/golden/v1/tensorflow.raw_ops.pbtxt | 2 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 2 +- 13 files changed, 167 insertions(+), 48 deletions(-) diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 9ced6e682fc..bcdfd1c6a8e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/versions.pb.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/dump_graph.h" @@ -128,19 +130,31 @@ Status ConvertGraphToXla(std::unique_ptr graph, return Status::OK(); } -void ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { - for (auto& node : *graph_def->mutable_node()) { +Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { + auto update_var_handle_op_node = [](NodeDef& node) -> Status { if (node.op() == "VarHandleOp") { node.set_op(tfcompile::kXlaAotOnlyVarHandleOp); + const auto& it = node.attr().find("allowed_devices"); + if (it != node.attr().end()) { + if (!it->second.list().s().empty()) { + // TODO(b/149512838): Support non-empty allowed devices. + return errors::InvalidArgument( + "VarHandleOp with non-empty allowed devices is not supported."); + } + node.mutable_attr()->erase("allowed_devices"); + } } + return Status::OK(); + }; + for (auto& node : *graph_def->mutable_node()) { + TF_RETURN_IF_ERROR(update_var_handle_op_node(node)); } for (auto& fn : *graph_def->mutable_library()->mutable_function()) { for (auto& node : *fn.mutable_node_def()) { - if (node.op() == "VarHandleOp") { - node.set_op(tfcompile::kXlaAotOnlyVarHandleOp); - } + TF_RETURN_IF_ERROR(update_var_handle_op_node(node)); } } + return Status::OK(); } } // namespace @@ -149,7 +163,7 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config, xla::Client* client, xla::XlaComputation* computation) { std::unique_ptr graph; - ConvertVarHandlesToAotVarHandles(&graph_def); + TF_RETURN_IF_ERROR(ConvertVarHandlesToAotVarHandles(&graph_def)); TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); TF_RETURN_IF_ERROR( ConvertGraphToXla(std::move(graph), config, client, computation)); diff --git a/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt index 0a4caa06bdb..39606a07184 100644 --- a/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt @@ -23,6 +23,13 @@ END name: "shape" description: < resource_dtypes_and_shapes; - TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes( - &resource_dtypes_and_shapes)); - if (!resource_dtypes_and_shapes.empty()) { + TensorHandle::ResourceHandleInfo resource_handle_info; + TF_RETURN_IF_ERROR(input->GetResourceHandleInfo(&resource_handle_info)); + std::vector* resource_dtypes_and_shapes = + &resource_handle_info.dtypes_and_shapes; + if (!resource_dtypes_and_shapes->empty()) { const DtypeAndPartialTensorShape& dtype_and_shape = - resource_dtypes_and_shapes.at(0); + resource_dtypes_and_shapes->at(0); input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape; // Add _Arg index, dtype and shape to "cache_key". @@ -647,8 +648,13 @@ Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op, TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype)); const AttrValue* shape; TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape)); - retvals[0]->SetResourceHandleDtypeAndShape( - {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}); + TensorHandle::ResourceHandleInfo resource_handle_info = { + {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}, {}}; + // "allowed_devices" is set only when the output represents a + // per-replica/partitioned resource variable. + TryGetNodeAttr(attr_slice, "allowed_devices", + &resource_handle_info.allowed_devices); + retvals[0]->SetResourceHandleInfo(std::move(resource_handle_info)); } return Status::OK(); } @@ -869,6 +875,19 @@ Status MaybeUpdateOpDevice(EagerOperation* op) { // is a resource we must pin it to prevent different device selection. // TODO(iga): null device can mean "unspecified" or "CPU". Clean this up. if (resource_device != op_device || op->Device() == kVariantDeviceNull) { + std::vector allowed_devices; + TF_RETURN_IF_ERROR( + tensor_handle->GetResourceAllowedDevices(&allowed_devices)); + if (!allowed_devices.empty()) { + // TODO(b/145922293): Support allowed_devices specified in wildcard + // patterns. + std::vector device_names; + if (std::find(allowed_devices.begin(), allowed_devices.end(), + op->GetDeviceName()) != allowed_devices.end()) { + TF_RETURN_IF_ERROR(ctx.FindDeviceFromName( + op->GetDeviceName().c_str(), &resource_device)); + } + } DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ") << "device of operation " << op->Name() << " to " << resource_device->name() << " because input #" << i diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index e7e2fb7b197..dc805d091bf 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -55,13 +55,13 @@ const int32 kInvalidOutputNum = -1; #endif } // namespace -void TensorHandle::SetResourceHandleDtypeAndShape( - std::vector&& dtypes_and_shapes) { - handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes); +void TensorHandle::SetResourceHandleInfo( + ResourceHandleInfo&& resource_handle_info) { + resource_handle_info_ = std::move(resource_handle_info); } -Status TensorHandle::GetResourceHandleDtypesAndShapes( - std::vector* result) { +Status TensorHandle::GetResourceHandleInfoImpl( + std::function set_resource_info) { if (dtype != DT_RESOURCE) { return errors::InvalidArgument( "TensorHandle::GetResourceDtypeAndShape should be called on tensor " @@ -70,22 +70,42 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes( } if (IsRemote()) { - *result = handle_dtypes_and_shapes_; + set_resource_info(); return Status::OK(); } // Wait for this TensorHandle to be ready. - profiler::TraceMe activity( - "TensorHandle::GetResourceHandleDtypesAndShapes WaitReady", - profiler::TraceMeLevel::kInfo); + profiler::TraceMe activity("TensorHandle::GetResourceHandleInfo WaitReady", + profiler::TraceMeLevel::kInfo); auto& data = absl::get(data_); - TF_RETURN_IF_ERROR( - data.WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes")); + TF_RETURN_IF_ERROR(data.WaitReady("TensorHandle::GetResourceHandleInfo")); - *result = handle_dtypes_and_shapes_; + set_resource_info(); return Status::OK(); } +Status TensorHandle::GetResourceHandleInfo(ResourceHandleInfo* result) { + auto get_resource_info = [result, this]() { + *result = resource_handle_info_; + }; + return GetResourceHandleInfoImpl(get_resource_info); +} + +Status TensorHandle::GetResourceHandleDtypesAndShapes( + std::vector* result) { + auto get_resource_info = [result, this]() { + *result = resource_handle_info_.dtypes_and_shapes; + }; + return GetResourceHandleInfoImpl(get_resource_info); +} + +Status TensorHandle::GetResourceAllowedDevices(std::vector* result) { + auto get_resource_info = [result, this]() { + *result = resource_handle_info_.allowed_devices; + }; + return GetResourceHandleInfoImpl(get_resource_info); +} + Status TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t, TensorHandle** h) { // TODO(b/136608821): Move away from nullptr @@ -145,8 +165,9 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, GetResourceDevice(t.flat()(0), ctx)), ctx_(ctx), implicit_mirroring_(true), - handle_dtypes_and_shapes_( - t.flat()(0).dtypes_and_shapes()), + resource_handle_info_( + {t.flat()(0).dtypes_and_shapes(), + t.flat()(0).allowed_devices()}), data_(absl::in_place_type, std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << VariantDeviceDebugString(device_) @@ -669,7 +690,8 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) { if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) { auto& resource_handle = t.flat()(0); - handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes(); + resource_handle_info_ = {resource_handle.dtypes_and_shapes(), + resource_handle.allowed_devices()}; } auto& data = absl::get(data_); return data.SetTensor(std::move(t)); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 783a8907f05..030976f32b8 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -210,13 +210,19 @@ class TensorHandle : public core::RefCounted { string DebugString() const; - void SetResourceHandleDtypeAndShape( - std::vector&& dtypes_and_shapes); + struct ResourceHandleInfo { + std::vector dtypes_and_shapes; + std::vector allowed_devices; + }; + + void SetResourceHandleInfo(ResourceHandleInfo&& resource_handle_info); // If this TensorHandle is 1) a local tensor, and 2) a resource handle, - // return data types and shapes of the underlying resource. + // return data types, shapes and allowed devices of the underlying resource. + Status GetResourceHandleInfo(ResourceHandleInfo* result); Status GetResourceHandleDtypesAndShapes( std::vector* result); + Status GetResourceAllowedDevices(std::vector* result); private: // The TensorHandleData can either represent a local or remote tensor handle. @@ -225,6 +231,8 @@ class TensorHandle : public core::RefCounted { // with a ready version of the tensor handle data. bool IsReady() const; + Status GetResourceHandleInfoImpl(std::function set_resource_info); + VariantDevice const device_; // Device in which the op producing this tensor was executed. Equals to @@ -268,9 +276,9 @@ class TensorHandle : public core::RefCounted { bool implicit_mirroring_; // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or - // refers to a remote resource handle, we store data types and shapes for - // the underlying resource. - std::vector handle_dtypes_and_shapes_; + // refers to a remote resource handle, we store data types, shapes and allowed + // devices for the underlying resource. + ResourceHandleInfo resource_handle_info_; // Does not need synchronization because it can be accessed only after // WaitReady() has returned. At that point, data_ is immutable. diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index 51c1e763021..54fb10e721d 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -164,22 +164,24 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, parent_->FindDeviceFromName(device_name.c_str(), &device)); TF_RETURN_IF_ERROR(TensorHandle::CreateLazyRemoteHandle( in.op_id(), in.output_num(), in.dtype(), device, parent_, out)); - std::vector dtypes_and_shapes; + TensorHandle::ResourceHandleInfo resource_handle_info; + std::vector* dtypes_and_shapes = + &resource_handle_info.dtypes_and_shapes; if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), - &dtypes_and_shapes) + dtypes_and_shapes) .ok()) { for (const auto& dtype_and_shape_proto : in.resource_dtypes_and_shapes()) { - dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{ + dtypes_and_shapes->push_back(DtypeAndPartialTensorShape{ dtype_and_shape_proto.dtype(), TensorShape(dtype_and_shape_proto.shape())}); } mutex_lock l(mirrored_resource_shape_mu_); mirrored_resource_shape_map_.emplace( RemoteTensorHandleInternal(in.op_id(), in.output_num()), - dtypes_and_shapes); + *dtypes_and_shapes); } - (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes)); + (*out)->SetResourceHandleInfo(std::move(resource_handle_info)); } return Status::OK(); diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index b4c4906ef27..2cf59d2845f 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -314,11 +314,13 @@ ResourceHandle MakeResourceHandle( template ResourceHandle MakeResourceHandle( OpKernelConstruction* ctx, const string& container, const string& name, - const std::vector& dtypes_and_shapes = {}) { - return MakeResourceHandle( - container.empty() ? ctx->resource_manager()->default_container() - : container, - name, *ctx->device(), MakeTypeIndex(), dtypes_and_shapes); + const std::vector& dtypes_and_shapes = {}, + const std::vector& allowed_devices = {}) { + return MakeResourceHandle(container.empty() + ? ctx->resource_manager()->default_container() + : container, + name, *ctx->device(), MakeTypeIndex(), + dtypes_and_shapes, allowed_devices); } Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 6a1cf9e570c..c8a08c65ab9 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -229,6 +229,8 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype)); PartialTensorShape shape; OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape)); + OP_REQUIRES_OK(context, + context->GetAttr("allowed_devices", &allowed_devices_)); is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME; @@ -239,7 +241,8 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) { &resource_, attr)); resource_.scalar()() = MakeResourceHandle( context, container_, name_, - std::vector{dtype_and_shape_}); + std::vector{dtype_and_shape_}, + allowed_devices_); } } @@ -252,7 +255,8 @@ void VarHandleOp::Compute(OpKernelContext* ctx) { ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr)); handle.scalar()() = MakeResourceHandle( ctx, container_, name_, - std::vector{dtype_and_shape_}); + std::vector{dtype_and_shape_}, + allowed_devices_); ctx->set_output(0, handle); } else { ctx->set_output(0, resource_); diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h index 1bb70b537c1..5935fa91d21 100644 --- a/tensorflow/core/kernels/resource_variable_ops.h +++ b/tensorflow/core/kernels/resource_variable_ops.h @@ -36,6 +36,10 @@ class VarHandleOp : public OpKernel { Tensor resource_; DtypeAndPartialTensorShape dtype_and_shape_; + + // A set of devices containing the resource variable. Set when the output + // ResourceHandle represents a per-replica/partitioned resource variable. + std::vector allowed_devices_; }; class ReadVariableOp : public OpKernel { diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 696a69eff80..77ab5f604c8 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -80,6 +80,7 @@ REGISTER_OP("VarHandleOp") .Attr("shared_name: string = ''") .Attr("dtype: type") .Attr("shape: shape") + .Attr("allowed_devices: list(string) = []") .Output("resource: resource") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index f20e54d18a5..cbd8f6a2ebe 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -30,6 +30,7 @@ from tensorflow.core.framework import tensor_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes @@ -1488,5 +1489,40 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.assertAllEqual(expected, result) +class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase): + + def setUp(self): + super(PerReplicaResourceHandleTest, self).setUp() + cpus = config.list_physical_devices("CPU") + # Set 2 virtual CPUs + config.set_logical_device_configuration(cpus[0], [ + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration(), + ]) + + def testAllowedDevices(self): + device0 = "/job:localhost/replica:0/task:0/device:CPU:0" + device1 = "/job:localhost/replica:0/task:0/device:CPU:1" + value0 = 1 + value1 = 2 + with context.eager_mode(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[], allowed_devices=[device0, device1]) + with ops.device(device0): + assign0 = resource_variable_ops.assign_variable_op(handle, value0) + with ops.device(device1): + assign1 = resource_variable_ops.assign_variable_op(handle, value1) + with ops.control_dependencies([assign0, assign1]): + with ops.device(device0): + read0 = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.int32) + with ops.device(device1): + read1 = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.int32) + + self.assertAllEqual(value0, read0) + self.assertAllEqual(value1, read1) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 686fd8653c7..d62a863d710 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -4962,7 +4962,7 @@ tf_module { } member_method { name: "VarHandleOp" - argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], " + argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'allowed_devices\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'[]\', \'None\'], " } member_method { name: "VarIsInitializedOp" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 686fd8653c7..d62a863d710 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -4962,7 +4962,7 @@ tf_module { } member_method { name: "VarHandleOp" - argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], " + argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'allowed_devices\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'[]\', \'None\'], " } member_method { name: "VarIsInitializedOp"