diff --git a/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt index 39606a07184..29ffcdaad6b 100644 --- a/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt @@ -28,8 +28,8 @@ END attr { name: "allowed_devices" description: <GetResourceHandleInfo(&resource_handle_info)); - std::vector* resource_dtypes_and_shapes = - &resource_handle_info.dtypes_and_shapes; - if (!resource_dtypes_and_shapes->empty()) { + std::vector resource_dtypes_and_shapes; + TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes( + &resource_dtypes_and_shapes)); + if (!resource_dtypes_and_shapes.empty()) { const DtypeAndPartialTensorShape& dtype_and_shape = - resource_dtypes_and_shapes->at(0); + resource_dtypes_and_shapes.at(0); input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape; // Add _Arg index, dtype and shape to "cache_key". @@ -695,13 +694,8 @@ Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op, TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype)); const AttrValue* shape; TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape)); - TensorHandle::ResourceHandleInfo resource_handle_info = { - {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}, {}}; - // "allowed_devices" is set only when the output represents a - // per-replica/partitioned resource variable. - TryGetNodeAttr(attr_slice, "allowed_devices", - &resource_handle_info.allowed_devices); - retvals[0]->SetResourceHandleInfo(std::move(resource_handle_info)); + retvals[0]->SetResourceHandleDtypeAndShape( + {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}); } return Status::OK(); } @@ -985,18 +979,6 @@ Status MaybeUpdateOpDevice(EagerOperation* op) { // is a resource we must pin it to prevent different device selection. // TODO(iga): null device can mean "unspecified" or "CPU". Clean this up. if (resource_device != op_device || op->Device() == kVariantDeviceNull) { - std::vector allowed_devices; - TF_RETURN_IF_ERROR( - tensor_handle->GetResourceAllowedDevices(&allowed_devices)); - if (!allowed_devices.empty()) { - // TODO(b/145922293): Support allowed_devices specified in wildcard - // patterns. - if (std::find(allowed_devices.begin(), allowed_devices.end(), - op->DeviceName()) != allowed_devices.end()) { - TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(op->DeviceName().c_str(), - &resource_device)); - } - } DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ") << "device of operation " << op->Name() << " to " << resource_device->name() << " because input #" << i diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index ef3e7a3cd28..0cd55959924 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -145,13 +145,13 @@ Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle( return Status::OK(); } -void TensorHandle::SetResourceHandleInfo( - ResourceHandleInfo&& resource_handle_info) { - resource_handle_info_ = std::move(resource_handle_info); +void TensorHandle::SetResourceHandleDtypeAndShape( + std::vector dtypes_and_shapes) { + handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes); } -Status TensorHandle::GetResourceHandleInfoImpl( - std::function set_resource_info) { +Status TensorHandle::GetResourceHandleDtypesAndShapes( + std::vector* result) { if (dtype != DT_RESOURCE) { return errors::InvalidArgument( "TensorHandle::GetResourceDtypeAndShape should be called on tensor " @@ -160,7 +160,7 @@ Status TensorHandle::GetResourceHandleInfoImpl( } if (Type() != LOCAL) { - set_resource_info(); + *result = handle_dtypes_and_shapes_; return Status::OK(); } @@ -170,32 +170,10 @@ Status TensorHandle::GetResourceHandleInfoImpl( auto& data = absl::get(data_); TF_RETURN_IF_ERROR(data.WaitReady("TensorHandle::GetResourceHandleInfo")); - set_resource_info(); + *result = handle_dtypes_and_shapes_; return Status::OK(); } -Status TensorHandle::GetResourceHandleInfo(ResourceHandleInfo* result) { - auto get_resource_info = [result, this]() { - *result = resource_handle_info_; - }; - return GetResourceHandleInfoImpl(get_resource_info); -} - -Status TensorHandle::GetResourceHandleDtypesAndShapes( - std::vector* result) { - auto get_resource_info = [result, this]() { - *result = resource_handle_info_.dtypes_and_shapes; - }; - return GetResourceHandleInfoImpl(get_resource_info); -} - -Status TensorHandle::GetResourceAllowedDevices(std::vector* result) { - auto get_resource_info = [result, this]() { - *result = resource_handle_info_.allowed_devices; - }; - return GetResourceHandleInfoImpl(get_resource_info); -} - int TensorHandle::NumPackedHandles() const { if (Type() != PACKED) { return 0; @@ -270,9 +248,8 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, resource_remote_device_incarnation_( GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), - resource_handle_info_( - {t.flat()(0).dtypes_and_shapes(), - t.flat()(0).allowed_devices()}), + handle_dtypes_and_shapes_( + t.flat()(0).dtypes_and_shapes()), data_(absl::in_place_type, std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << VariantDeviceDebugString(device_) @@ -327,10 +304,10 @@ Status TensorHandle::CreatePackedHandle(std::vector&& handles, return errors::InvalidArgument("Handles should not be empty."); } - ResourceHandleInfo resource_handle_info; + std::vector dtypes_and_shapes; if (dtype == DT_RESOURCE) { TF_RETURN_IF_ERROR( - handles.at(0)->GetResourceHandleInfo(&resource_handle_info)); + handles.at(0)->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes)); } std::vector devices; for (auto* handle : handles) { @@ -348,7 +325,8 @@ Status TensorHandle::CreatePackedHandle(std::vector&& handles, &composite_device)); *packed_handle = new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx); - (*packed_handle)->SetResourceHandleInfo(std::move(resource_handle_info)); + (*packed_handle) + ->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes)); return Status::OK(); } @@ -898,8 +876,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) { if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) { auto& resource_handle = t.flat()(0); - resource_handle_info_ = {resource_handle.dtypes_and_shapes(), - resource_handle.allowed_devices()}; + handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes(); } auto& data = absl::get(data_); return data.SetTensor(std::move(t)); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 6ac48bdac26..8ef482cd82c 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -226,19 +226,13 @@ class TensorHandle : public ImmediateExecutionTensorHandle, string DebugString() const; - struct ResourceHandleInfo { - std::vector dtypes_and_shapes; - std::vector allowed_devices; - }; - - void SetResourceHandleInfo(ResourceHandleInfo&& resource_handle_info); + void SetResourceHandleDtypeAndShape( + std::vector dtypes_and_shapes); // If this TensorHandle is 1) a local tensor, and 2) a resource handle, - // return data types, shapes and allowed devices of the underlying resource. - Status GetResourceHandleInfo(ResourceHandleInfo* result); + // return data types and shapes of the underlying resource. Status GetResourceHandleDtypesAndShapes( std::vector* result); - Status GetResourceAllowedDevices(std::vector* result); // Returns the number of packed handles. 0 if the handle type is not PACKED. int NumPackedHandles() const; @@ -261,8 +255,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle, // with a ready version of the tensor handle data. bool IsReady() const; - Status GetResourceHandleInfoImpl(std::function set_resource_info); - VariantDevice const device_; // Device in which the op producing this tensor was executed. Equals to @@ -308,9 +300,9 @@ class TensorHandle : public ImmediateExecutionTensorHandle, Status is_poisoned_; // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or - // refers to a remote resource handle, we store data types, shapes and allowed - // devices for the underlying resource. - ResourceHandleInfo resource_handle_info_; + // refers to a remote resource handle, we store data types and shapes for + // the underlying resource. + std::vector handle_dtypes_and_shapes_; // A handle data which refers to multiple TensorHandles of the same dtype and // shape. diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc index 28092c0a604..40cec3fcc49 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc @@ -150,13 +150,13 @@ TEST_F(PackedTensorHandleTest, PackedHandle) { Device* d0 = ListDevices().at(0); TensorHandle* h0 = TensorHandle::CreateLocalHandle(std::move(t0), d0, d0, d0, context()); - h0->SetResourceHandleInfo({{dtype_and_shape}, {}}); + h0->SetResourceHandleDtypeAndShape({dtype_and_shape}); handles.push_back(h0); Tensor t1(dtype, shape); Device* d1 = ListDevices().at(1); TensorHandle* h1 = TensorHandle::CreateLocalHandle(std::move(t1), d1, d1, d1, context()); - h1->SetResourceHandleInfo({{dtype_and_shape}, {}}); + h1->SetResourceHandleDtypeAndShape({dtype_and_shape}); handles.push_back(h1); // Create 2 remote TensorHandles (not ready). @@ -185,13 +185,12 @@ TEST_F(PackedTensorHandleTest, PackedHandle) { TensorShape packed_shape; TF_ASSERT_OK(packed_handle->Shape(&packed_shape)); EXPECT_EQ(packed_shape, shape); - TensorHandle::ResourceHandleInfo resource_handle_info; - TF_ASSERT_OK(packed_handle->GetResourceHandleInfo(&resource_handle_info)); - EXPECT_EQ(resource_handle_info.dtypes_and_shapes.size(), 1); - EXPECT_EQ(resource_handle_info.dtypes_and_shapes.at(0).dtype, DT_FLOAT); - EXPECT_EQ( - resource_handle_info.dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}), - true); + std::vector dtypes_and_shapes; + TF_ASSERT_OK( + packed_handle->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes)); + EXPECT_EQ(dtypes_and_shapes.size(), 1); + EXPECT_EQ(dtypes_and_shapes.at(0).dtype, DT_FLOAT); + EXPECT_EQ(dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}), true); CompositeDevice* device = reinterpret_cast( absl::get(packed_handle->device())); diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index 94a4f199337..9003f2b3f17 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -167,24 +167,22 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, parent_->FindDeviceFromName(device_name.c_str(), &device)); *out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(), in.dtype(), device, parent_); - TensorHandle::ResourceHandleInfo resource_handle_info; - std::vector* dtypes_and_shapes = - &resource_handle_info.dtypes_and_shapes; + std::vector dtypes_and_shapes; if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), - dtypes_and_shapes) + &dtypes_and_shapes) .ok()) { for (const auto& dtype_and_shape_proto : in.resource_dtypes_and_shapes()) { - dtypes_and_shapes->push_back(DtypeAndPartialTensorShape{ + dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{ dtype_and_shape_proto.dtype(), TensorShape(dtype_and_shape_proto.shape())}); } mutex_lock l(mirrored_resource_shape_mu_); mirrored_resource_shape_map_.emplace( RemoteTensorHandleInternal(in.op_id(), in.output_num()), - *dtypes_and_shapes); + dtypes_and_shapes); } - (*out)->SetResourceHandleInfo(std::move(resource_handle_info)); + (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes)); } return Status::OK(); diff --git a/tensorflow/core/framework/resource_handle.cc b/tensorflow/core/framework/resource_handle.cc index 2db5cfa301c..e7f4c2afc90 100644 --- a/tensorflow/core/framework/resource_handle.cc +++ b/tensorflow/core/framework/resource_handle.cc @@ -38,9 +38,6 @@ void ResourceHandle::AsProto(ResourceHandleProto* proto) const { dtype_and_shape->set_dtype(dtype_and_shape_pair.dtype); dtype_and_shape_pair.shape.AsProto(dtype_and_shape->mutable_shape()); } - for (const string& device : allowed_devices_) { - *proto->add_allowed_devices() = device; - } } void ResourceHandle::FromProto(const ResourceHandleProto& proto) { @@ -56,9 +53,6 @@ void ResourceHandle::FromProto(const ResourceHandleProto& proto) { dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{dtype, shape}); } dtypes_and_shapes_ = std::move(dtypes_and_shapes); - for (const string& device : proto.allowed_devices()) { - allowed_devices_.push_back(device); - } } string ResourceHandle::SerializeAsString() const { diff --git a/tensorflow/core/framework/resource_handle.h b/tensorflow/core/framework/resource_handle.h index 88c9f9da190..9acb94b6e79 100644 --- a/tensorflow/core/framework/resource_handle.h +++ b/tensorflow/core/framework/resource_handle.h @@ -39,14 +39,8 @@ class ResourceHandle { // Unique name for the device containing the resource. const std::string& device() const { return device_; } - // Names of the devices containing the resource. - const std::vector& allowed_devices() const { - return allowed_devices_; - } + void set_device(const std::string& device) { device_ = device; } - void set_allowed_devices(const std::vector& devices) { - allowed_devices_ = devices; - } // Container in which this resource is placed. const std::string& container() const { return container_; } @@ -93,12 +87,7 @@ class ResourceHandle { "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; public: - // The default device containing the resource, where the ResourceHandle is - // initially created. std::string device_; - // A set of devices containing the resource. If empty, the resource only - // exists on device_. Can be represented in wildcard patterns. - std::vector allowed_devices_; std::string container_; std::string name_; uint64 hash_code_ = 0; diff --git a/tensorflow/core/framework/resource_handle.proto b/tensorflow/core/framework/resource_handle.proto index eb0d1631c2f..5a41750475d 100644 --- a/tensorflow/core/framework/resource_handle.proto +++ b/tensorflow/core/framework/resource_handle.proto @@ -41,7 +41,5 @@ message ResourceHandleProto { // Data types and shapes for the underlying resource. repeated DtypeAndShape dtypes_and_shapes = 6; - // A set of devices containing the resource. If empty, the resource only - // exists on `device`. - repeated string allowed_devices = 7; + reserved 7; } diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index fd524b05bb9..e6ecfbb9190 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -36,8 +36,7 @@ static std::atomic current_id_; ResourceHandle MakeResourceHandle( const string& container, const string& name, const DeviceBase& device, const TypeIndex& type_index, - const std::vector& dtypes_and_shapes, - const std::vector& allowed_devices) { + const std::vector& dtypes_and_shapes) { ResourceHandle result; result.set_device(device.name()); result.set_container(container); @@ -49,7 +48,6 @@ ResourceHandle MakeResourceHandle( result.set_hash_code(type_index.hash_code()); result.set_maybe_type_name(type_index.name()); result.set_dtypes_and_shapes(dtypes_and_shapes); - result.set_allowed_devices(allowed_devices); return result; } @@ -67,39 +65,12 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, namespace internal { Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) { - const string& current_device_name = ctx->device()->attributes().name(); - if (current_device_name == p.device()) { - return Status::OK(); - } - DeviceNameUtils::ParsedName parsed_current_device_name; - if (!DeviceNameUtils::ParseFullName(current_device_name, - &parsed_current_device_name)) { + if (ctx->device()->attributes().name() != p.device()) { return errors::InvalidArgument( - "Cannot parse device name in OpKernelContext: ", current_device_name); + "Trying to access resource ", p.name(), " located in device ", + p.device(), " from device ", ctx->device()->attributes().name()); } - - for (const string& device : p.allowed_devices()) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device, &parsed)) { - return errors::InvalidArgument("Cannot parse allowed device name: ", - device); - } - if (DeviceNameUtils::IsCompleteSpecification(parsed, - parsed_current_device_name)) { - return Status::OK(); - } - } - string error_message = strings::StrCat("Trying to access resource ", p.name(), - " located in device ", p.device(), - " from device ", current_device_name); - if (!p.allowed_devices().empty()) { - absl::StrAppend(&error_message, " (allowed devices: "); - for (const string& device : p.allowed_devices()) { - absl::StrAppend(&error_message, device, ", "); - } - absl::StrAppend(&error_message, ") "); - } - return errors::InvalidArgument(error_message); + return Status::OK(); } } // end namespace internal diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index 3a9b97c7831..b0e4eace16e 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -291,31 +291,27 @@ class ResourceMgr { ResourceHandle MakeResourceHandle( const string& container, const string& name, const DeviceBase& device, const TypeIndex& type_index, - const std::vector& dtypes_and_shapes = {}, - const std::vector& allowed_devices = {}) TF_MUST_USE_RESULT; + const std::vector& dtypes_and_shapes = {}) + TF_MUST_USE_RESULT; template ResourceHandle MakeResourceHandle( OpKernelContext* ctx, const string& container, const string& name, - const std::vector& dtypes_and_shapes = {}, - const std::vector& allowed_devices = {}) { - return MakeResourceHandle(container.empty() - ? ctx->resource_manager()->default_container() - : container, - name, *ctx->device(), MakeTypeIndex(), - dtypes_and_shapes, allowed_devices); + const std::vector& dtypes_and_shapes = {}) { + return MakeResourceHandle( + container.empty() ? ctx->resource_manager()->default_container() + : container, + name, *ctx->device(), MakeTypeIndex(), dtypes_and_shapes); } template ResourceHandle MakeResourceHandle( OpKernelConstruction* ctx, const string& container, const string& name, - const std::vector& dtypes_and_shapes = {}, - const std::vector& allowed_devices = {}) { - return MakeResourceHandle(container.empty() - ? ctx->resource_manager()->default_container() - : container, - name, *ctx->device(), MakeTypeIndex(), - dtypes_and_shapes, allowed_devices); + const std::vector& dtypes_and_shapes = {}) { + return MakeResourceHandle( + container.empty() ? ctx->resource_manager()->default_container() + : container, + name, *ctx->device(), MakeTypeIndex(), dtypes_and_shapes); } Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index a48024123a6..f524ff77c11 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -352,51 +352,4 @@ TEST(ResourceHandleTest, DeleteUsingResourceHandle) { EXPECT_NE(LookupResource(&ctx, p, &lookup_r).ok(), true); } -TEST(ResourceHandleTest, AllowedDevices) { - const std::vector device_names = { - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:0/device:CPU:2", - "/job:worker/replica:1/task:3/device:CPU:5"}; - std::vector devices; - for (const string& name : device_names) { - devices.emplace_back(name); - } - - std::vector params(device_names.size()); - std::vector> resource_mgrs; - std::vector> ctxs; - for (int i = 0; i < device_names.size(); ++i) { - resource_mgrs.emplace_back( - absl::make_unique(/* default_container= */ "")); - params[i].resource_manager = resource_mgrs[i].get(); - params[i].device = &(devices[i]); - ctxs.emplace_back( - absl::make_unique(&(params[i]), /* num_outputs= */ 0)); - } - - const string partially_specified_name = - "/job:worker/replica:0/task:0/device:CPU:*"; - const string& fully_specified_name = device_names.at(2); - const std::vector allowed_devices = {partially_specified_name, - fully_specified_name}; - // Create a ResourceHandle on device 0. - ResourceHandle p = MakeResourceHandle( - ctxs[0].get(), "container", "name", - /* dtypes_and_shapes= */ {}, allowed_devices); - - std::vector resources; - for (const auto& ctx : ctxs) { - StubResource* r = new StubResource; - TF_EXPECT_OK(CreateResource(ctx.get(), p, r)); - resources.push_back(r); - } - - for (int i = 0; i < ctxs.size(); ++i) { - core::RefCountPtr lookup_r; - TF_EXPECT_OK(LookupResource(ctxs[i].get(), p, &lookup_r)); - EXPECT_EQ(lookup_r.get(), resources[i]); - TF_EXPECT_OK(DeleteResource(ctxs[i].get(), p)); - } -} - } // end namespace tensorflow diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 0fc1d53749f..b9c883c7e2f 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -222,8 +222,6 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype)); PartialTensorShape shape; OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape)); - OP_REQUIRES_OK(context, - context->GetAttr("allowed_devices", &allowed_devices_)); is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME; @@ -234,8 +232,7 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) { &resource_, attr)); resource_.scalar()() = MakeResourceHandle( context, container_, name_, - std::vector{dtype_and_shape_}, - allowed_devices_); + std::vector{dtype_and_shape_}); } } @@ -248,8 +245,7 @@ void VarHandleOp::Compute(OpKernelContext* ctx) { ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr)); handle.scalar()() = MakeResourceHandle( ctx, container_, name_, - std::vector{dtype_and_shape_}, - allowed_devices_); + std::vector{dtype_and_shape_}); ctx->set_output(0, handle); } else { ctx->set_output(0, resource_); diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h index 5935fa91d21..1bb70b537c1 100644 --- a/tensorflow/core/kernels/resource_variable_ops.h +++ b/tensorflow/core/kernels/resource_variable_ops.h @@ -36,10 +36,6 @@ class VarHandleOp : public OpKernel { Tensor resource_; DtypeAndPartialTensorShape dtype_and_shape_; - - // A set of devices containing the resource variable. Set when the output - // ResourceHandle represents a per-replica/partitioned resource variable. - std::vector allowed_devices_; }; class ReadVariableOp : public OpKernel { diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index b45e9dfb2bc..fb172fbcb10 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -30,7 +30,6 @@ from tensorflow.core.framework import tensor_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function -from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes @@ -1513,41 +1512,5 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.assertAllEqual(expected, result) -class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase): - - def setUp(self): - super(PerReplicaResourceHandleTest, self).setUp() - cpus = config.list_physical_devices("CPU") - # Set 2 virtual CPUs - config.set_logical_device_configuration(cpus[0], [ - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration(), - ]) - - @test_util.disable_tfrt("Multiple device support. b/154956430") - def testAllowedDevices(self): - device0 = "/job:localhost/replica:0/task:0/device:CPU:0" - device1 = "/job:localhost/replica:0/task:0/device:CPU:1" - value0 = 1 - value1 = 2 - with context.eager_mode(): - handle = resource_variable_ops.var_handle_op( - dtype=dtypes.int32, shape=[], allowed_devices=[device0, device1]) - with ops.device(device0): - assign0 = resource_variable_ops.assign_variable_op(handle, value0) - with ops.device(device1): - assign1 = resource_variable_ops.assign_variable_op(handle, value1) - with ops.control_dependencies([assign0, assign1]): - with ops.device(device0): - read0 = resource_variable_ops.read_variable_op( - handle, dtype=dtypes.int32) - with ops.device(device1): - read1 = resource_variable_ops.read_variable_op( - handle, dtype=dtypes.int32) - - self.assertAllEqual(value0, read0) - self.assertAllEqual(value1, read1) - - if __name__ == "__main__": test.main()