[Cleanup] Remove allowed_devices of ResourceHandle since it's no longer used.

PiperOrigin-RevId: 317710941
Change-Id: Ib1920c5ee25d405290f852b725d693ee5ea09766
This commit is contained in:
Yujing Zhang 2020-06-22 12:29:26 -07:00 committed by TensorFlower Gardener
parent aac1dd5788
commit 38d95ad2d8
15 changed files with 63 additions and 259 deletions

View File

@ -28,8 +28,8 @@ END
attr { attr {
name: "allowed_devices" name: "allowed_devices"
description: <<END description: <<END
The allowed devices containing the resource variable. Set when the output DEPRECATED. The allowed devices containing the resource variable. Set when the
ResourceHandle represents a per-replica/partitioned resource variable. output ResourceHandle represents a per-replica/partitioned resource variable.
END END
} }
summary: "Creates a handle to a Variable resource." summary: "Creates a handle to a Variable resource."

View File

@ -432,13 +432,12 @@ Status GetOrCreateKernelAndDevice(
// looking it up in ResourceMgr, which is slow). So we just get // looking it up in ResourceMgr, which is slow). So we just get
// resource_dtypes_and_shapes for all DT_RESOURCE inputs. If // resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
// resource_dtypes_and_shapes is not empty, take the first element. // resource_dtypes_and_shapes is not empty, take the first element.
TensorHandle::ResourceHandleInfo resource_handle_info; std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
TF_RETURN_IF_ERROR(input->GetResourceHandleInfo(&resource_handle_info)); TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* resource_dtypes_and_shapes = &resource_dtypes_and_shapes));
&resource_handle_info.dtypes_and_shapes; if (!resource_dtypes_and_shapes.empty()) {
if (!resource_dtypes_and_shapes->empty()) {
const DtypeAndPartialTensorShape& dtype_and_shape = const DtypeAndPartialTensorShape& dtype_and_shape =
resource_dtypes_and_shapes->at(0); resource_dtypes_and_shapes.at(0);
input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape; input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
// Add _Arg index, dtype and shape to "cache_key". // Add _Arg index, dtype and shape to "cache_key".
@ -695,13 +694,8 @@ Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype)); TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
const AttrValue* shape; const AttrValue* shape;
TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape)); TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
TensorHandle::ResourceHandleInfo resource_handle_info = { retvals[0]->SetResourceHandleDtypeAndShape(
{DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}, {}}; {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
// "allowed_devices" is set only when the output represents a
// per-replica/partitioned resource variable.
TryGetNodeAttr(attr_slice, "allowed_devices",
&resource_handle_info.allowed_devices);
retvals[0]->SetResourceHandleInfo(std::move(resource_handle_info));
} }
return Status::OK(); return Status::OK();
} }
@ -985,18 +979,6 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
// is a resource we must pin it to prevent different device selection. // is a resource we must pin it to prevent different device selection.
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up. // TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
if (resource_device != op_device || op->Device() == kVariantDeviceNull) { if (resource_device != op_device || op->Device() == kVariantDeviceNull) {
std::vector<string> allowed_devices;
TF_RETURN_IF_ERROR(
tensor_handle->GetResourceAllowedDevices(&allowed_devices));
if (!allowed_devices.empty()) {
// TODO(b/145922293): Support allowed_devices specified in wildcard
// patterns.
if (std::find(allowed_devices.begin(), allowed_devices.end(),
op->DeviceName()) != allowed_devices.end()) {
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(op->DeviceName().c_str(),
&resource_device));
}
}
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ") DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
<< "device of operation " << op->Name() << " to " << "device of operation " << op->Name() << " to "
<< resource_device->name() << " because input #" << i << resource_device->name() << " because input #" << i

View File

@ -145,13 +145,13 @@ Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
return Status::OK(); return Status::OK();
} }
void TensorHandle::SetResourceHandleInfo( void TensorHandle::SetResourceHandleDtypeAndShape(
ResourceHandleInfo&& resource_handle_info) { std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) {
resource_handle_info_ = std::move(resource_handle_info); handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
} }
Status TensorHandle::GetResourceHandleInfoImpl( Status TensorHandle::GetResourceHandleDtypesAndShapes(
std::function<void()> set_resource_info) { std::vector<DtypeAndPartialTensorShape>* result) {
if (dtype != DT_RESOURCE) { if (dtype != DT_RESOURCE) {
return errors::InvalidArgument( return errors::InvalidArgument(
"TensorHandle::GetResourceDtypeAndShape should be called on tensor " "TensorHandle::GetResourceDtypeAndShape should be called on tensor "
@ -160,7 +160,7 @@ Status TensorHandle::GetResourceHandleInfoImpl(
} }
if (Type() != LOCAL) { if (Type() != LOCAL) {
set_resource_info(); *result = handle_dtypes_and_shapes_;
return Status::OK(); return Status::OK();
} }
@ -170,32 +170,10 @@ Status TensorHandle::GetResourceHandleInfoImpl(
auto& data = absl::get<LocalTensorHandleData>(data_); auto& data = absl::get<LocalTensorHandleData>(data_);
TF_RETURN_IF_ERROR(data.WaitReady("TensorHandle::GetResourceHandleInfo")); TF_RETURN_IF_ERROR(data.WaitReady("TensorHandle::GetResourceHandleInfo"));
set_resource_info(); *result = handle_dtypes_and_shapes_;
return Status::OK(); return Status::OK();
} }
Status TensorHandle::GetResourceHandleInfo(ResourceHandleInfo* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_.dtypes_and_shapes;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_.allowed_devices;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
int TensorHandle::NumPackedHandles() const { int TensorHandle::NumPackedHandles() const {
if (Type() != PACKED) { if (Type() != PACKED) {
return 0; return 0;
@ -270,9 +248,8 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
resource_remote_device_incarnation_( resource_remote_device_incarnation_(
GetRemoteDeviceIncarnation(resource_device_)), GetRemoteDeviceIncarnation(resource_device_)),
ctx_(ctx), ctx_(ctx),
resource_handle_info_( handle_dtypes_and_shapes_(
{t.flat<class ResourceHandle>()(0).dtypes_and_shapes(), t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
t.flat<class ResourceHandle>()(0).allowed_devices()}),
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) { data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_) << " device: " << VariantDeviceDebugString(device_)
@ -327,10 +304,10 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
return errors::InvalidArgument("Handles should not be empty."); return errors::InvalidArgument("Handles should not be empty.");
} }
ResourceHandleInfo resource_handle_info; std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
if (dtype == DT_RESOURCE) { if (dtype == DT_RESOURCE) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
handles.at(0)->GetResourceHandleInfo(&resource_handle_info)); handles.at(0)->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes));
} }
std::vector<string> devices; std::vector<string> devices;
for (auto* handle : handles) { for (auto* handle : handles) {
@ -348,7 +325,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
&composite_device)); &composite_device));
*packed_handle = *packed_handle =
new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx); new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx);
(*packed_handle)->SetResourceHandleInfo(std::move(resource_handle_info)); (*packed_handle)
->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
return Status::OK(); return Status::OK();
} }
@ -898,8 +876,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) { if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
auto& resource_handle = t.flat<class ResourceHandle>()(0); auto& resource_handle = t.flat<class ResourceHandle>()(0);
resource_handle_info_ = {resource_handle.dtypes_and_shapes(), handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
resource_handle.allowed_devices()};
} }
auto& data = absl::get<LocalTensorHandleData>(data_); auto& data = absl::get<LocalTensorHandleData>(data_);
return data.SetTensor(std::move(t)); return data.SetTensor(std::move(t));

View File

@ -226,19 +226,13 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
string DebugString() const; string DebugString() const;
struct ResourceHandleInfo { void SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes; std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
std::vector<string> allowed_devices;
};
void SetResourceHandleInfo(ResourceHandleInfo&& resource_handle_info);
// If this TensorHandle is 1) a local tensor, and 2) a resource handle, // If this TensorHandle is 1) a local tensor, and 2) a resource handle,
// return data types, shapes and allowed devices of the underlying resource. // return data types and shapes of the underlying resource.
Status GetResourceHandleInfo(ResourceHandleInfo* result);
Status GetResourceHandleDtypesAndShapes( Status GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result); std::vector<DtypeAndPartialTensorShape>* result);
Status GetResourceAllowedDevices(std::vector<string>* result);
// Returns the number of packed handles. 0 if the handle type is not PACKED. // Returns the number of packed handles. 0 if the handle type is not PACKED.
int NumPackedHandles() const; int NumPackedHandles() const;
@ -261,8 +255,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
// with a ready version of the tensor handle data. // with a ready version of the tensor handle data.
bool IsReady() const; bool IsReady() const;
Status GetResourceHandleInfoImpl(std::function<void()> set_resource_info);
VariantDevice const device_; VariantDevice const device_;
// Device in which the op producing this tensor was executed. Equals to // Device in which the op producing this tensor was executed. Equals to
@ -308,9 +300,9 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
Status is_poisoned_; Status is_poisoned_;
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
// refers to a remote resource handle, we store data types, shapes and allowed // refers to a remote resource handle, we store data types and shapes for
// devices for the underlying resource. // the underlying resource.
ResourceHandleInfo resource_handle_info_; std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
// A handle data which refers to multiple TensorHandles of the same dtype and // A handle data which refers to multiple TensorHandles of the same dtype and
// shape. // shape.

View File

@ -150,13 +150,13 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
Device* d0 = ListDevices().at(0); Device* d0 = ListDevices().at(0);
TensorHandle* h0 = TensorHandle* h0 =
TensorHandle::CreateLocalHandle(std::move(t0), d0, d0, d0, context()); TensorHandle::CreateLocalHandle(std::move(t0), d0, d0, d0, context());
h0->SetResourceHandleInfo({{dtype_and_shape}, {}}); h0->SetResourceHandleDtypeAndShape({dtype_and_shape});
handles.push_back(h0); handles.push_back(h0);
Tensor t1(dtype, shape); Tensor t1(dtype, shape);
Device* d1 = ListDevices().at(1); Device* d1 = ListDevices().at(1);
TensorHandle* h1 = TensorHandle* h1 =
TensorHandle::CreateLocalHandle(std::move(t1), d1, d1, d1, context()); TensorHandle::CreateLocalHandle(std::move(t1), d1, d1, d1, context());
h1->SetResourceHandleInfo({{dtype_and_shape}, {}}); h1->SetResourceHandleDtypeAndShape({dtype_and_shape});
handles.push_back(h1); handles.push_back(h1);
// Create 2 remote TensorHandles (not ready). // Create 2 remote TensorHandles (not ready).
@ -185,13 +185,12 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
TensorShape packed_shape; TensorShape packed_shape;
TF_ASSERT_OK(packed_handle->Shape(&packed_shape)); TF_ASSERT_OK(packed_handle->Shape(&packed_shape));
EXPECT_EQ(packed_shape, shape); EXPECT_EQ(packed_shape, shape);
TensorHandle::ResourceHandleInfo resource_handle_info; std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
TF_ASSERT_OK(packed_handle->GetResourceHandleInfo(&resource_handle_info)); TF_ASSERT_OK(
EXPECT_EQ(resource_handle_info.dtypes_and_shapes.size(), 1); packed_handle->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes));
EXPECT_EQ(resource_handle_info.dtypes_and_shapes.at(0).dtype, DT_FLOAT); EXPECT_EQ(dtypes_and_shapes.size(), 1);
EXPECT_EQ( EXPECT_EQ(dtypes_and_shapes.at(0).dtype, DT_FLOAT);
resource_handle_info.dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}), EXPECT_EQ(dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}), true);
true);
CompositeDevice* device = reinterpret_cast<CompositeDevice*>( CompositeDevice* device = reinterpret_cast<CompositeDevice*>(
absl::get<Device*>(packed_handle->device())); absl::get<Device*>(packed_handle->device()));

View File

@ -167,24 +167,22 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
parent_->FindDeviceFromName(device_name.c_str(), &device)); parent_->FindDeviceFromName(device_name.c_str(), &device));
*out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(), *out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
in.dtype(), device, parent_); in.dtype(), device, parent_);
TensorHandle::ResourceHandleInfo resource_handle_info; std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
std::vector<DtypeAndPartialTensorShape>* dtypes_and_shapes =
&resource_handle_info.dtypes_and_shapes;
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
dtypes_and_shapes) &dtypes_and_shapes)
.ok()) { .ok()) {
for (const auto& dtype_and_shape_proto : for (const auto& dtype_and_shape_proto :
in.resource_dtypes_and_shapes()) { in.resource_dtypes_and_shapes()) {
dtypes_and_shapes->push_back(DtypeAndPartialTensorShape{ dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
dtype_and_shape_proto.dtype(), dtype_and_shape_proto.dtype(),
TensorShape(dtype_and_shape_proto.shape())}); TensorShape(dtype_and_shape_proto.shape())});
} }
mutex_lock l(mirrored_resource_shape_mu_); mutex_lock l(mirrored_resource_shape_mu_);
mirrored_resource_shape_map_.emplace( mirrored_resource_shape_map_.emplace(
RemoteTensorHandleInternal(in.op_id(), in.output_num()), RemoteTensorHandleInternal(in.op_id(), in.output_num()),
*dtypes_and_shapes); dtypes_and_shapes);
} }
(*out)->SetResourceHandleInfo(std::move(resource_handle_info)); (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
} }
return Status::OK(); return Status::OK();

View File

@ -38,9 +38,6 @@ void ResourceHandle::AsProto(ResourceHandleProto* proto) const {
dtype_and_shape->set_dtype(dtype_and_shape_pair.dtype); dtype_and_shape->set_dtype(dtype_and_shape_pair.dtype);
dtype_and_shape_pair.shape.AsProto(dtype_and_shape->mutable_shape()); dtype_and_shape_pair.shape.AsProto(dtype_and_shape->mutable_shape());
} }
for (const string& device : allowed_devices_) {
*proto->add_allowed_devices() = device;
}
} }
void ResourceHandle::FromProto(const ResourceHandleProto& proto) { void ResourceHandle::FromProto(const ResourceHandleProto& proto) {
@ -56,9 +53,6 @@ void ResourceHandle::FromProto(const ResourceHandleProto& proto) {
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{dtype, shape}); dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{dtype, shape});
} }
dtypes_and_shapes_ = std::move(dtypes_and_shapes); dtypes_and_shapes_ = std::move(dtypes_and_shapes);
for (const string& device : proto.allowed_devices()) {
allowed_devices_.push_back(device);
}
} }
string ResourceHandle::SerializeAsString() const { string ResourceHandle::SerializeAsString() const {

View File

@ -39,14 +39,8 @@ class ResourceHandle {
// Unique name for the device containing the resource. // Unique name for the device containing the resource.
const std::string& device() const { return device_; } const std::string& device() const { return device_; }
// Names of the devices containing the resource.
const std::vector<string>& allowed_devices() const {
return allowed_devices_;
}
void set_device(const std::string& device) { device_ = device; } void set_device(const std::string& device) { device_ = device; }
void set_allowed_devices(const std::vector<string>& devices) {
allowed_devices_ = devices;
}
// Container in which this resource is placed. // Container in which this resource is placed.
const std::string& container() const { return container_; } const std::string& container() const { return container_; }
@ -93,12 +87,7 @@ class ResourceHandle {
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"; "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
public: public:
// The default device containing the resource, where the ResourceHandle is
// initially created.
std::string device_; std::string device_;
// A set of devices containing the resource. If empty, the resource only
// exists on device_. Can be represented in wildcard patterns.
std::vector<string> allowed_devices_;
std::string container_; std::string container_;
std::string name_; std::string name_;
uint64 hash_code_ = 0; uint64 hash_code_ = 0;

View File

@ -41,7 +41,5 @@ message ResourceHandleProto {
// Data types and shapes for the underlying resource. // Data types and shapes for the underlying resource.
repeated DtypeAndShape dtypes_and_shapes = 6; repeated DtypeAndShape dtypes_and_shapes = 6;
// A set of devices containing the resource. If empty, the resource only reserved 7;
// exists on `device`.
repeated string allowed_devices = 7;
} }

View File

@ -36,8 +36,7 @@ static std::atomic<int64> current_id_;
ResourceHandle MakeResourceHandle( ResourceHandle MakeResourceHandle(
const string& container, const string& name, const DeviceBase& device, const string& container, const string& name, const DeviceBase& device,
const TypeIndex& type_index, const TypeIndex& type_index,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes, const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes) {
const std::vector<string>& allowed_devices) {
ResourceHandle result; ResourceHandle result;
result.set_device(device.name()); result.set_device(device.name());
result.set_container(container); result.set_container(container);
@ -49,7 +48,6 @@ ResourceHandle MakeResourceHandle(
result.set_hash_code(type_index.hash_code()); result.set_hash_code(type_index.hash_code());
result.set_maybe_type_name(type_index.name()); result.set_maybe_type_name(type_index.name());
result.set_dtypes_and_shapes(dtypes_and_shapes); result.set_dtypes_and_shapes(dtypes_and_shapes);
result.set_allowed_devices(allowed_devices);
return result; return result;
} }
@ -67,40 +65,13 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
namespace internal { namespace internal {
Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) { Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) {
const string& current_device_name = ctx->device()->attributes().name(); if (ctx->device()->attributes().name() != p.device()) {
if (current_device_name == p.device()) {
return Status::OK();
}
DeviceNameUtils::ParsedName parsed_current_device_name;
if (!DeviceNameUtils::ParseFullName(current_device_name,
&parsed_current_device_name)) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Cannot parse device name in OpKernelContext: ", current_device_name); "Trying to access resource ", p.name(), " located in device ",
p.device(), " from device ", ctx->device()->attributes().name());
} }
for (const string& device : p.allowed_devices()) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
return errors::InvalidArgument("Cannot parse allowed device name: ",
device);
}
if (DeviceNameUtils::IsCompleteSpecification(parsed,
parsed_current_device_name)) {
return Status::OK(); return Status::OK();
} }
}
string error_message = strings::StrCat("Trying to access resource ", p.name(),
" located in device ", p.device(),
" from device ", current_device_name);
if (!p.allowed_devices().empty()) {
absl::StrAppend(&error_message, " (allowed devices: ");
for (const string& device : p.allowed_devices()) {
absl::StrAppend(&error_message, device, ", ");
}
absl::StrAppend(&error_message, ") ");
}
return errors::InvalidArgument(error_message);
}
} // end namespace internal } // end namespace internal

View File

@ -291,31 +291,27 @@ class ResourceMgr {
ResourceHandle MakeResourceHandle( ResourceHandle MakeResourceHandle(
const string& container, const string& name, const DeviceBase& device, const string& container, const string& name, const DeviceBase& device,
const TypeIndex& type_index, const TypeIndex& type_index,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}, const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {})
const std::vector<string>& allowed_devices = {}) TF_MUST_USE_RESULT; TF_MUST_USE_RESULT;
template <typename T> template <typename T>
ResourceHandle MakeResourceHandle( ResourceHandle MakeResourceHandle(
OpKernelContext* ctx, const string& container, const string& name, OpKernelContext* ctx, const string& container, const string& name,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}, const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
const std::vector<string>& allowed_devices = {}) { return MakeResourceHandle(
return MakeResourceHandle(container.empty() container.empty() ? ctx->resource_manager()->default_container()
? ctx->resource_manager()->default_container()
: container, : container,
name, *ctx->device(), MakeTypeIndex<T>(), name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
dtypes_and_shapes, allowed_devices);
} }
template <typename T> template <typename T>
ResourceHandle MakeResourceHandle( ResourceHandle MakeResourceHandle(
OpKernelConstruction* ctx, const string& container, const string& name, OpKernelConstruction* ctx, const string& container, const string& name,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}, const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
const std::vector<string>& allowed_devices = {}) { return MakeResourceHandle(
return MakeResourceHandle(container.empty() container.empty() ? ctx->resource_manager()->default_container()
? ctx->resource_manager()->default_container()
: container, : container,
name, *ctx->device(), MakeTypeIndex<T>(), name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
dtypes_and_shapes, allowed_devices);
} }
Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,

View File

@ -352,51 +352,4 @@ TEST(ResourceHandleTest, DeleteUsingResourceHandle) {
EXPECT_NE(LookupResource<StubResource>(&ctx, p, &lookup_r).ok(), true); EXPECT_NE(LookupResource<StubResource>(&ctx, p, &lookup_r).ok(), true);
} }
TEST(ResourceHandleTest, AllowedDevices) {
const std::vector<string> device_names = {
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:0/device:CPU:2",
"/job:worker/replica:1/task:3/device:CPU:5"};
std::vector<StubDevice> devices;
for (const string& name : device_names) {
devices.emplace_back(name);
}
std::vector<OpKernelContext::Params> params(device_names.size());
std::vector<std::unique_ptr<ResourceMgr>> resource_mgrs;
std::vector<std::unique_ptr<OpKernelContext>> ctxs;
for (int i = 0; i < device_names.size(); ++i) {
resource_mgrs.emplace_back(
absl::make_unique<ResourceMgr>(/* default_container= */ ""));
params[i].resource_manager = resource_mgrs[i].get();
params[i].device = &(devices[i]);
ctxs.emplace_back(
absl::make_unique<OpKernelContext>(&(params[i]), /* num_outputs= */ 0));
}
const string partially_specified_name =
"/job:worker/replica:0/task:0/device:CPU:*";
const string& fully_specified_name = device_names.at(2);
const std::vector<string> allowed_devices = {partially_specified_name,
fully_specified_name};
// Create a ResourceHandle on device 0.
ResourceHandle p = MakeResourceHandle<StubResource>(
ctxs[0].get(), "container", "name",
/* dtypes_and_shapes= */ {}, allowed_devices);
std::vector<StubResource*> resources;
for (const auto& ctx : ctxs) {
StubResource* r = new StubResource;
TF_EXPECT_OK(CreateResource(ctx.get(), p, r));
resources.push_back(r);
}
for (int i = 0; i < ctxs.size(); ++i) {
core::RefCountPtr<StubResource> lookup_r;
TF_EXPECT_OK(LookupResource<StubResource>(ctxs[i].get(), p, &lookup_r));
EXPECT_EQ(lookup_r.get(), resources[i]);
TF_EXPECT_OK(DeleteResource(ctxs[i].get(), p));
}
}
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -222,8 +222,6 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype)); OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
PartialTensorShape shape; PartialTensorShape shape;
OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape)); OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
OP_REQUIRES_OK(context,
context->GetAttr("allowed_devices", &allowed_devices_));
is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME; is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
@ -234,8 +232,7 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
&resource_, attr)); &resource_, attr));
resource_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>( resource_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
context, container_, name_, context, container_, name_,
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_}, std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
allowed_devices_);
} }
} }
@ -248,8 +245,7 @@ void VarHandleOp::Compute(OpKernelContext* ctx) {
ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr)); ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
handle.scalar<ResourceHandle>()() = MakeResourceHandle<Var>( handle.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
ctx, container_, name_, ctx, container_, name_,
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_}, std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
allowed_devices_);
ctx->set_output(0, handle); ctx->set_output(0, handle);
} else { } else {
ctx->set_output(0, resource_); ctx->set_output(0, resource_);

View File

@ -36,10 +36,6 @@ class VarHandleOp : public OpKernel {
Tensor resource_; Tensor resource_;
DtypeAndPartialTensorShape dtype_and_shape_; DtypeAndPartialTensorShape dtype_and_shape_;
// A set of devices containing the resource variable. Set when the output
// ResourceHandle represents a per-replica/partitioned resource variable.
std::vector<string> allowed_devices_;
}; };
class ReadVariableOp : public OpKernel { class ReadVariableOp : public OpKernel {

View File

@ -30,7 +30,6 @@ from tensorflow.core.framework import tensor_pb2
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -1513,41 +1512,5 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
self.assertAllEqual(expected, result) self.assertAllEqual(expected, result)
class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase):
def setUp(self):
super(PerReplicaResourceHandleTest, self).setUp()
cpus = config.list_physical_devices("CPU")
# Set 2 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
])
@test_util.disable_tfrt("Multiple device support. b/154956430")
def testAllowedDevices(self):
device0 = "/job:localhost/replica:0/task:0/device:CPU:0"
device1 = "/job:localhost/replica:0/task:0/device:CPU:1"
value0 = 1
value1 = 2
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[], allowed_devices=[device0, device1])
with ops.device(device0):
assign0 = resource_variable_ops.assign_variable_op(handle, value0)
with ops.device(device1):
assign1 = resource_variable_ops.assign_variable_op(handle, value1)
with ops.control_dependencies([assign0, assign1]):
with ops.device(device0):
read0 = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
with ops.device(device1):
read1 = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
self.assertAllEqual(value0, read0)
self.assertAllEqual(value1, read1)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()