[Cleanup] Remove allowed_devices of ResourceHandle since it's no longer used.
PiperOrigin-RevId: 317710941 Change-Id: Ib1920c5ee25d405290f852b725d693ee5ea09766
This commit is contained in:
parent
aac1dd5788
commit
38d95ad2d8
|
@ -28,8 +28,8 @@ END
|
|||
attr {
|
||||
name: "allowed_devices"
|
||||
description: <<END
|
||||
The allowed devices containing the resource variable. Set when the output
|
||||
ResourceHandle represents a per-replica/partitioned resource variable.
|
||||
DEPRECATED. The allowed devices containing the resource variable. Set when the
|
||||
output ResourceHandle represents a per-replica/partitioned resource variable.
|
||||
END
|
||||
}
|
||||
summary: "Creates a handle to a Variable resource."
|
||||
|
|
|
@ -432,13 +432,12 @@ Status GetOrCreateKernelAndDevice(
|
|||
// looking it up in ResourceMgr, which is slow). So we just get
|
||||
// resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
|
||||
// resource_dtypes_and_shapes is not empty, take the first element.
|
||||
TensorHandle::ResourceHandleInfo resource_handle_info;
|
||||
TF_RETURN_IF_ERROR(input->GetResourceHandleInfo(&resource_handle_info));
|
||||
std::vector<DtypeAndPartialTensorShape>* resource_dtypes_and_shapes =
|
||||
&resource_handle_info.dtypes_and_shapes;
|
||||
if (!resource_dtypes_and_shapes->empty()) {
|
||||
std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
|
||||
TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
|
||||
&resource_dtypes_and_shapes));
|
||||
if (!resource_dtypes_and_shapes.empty()) {
|
||||
const DtypeAndPartialTensorShape& dtype_and_shape =
|
||||
resource_dtypes_and_shapes->at(0);
|
||||
resource_dtypes_and_shapes.at(0);
|
||||
input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
|
||||
|
||||
// Add _Arg index, dtype and shape to "cache_key".
|
||||
|
@ -695,13 +694,8 @@ Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
|
|||
TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
|
||||
const AttrValue* shape;
|
||||
TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
|
||||
TensorHandle::ResourceHandleInfo resource_handle_info = {
|
||||
{DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}, {}};
|
||||
// "allowed_devices" is set only when the output represents a
|
||||
// per-replica/partitioned resource variable.
|
||||
TryGetNodeAttr(attr_slice, "allowed_devices",
|
||||
&resource_handle_info.allowed_devices);
|
||||
retvals[0]->SetResourceHandleInfo(std::move(resource_handle_info));
|
||||
retvals[0]->SetResourceHandleDtypeAndShape(
|
||||
{DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -985,18 +979,6 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
|
|||
// is a resource we must pin it to prevent different device selection.
|
||||
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
|
||||
if (resource_device != op_device || op->Device() == kVariantDeviceNull) {
|
||||
std::vector<string> allowed_devices;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensor_handle->GetResourceAllowedDevices(&allowed_devices));
|
||||
if (!allowed_devices.empty()) {
|
||||
// TODO(b/145922293): Support allowed_devices specified in wildcard
|
||||
// patterns.
|
||||
if (std::find(allowed_devices.begin(), allowed_devices.end(),
|
||||
op->DeviceName()) != allowed_devices.end()) {
|
||||
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(op->DeviceName().c_str(),
|
||||
&resource_device));
|
||||
}
|
||||
}
|
||||
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
|
||||
<< "device of operation " << op->Name() << " to "
|
||||
<< resource_device->name() << " because input #" << i
|
||||
|
|
|
@ -145,13 +145,13 @@ Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void TensorHandle::SetResourceHandleInfo(
|
||||
ResourceHandleInfo&& resource_handle_info) {
|
||||
resource_handle_info_ = std::move(resource_handle_info);
|
||||
void TensorHandle::SetResourceHandleDtypeAndShape(
|
||||
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) {
|
||||
handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
|
||||
}
|
||||
|
||||
Status TensorHandle::GetResourceHandleInfoImpl(
|
||||
std::function<void()> set_resource_info) {
|
||||
Status TensorHandle::GetResourceHandleDtypesAndShapes(
|
||||
std::vector<DtypeAndPartialTensorShape>* result) {
|
||||
if (dtype != DT_RESOURCE) {
|
||||
return errors::InvalidArgument(
|
||||
"TensorHandle::GetResourceDtypeAndShape should be called on tensor "
|
||||
|
@ -160,7 +160,7 @@ Status TensorHandle::GetResourceHandleInfoImpl(
|
|||
}
|
||||
|
||||
if (Type() != LOCAL) {
|
||||
set_resource_info();
|
||||
*result = handle_dtypes_and_shapes_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -170,32 +170,10 @@ Status TensorHandle::GetResourceHandleInfoImpl(
|
|||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
||||
TF_RETURN_IF_ERROR(data.WaitReady("TensorHandle::GetResourceHandleInfo"));
|
||||
|
||||
set_resource_info();
|
||||
*result = handle_dtypes_and_shapes_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::GetResourceHandleInfo(ResourceHandleInfo* result) {
|
||||
auto get_resource_info = [result, this]() {
|
||||
*result = resource_handle_info_;
|
||||
};
|
||||
return GetResourceHandleInfoImpl(get_resource_info);
|
||||
}
|
||||
|
||||
Status TensorHandle::GetResourceHandleDtypesAndShapes(
|
||||
std::vector<DtypeAndPartialTensorShape>* result) {
|
||||
auto get_resource_info = [result, this]() {
|
||||
*result = resource_handle_info_.dtypes_and_shapes;
|
||||
};
|
||||
return GetResourceHandleInfoImpl(get_resource_info);
|
||||
}
|
||||
|
||||
Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
|
||||
auto get_resource_info = [result, this]() {
|
||||
*result = resource_handle_info_.allowed_devices;
|
||||
};
|
||||
return GetResourceHandleInfoImpl(get_resource_info);
|
||||
}
|
||||
|
||||
int TensorHandle::NumPackedHandles() const {
|
||||
if (Type() != PACKED) {
|
||||
return 0;
|
||||
|
@ -270,9 +248,8 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
|||
resource_remote_device_incarnation_(
|
||||
GetRemoteDeviceIncarnation(resource_device_)),
|
||||
ctx_(ctx),
|
||||
resource_handle_info_(
|
||||
{t.flat<class ResourceHandle>()(0).dtypes_and_shapes(),
|
||||
t.flat<class ResourceHandle>()(0).allowed_devices()}),
|
||||
handle_dtypes_and_shapes_(
|
||||
t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
|
||||
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||
<< " device: " << VariantDeviceDebugString(device_)
|
||||
|
@ -327,10 +304,10 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
|||
return errors::InvalidArgument("Handles should not be empty.");
|
||||
}
|
||||
|
||||
ResourceHandleInfo resource_handle_info;
|
||||
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
|
||||
if (dtype == DT_RESOURCE) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
handles.at(0)->GetResourceHandleInfo(&resource_handle_info));
|
||||
handles.at(0)->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes));
|
||||
}
|
||||
std::vector<string> devices;
|
||||
for (auto* handle : handles) {
|
||||
|
@ -348,7 +325,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
|||
&composite_device));
|
||||
*packed_handle =
|
||||
new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx);
|
||||
(*packed_handle)->SetResourceHandleInfo(std::move(resource_handle_info));
|
||||
(*packed_handle)
|
||||
->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -898,8 +876,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
|
|||
|
||||
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
|
||||
auto& resource_handle = t.flat<class ResourceHandle>()(0);
|
||||
resource_handle_info_ = {resource_handle.dtypes_and_shapes(),
|
||||
resource_handle.allowed_devices()};
|
||||
handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
|
||||
}
|
||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
||||
return data.SetTensor(std::move(t));
|
||||
|
|
|
@ -226,19 +226,13 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
|
|||
|
||||
string DebugString() const;
|
||||
|
||||
struct ResourceHandleInfo {
|
||||
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
|
||||
std::vector<string> allowed_devices;
|
||||
};
|
||||
|
||||
void SetResourceHandleInfo(ResourceHandleInfo&& resource_handle_info);
|
||||
void SetResourceHandleDtypeAndShape(
|
||||
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
|
||||
|
||||
// If this TensorHandle is 1) a local tensor, and 2) a resource handle,
|
||||
// return data types, shapes and allowed devices of the underlying resource.
|
||||
Status GetResourceHandleInfo(ResourceHandleInfo* result);
|
||||
// return data types and shapes of the underlying resource.
|
||||
Status GetResourceHandleDtypesAndShapes(
|
||||
std::vector<DtypeAndPartialTensorShape>* result);
|
||||
Status GetResourceAllowedDevices(std::vector<string>* result);
|
||||
|
||||
// Returns the number of packed handles. 0 if the handle type is not PACKED.
|
||||
int NumPackedHandles() const;
|
||||
|
@ -261,8 +255,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
|
|||
// with a ready version of the tensor handle data.
|
||||
bool IsReady() const;
|
||||
|
||||
Status GetResourceHandleInfoImpl(std::function<void()> set_resource_info);
|
||||
|
||||
VariantDevice const device_;
|
||||
|
||||
// Device in which the op producing this tensor was executed. Equals to
|
||||
|
@ -308,9 +300,9 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
|
|||
Status is_poisoned_;
|
||||
|
||||
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
|
||||
// refers to a remote resource handle, we store data types, shapes and allowed
|
||||
// devices for the underlying resource.
|
||||
ResourceHandleInfo resource_handle_info_;
|
||||
// refers to a remote resource handle, we store data types and shapes for
|
||||
// the underlying resource.
|
||||
std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
|
||||
|
||||
// A handle data which refers to multiple TensorHandles of the same dtype and
|
||||
// shape.
|
||||
|
|
|
@ -150,13 +150,13 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
|
|||
Device* d0 = ListDevices().at(0);
|
||||
TensorHandle* h0 =
|
||||
TensorHandle::CreateLocalHandle(std::move(t0), d0, d0, d0, context());
|
||||
h0->SetResourceHandleInfo({{dtype_and_shape}, {}});
|
||||
h0->SetResourceHandleDtypeAndShape({dtype_and_shape});
|
||||
handles.push_back(h0);
|
||||
Tensor t1(dtype, shape);
|
||||
Device* d1 = ListDevices().at(1);
|
||||
TensorHandle* h1 =
|
||||
TensorHandle::CreateLocalHandle(std::move(t1), d1, d1, d1, context());
|
||||
h1->SetResourceHandleInfo({{dtype_and_shape}, {}});
|
||||
h1->SetResourceHandleDtypeAndShape({dtype_and_shape});
|
||||
handles.push_back(h1);
|
||||
|
||||
// Create 2 remote TensorHandles (not ready).
|
||||
|
@ -185,13 +185,12 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
|
|||
TensorShape packed_shape;
|
||||
TF_ASSERT_OK(packed_handle->Shape(&packed_shape));
|
||||
EXPECT_EQ(packed_shape, shape);
|
||||
TensorHandle::ResourceHandleInfo resource_handle_info;
|
||||
TF_ASSERT_OK(packed_handle->GetResourceHandleInfo(&resource_handle_info));
|
||||
EXPECT_EQ(resource_handle_info.dtypes_and_shapes.size(), 1);
|
||||
EXPECT_EQ(resource_handle_info.dtypes_and_shapes.at(0).dtype, DT_FLOAT);
|
||||
EXPECT_EQ(
|
||||
resource_handle_info.dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}),
|
||||
true);
|
||||
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
|
||||
TF_ASSERT_OK(
|
||||
packed_handle->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes));
|
||||
EXPECT_EQ(dtypes_and_shapes.size(), 1);
|
||||
EXPECT_EQ(dtypes_and_shapes.at(0).dtype, DT_FLOAT);
|
||||
EXPECT_EQ(dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}), true);
|
||||
|
||||
CompositeDevice* device = reinterpret_cast<CompositeDevice*>(
|
||||
absl::get<Device*>(packed_handle->device()));
|
||||
|
|
|
@ -167,24 +167,22 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
|
|||
parent_->FindDeviceFromName(device_name.c_str(), &device));
|
||||
*out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
|
||||
in.dtype(), device, parent_);
|
||||
TensorHandle::ResourceHandleInfo resource_handle_info;
|
||||
std::vector<DtypeAndPartialTensorShape>* dtypes_and_shapes =
|
||||
&resource_handle_info.dtypes_and_shapes;
|
||||
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
|
||||
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
|
||||
dtypes_and_shapes)
|
||||
&dtypes_and_shapes)
|
||||
.ok()) {
|
||||
for (const auto& dtype_and_shape_proto :
|
||||
in.resource_dtypes_and_shapes()) {
|
||||
dtypes_and_shapes->push_back(DtypeAndPartialTensorShape{
|
||||
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
|
||||
dtype_and_shape_proto.dtype(),
|
||||
TensorShape(dtype_and_shape_proto.shape())});
|
||||
}
|
||||
mutex_lock l(mirrored_resource_shape_mu_);
|
||||
mirrored_resource_shape_map_.emplace(
|
||||
RemoteTensorHandleInternal(in.op_id(), in.output_num()),
|
||||
*dtypes_and_shapes);
|
||||
dtypes_and_shapes);
|
||||
}
|
||||
(*out)->SetResourceHandleInfo(std::move(resource_handle_info));
|
||||
(*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -38,9 +38,6 @@ void ResourceHandle::AsProto(ResourceHandleProto* proto) const {
|
|||
dtype_and_shape->set_dtype(dtype_and_shape_pair.dtype);
|
||||
dtype_and_shape_pair.shape.AsProto(dtype_and_shape->mutable_shape());
|
||||
}
|
||||
for (const string& device : allowed_devices_) {
|
||||
*proto->add_allowed_devices() = device;
|
||||
}
|
||||
}
|
||||
|
||||
void ResourceHandle::FromProto(const ResourceHandleProto& proto) {
|
||||
|
@ -56,9 +53,6 @@ void ResourceHandle::FromProto(const ResourceHandleProto& proto) {
|
|||
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{dtype, shape});
|
||||
}
|
||||
dtypes_and_shapes_ = std::move(dtypes_and_shapes);
|
||||
for (const string& device : proto.allowed_devices()) {
|
||||
allowed_devices_.push_back(device);
|
||||
}
|
||||
}
|
||||
|
||||
string ResourceHandle::SerializeAsString() const {
|
||||
|
|
|
@ -39,14 +39,8 @@ class ResourceHandle {
|
|||
|
||||
// Unique name for the device containing the resource.
|
||||
const std::string& device() const { return device_; }
|
||||
// Names of the devices containing the resource.
|
||||
const std::vector<string>& allowed_devices() const {
|
||||
return allowed_devices_;
|
||||
}
|
||||
|
||||
void set_device(const std::string& device) { device_ = device; }
|
||||
void set_allowed_devices(const std::vector<string>& devices) {
|
||||
allowed_devices_ = devices;
|
||||
}
|
||||
|
||||
// Container in which this resource is placed.
|
||||
const std::string& container() const { return container_; }
|
||||
|
@ -93,12 +87,7 @@ class ResourceHandle {
|
|||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
|
||||
public:
|
||||
// The default device containing the resource, where the ResourceHandle is
|
||||
// initially created.
|
||||
std::string device_;
|
||||
// A set of devices containing the resource. If empty, the resource only
|
||||
// exists on device_. Can be represented in wildcard patterns.
|
||||
std::vector<string> allowed_devices_;
|
||||
std::string container_;
|
||||
std::string name_;
|
||||
uint64 hash_code_ = 0;
|
||||
|
|
|
@ -41,7 +41,5 @@ message ResourceHandleProto {
|
|||
// Data types and shapes for the underlying resource.
|
||||
repeated DtypeAndShape dtypes_and_shapes = 6;
|
||||
|
||||
// A set of devices containing the resource. If empty, the resource only
|
||||
// exists on `device`.
|
||||
repeated string allowed_devices = 7;
|
||||
reserved 7;
|
||||
}
|
||||
|
|
|
@ -36,8 +36,7 @@ static std::atomic<int64> current_id_;
|
|||
ResourceHandle MakeResourceHandle(
|
||||
const string& container, const string& name, const DeviceBase& device,
|
||||
const TypeIndex& type_index,
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes,
|
||||
const std::vector<string>& allowed_devices) {
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes) {
|
||||
ResourceHandle result;
|
||||
result.set_device(device.name());
|
||||
result.set_container(container);
|
||||
|
@ -49,7 +48,6 @@ ResourceHandle MakeResourceHandle(
|
|||
result.set_hash_code(type_index.hash_code());
|
||||
result.set_maybe_type_name(type_index.name());
|
||||
result.set_dtypes_and_shapes(dtypes_and_shapes);
|
||||
result.set_allowed_devices(allowed_devices);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -67,39 +65,12 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
|
|||
namespace internal {
|
||||
|
||||
Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) {
|
||||
const string& current_device_name = ctx->device()->attributes().name();
|
||||
if (current_device_name == p.device()) {
|
||||
return Status::OK();
|
||||
}
|
||||
DeviceNameUtils::ParsedName parsed_current_device_name;
|
||||
if (!DeviceNameUtils::ParseFullName(current_device_name,
|
||||
&parsed_current_device_name)) {
|
||||
if (ctx->device()->attributes().name() != p.device()) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot parse device name in OpKernelContext: ", current_device_name);
|
||||
"Trying to access resource ", p.name(), " located in device ",
|
||||
p.device(), " from device ", ctx->device()->attributes().name());
|
||||
}
|
||||
|
||||
for (const string& device : p.allowed_devices()) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
|
||||
return errors::InvalidArgument("Cannot parse allowed device name: ",
|
||||
device);
|
||||
}
|
||||
if (DeviceNameUtils::IsCompleteSpecification(parsed,
|
||||
parsed_current_device_name)) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
string error_message = strings::StrCat("Trying to access resource ", p.name(),
|
||||
" located in device ", p.device(),
|
||||
" from device ", current_device_name);
|
||||
if (!p.allowed_devices().empty()) {
|
||||
absl::StrAppend(&error_message, " (allowed devices: ");
|
||||
for (const string& device : p.allowed_devices()) {
|
||||
absl::StrAppend(&error_message, device, ", ");
|
||||
}
|
||||
absl::StrAppend(&error_message, ") ");
|
||||
}
|
||||
return errors::InvalidArgument(error_message);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
|
|
@ -291,31 +291,27 @@ class ResourceMgr {
|
|||
ResourceHandle MakeResourceHandle(
|
||||
const string& container, const string& name, const DeviceBase& device,
|
||||
const TypeIndex& type_index,
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
|
||||
const std::vector<string>& allowed_devices = {}) TF_MUST_USE_RESULT;
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {})
|
||||
TF_MUST_USE_RESULT;
|
||||
|
||||
template <typename T>
|
||||
ResourceHandle MakeResourceHandle(
|
||||
OpKernelContext* ctx, const string& container, const string& name,
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
|
||||
const std::vector<string>& allowed_devices = {}) {
|
||||
return MakeResourceHandle(container.empty()
|
||||
? ctx->resource_manager()->default_container()
|
||||
: container,
|
||||
name, *ctx->device(), MakeTypeIndex<T>(),
|
||||
dtypes_and_shapes, allowed_devices);
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
|
||||
return MakeResourceHandle(
|
||||
container.empty() ? ctx->resource_manager()->default_container()
|
||||
: container,
|
||||
name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ResourceHandle MakeResourceHandle(
|
||||
OpKernelConstruction* ctx, const string& container, const string& name,
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
|
||||
const std::vector<string>& allowed_devices = {}) {
|
||||
return MakeResourceHandle(container.empty()
|
||||
? ctx->resource_manager()->default_container()
|
||||
: container,
|
||||
name, *ctx->device(), MakeTypeIndex<T>(),
|
||||
dtypes_and_shapes, allowed_devices);
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
|
||||
return MakeResourceHandle(
|
||||
container.empty() ? ctx->resource_manager()->default_container()
|
||||
: container,
|
||||
name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
|
||||
}
|
||||
|
||||
Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
|
||||
|
|
|
@ -352,51 +352,4 @@ TEST(ResourceHandleTest, DeleteUsingResourceHandle) {
|
|||
EXPECT_NE(LookupResource<StubResource>(&ctx, p, &lookup_r).ok(), true);
|
||||
}
|
||||
|
||||
TEST(ResourceHandleTest, AllowedDevices) {
|
||||
const std::vector<string> device_names = {
|
||||
"/job:worker/replica:0/task:0/device:CPU:0",
|
||||
"/job:worker/replica:0/task:0/device:CPU:2",
|
||||
"/job:worker/replica:1/task:3/device:CPU:5"};
|
||||
std::vector<StubDevice> devices;
|
||||
for (const string& name : device_names) {
|
||||
devices.emplace_back(name);
|
||||
}
|
||||
|
||||
std::vector<OpKernelContext::Params> params(device_names.size());
|
||||
std::vector<std::unique_ptr<ResourceMgr>> resource_mgrs;
|
||||
std::vector<std::unique_ptr<OpKernelContext>> ctxs;
|
||||
for (int i = 0; i < device_names.size(); ++i) {
|
||||
resource_mgrs.emplace_back(
|
||||
absl::make_unique<ResourceMgr>(/* default_container= */ ""));
|
||||
params[i].resource_manager = resource_mgrs[i].get();
|
||||
params[i].device = &(devices[i]);
|
||||
ctxs.emplace_back(
|
||||
absl::make_unique<OpKernelContext>(&(params[i]), /* num_outputs= */ 0));
|
||||
}
|
||||
|
||||
const string partially_specified_name =
|
||||
"/job:worker/replica:0/task:0/device:CPU:*";
|
||||
const string& fully_specified_name = device_names.at(2);
|
||||
const std::vector<string> allowed_devices = {partially_specified_name,
|
||||
fully_specified_name};
|
||||
// Create a ResourceHandle on device 0.
|
||||
ResourceHandle p = MakeResourceHandle<StubResource>(
|
||||
ctxs[0].get(), "container", "name",
|
||||
/* dtypes_and_shapes= */ {}, allowed_devices);
|
||||
|
||||
std::vector<StubResource*> resources;
|
||||
for (const auto& ctx : ctxs) {
|
||||
StubResource* r = new StubResource;
|
||||
TF_EXPECT_OK(CreateResource(ctx.get(), p, r));
|
||||
resources.push_back(r);
|
||||
}
|
||||
|
||||
for (int i = 0; i < ctxs.size(); ++i) {
|
||||
core::RefCountPtr<StubResource> lookup_r;
|
||||
TF_EXPECT_OK(LookupResource<StubResource>(ctxs[i].get(), p, &lookup_r));
|
||||
EXPECT_EQ(lookup_r.get(), resources[i]);
|
||||
TF_EXPECT_OK(DeleteResource(ctxs[i].get(), p));
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
|
|
@ -222,8 +222,6 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
|
|||
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
|
||||
PartialTensorShape shape;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("allowed_devices", &allowed_devices_));
|
||||
|
||||
is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
|
||||
|
||||
|
@ -234,8 +232,7 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
|
|||
&resource_, attr));
|
||||
resource_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
|
||||
context, container_, name_,
|
||||
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
|
||||
allowed_devices_);
|
||||
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -248,8 +245,7 @@ void VarHandleOp::Compute(OpKernelContext* ctx) {
|
|||
ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
|
||||
handle.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
|
||||
ctx, container_, name_,
|
||||
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
|
||||
allowed_devices_);
|
||||
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
|
||||
ctx->set_output(0, handle);
|
||||
} else {
|
||||
ctx->set_output(0, resource_);
|
||||
|
|
|
@ -36,10 +36,6 @@ class VarHandleOp : public OpKernel {
|
|||
Tensor resource_;
|
||||
|
||||
DtypeAndPartialTensorShape dtype_and_shape_;
|
||||
|
||||
// A set of devices containing the resource variable. Set when the output
|
||||
// ResourceHandle represents a per-replica/partitioned resource variable.
|
||||
std::vector<string> allowed_devices_;
|
||||
};
|
||||
|
||||
class ReadVariableOp : public OpKernel {
|
||||
|
|
|
@ -30,7 +30,6 @@ from tensorflow.core.framework import tensor_pb2
|
|||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -1513,41 +1512,5 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
self.assertAllEqual(expected, result)
|
||||
|
||||
|
||||
class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(PerReplicaResourceHandleTest, self).setUp()
|
||||
cpus = config.list_physical_devices("CPU")
|
||||
# Set 2 virtual CPUs
|
||||
config.set_logical_device_configuration(cpus[0], [
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration(),
|
||||
])
|
||||
|
||||
@test_util.disable_tfrt("Multiple device support. b/154956430")
|
||||
def testAllowedDevices(self):
|
||||
device0 = "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
device1 = "/job:localhost/replica:0/task:0/device:CPU:1"
|
||||
value0 = 1
|
||||
value1 = 2
|
||||
with context.eager_mode():
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[], allowed_devices=[device0, device1])
|
||||
with ops.device(device0):
|
||||
assign0 = resource_variable_ops.assign_variable_op(handle, value0)
|
||||
with ops.device(device1):
|
||||
assign1 = resource_variable_ops.assign_variable_op(handle, value1)
|
||||
with ops.control_dependencies([assign0, assign1]):
|
||||
with ops.device(device0):
|
||||
read0 = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.int32)
|
||||
with ops.device(device1):
|
||||
read1 = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.int32)
|
||||
|
||||
self.assertAllEqual(value0, read0)
|
||||
self.assertAllEqual(value1, read1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue