[Cleanup] Remove allowed_devices of ResourceHandle since it's no longer used.

PiperOrigin-RevId: 317710941
Change-Id: Ib1920c5ee25d405290f852b725d693ee5ea09766
This commit is contained in:
Yujing Zhang 2020-06-22 12:29:26 -07:00 committed by TensorFlower Gardener
parent aac1dd5788
commit 38d95ad2d8
15 changed files with 63 additions and 259 deletions

View File

@ -28,8 +28,8 @@ END
attr {
name: "allowed_devices"
description: <<END
The allowed devices containing the resource variable. Set when the output
ResourceHandle represents a per-replica/partitioned resource variable.
DEPRECATED. The allowed devices containing the resource variable. Set when the
output ResourceHandle represents a per-replica/partitioned resource variable.
END
}
summary: "Creates a handle to a Variable resource."

View File

@ -432,13 +432,12 @@ Status GetOrCreateKernelAndDevice(
// looking it up in ResourceMgr, which is slow). So we just get
// resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
// resource_dtypes_and_shapes is not empty, take the first element.
TensorHandle::ResourceHandleInfo resource_handle_info;
TF_RETURN_IF_ERROR(input->GetResourceHandleInfo(&resource_handle_info));
std::vector<DtypeAndPartialTensorShape>* resource_dtypes_and_shapes =
&resource_handle_info.dtypes_and_shapes;
if (!resource_dtypes_and_shapes->empty()) {
std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
&resource_dtypes_and_shapes));
if (!resource_dtypes_and_shapes.empty()) {
const DtypeAndPartialTensorShape& dtype_and_shape =
resource_dtypes_and_shapes->at(0);
resource_dtypes_and_shapes.at(0);
input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
// Add _Arg index, dtype and shape to "cache_key".
@ -695,13 +694,8 @@ Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
const AttrValue* shape;
TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
TensorHandle::ResourceHandleInfo resource_handle_info = {
{DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}, {}};
// "allowed_devices" is set only when the output represents a
// per-replica/partitioned resource variable.
TryGetNodeAttr(attr_slice, "allowed_devices",
&resource_handle_info.allowed_devices);
retvals[0]->SetResourceHandleInfo(std::move(resource_handle_info));
retvals[0]->SetResourceHandleDtypeAndShape(
{DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
}
return Status::OK();
}
@ -985,18 +979,6 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
// is a resource we must pin it to prevent different device selection.
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
if (resource_device != op_device || op->Device() == kVariantDeviceNull) {
std::vector<string> allowed_devices;
TF_RETURN_IF_ERROR(
tensor_handle->GetResourceAllowedDevices(&allowed_devices));
if (!allowed_devices.empty()) {
// TODO(b/145922293): Support allowed_devices specified in wildcard
// patterns.
if (std::find(allowed_devices.begin(), allowed_devices.end(),
op->DeviceName()) != allowed_devices.end()) {
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(op->DeviceName().c_str(),
&resource_device));
}
}
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
<< "device of operation " << op->Name() << " to "
<< resource_device->name() << " because input #" << i

View File

@ -145,13 +145,13 @@ Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
return Status::OK();
}
void TensorHandle::SetResourceHandleInfo(
ResourceHandleInfo&& resource_handle_info) {
resource_handle_info_ = std::move(resource_handle_info);
void TensorHandle::SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) {
handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
}
Status TensorHandle::GetResourceHandleInfoImpl(
std::function<void()> set_resource_info) {
Status TensorHandle::GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result) {
if (dtype != DT_RESOURCE) {
return errors::InvalidArgument(
"TensorHandle::GetResourceDtypeAndShape should be called on tensor "
@ -160,7 +160,7 @@ Status TensorHandle::GetResourceHandleInfoImpl(
}
if (Type() != LOCAL) {
set_resource_info();
*result = handle_dtypes_and_shapes_;
return Status::OK();
}
@ -170,32 +170,10 @@ Status TensorHandle::GetResourceHandleInfoImpl(
auto& data = absl::get<LocalTensorHandleData>(data_);
TF_RETURN_IF_ERROR(data.WaitReady("TensorHandle::GetResourceHandleInfo"));
set_resource_info();
*result = handle_dtypes_and_shapes_;
return Status::OK();
}
Status TensorHandle::GetResourceHandleInfo(ResourceHandleInfo* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_.dtypes_and_shapes;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_.allowed_devices;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
int TensorHandle::NumPackedHandles() const {
if (Type() != PACKED) {
return 0;
@ -270,9 +248,8 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
resource_remote_device_incarnation_(
GetRemoteDeviceIncarnation(resource_device_)),
ctx_(ctx),
resource_handle_info_(
{t.flat<class ResourceHandle>()(0).dtypes_and_shapes(),
t.flat<class ResourceHandle>()(0).allowed_devices()}),
handle_dtypes_and_shapes_(
t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_)
@ -327,10 +304,10 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
return errors::InvalidArgument("Handles should not be empty.");
}
ResourceHandleInfo resource_handle_info;
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
if (dtype == DT_RESOURCE) {
TF_RETURN_IF_ERROR(
handles.at(0)->GetResourceHandleInfo(&resource_handle_info));
handles.at(0)->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes));
}
std::vector<string> devices;
for (auto* handle : handles) {
@ -348,7 +325,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
&composite_device));
*packed_handle =
new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx);
(*packed_handle)->SetResourceHandleInfo(std::move(resource_handle_info));
(*packed_handle)
->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
return Status::OK();
}
@ -898,8 +876,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
auto& resource_handle = t.flat<class ResourceHandle>()(0);
resource_handle_info_ = {resource_handle.dtypes_and_shapes(),
resource_handle.allowed_devices()};
handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
}
auto& data = absl::get<LocalTensorHandleData>(data_);
return data.SetTensor(std::move(t));

View File

@ -226,19 +226,13 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
string DebugString() const;
struct ResourceHandleInfo {
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
std::vector<string> allowed_devices;
};
void SetResourceHandleInfo(ResourceHandleInfo&& resource_handle_info);
void SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
// If this TensorHandle is 1) a local tensor, and 2) a resource handle,
// return data types, shapes and allowed devices of the underlying resource.
Status GetResourceHandleInfo(ResourceHandleInfo* result);
// return data types and shapes of the underlying resource.
Status GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result);
Status GetResourceAllowedDevices(std::vector<string>* result);
// Returns the number of packed handles. 0 if the handle type is not PACKED.
int NumPackedHandles() const;
@ -261,8 +255,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
// with a ready version of the tensor handle data.
bool IsReady() const;
Status GetResourceHandleInfoImpl(std::function<void()> set_resource_info);
VariantDevice const device_;
// Device in which the op producing this tensor was executed. Equals to
@ -308,9 +300,9 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
Status is_poisoned_;
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
// refers to a remote resource handle, we store data types, shapes and allowed
// devices for the underlying resource.
ResourceHandleInfo resource_handle_info_;
// refers to a remote resource handle, we store data types and shapes for
// the underlying resource.
std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
// A handle data which refers to multiple TensorHandles of the same dtype and
// shape.

View File

@ -150,13 +150,13 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
Device* d0 = ListDevices().at(0);
TensorHandle* h0 =
TensorHandle::CreateLocalHandle(std::move(t0), d0, d0, d0, context());
h0->SetResourceHandleInfo({{dtype_and_shape}, {}});
h0->SetResourceHandleDtypeAndShape({dtype_and_shape});
handles.push_back(h0);
Tensor t1(dtype, shape);
Device* d1 = ListDevices().at(1);
TensorHandle* h1 =
TensorHandle::CreateLocalHandle(std::move(t1), d1, d1, d1, context());
h1->SetResourceHandleInfo({{dtype_and_shape}, {}});
h1->SetResourceHandleDtypeAndShape({dtype_and_shape});
handles.push_back(h1);
// Create 2 remote TensorHandles (not ready).
@ -185,13 +185,12 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
TensorShape packed_shape;
TF_ASSERT_OK(packed_handle->Shape(&packed_shape));
EXPECT_EQ(packed_shape, shape);
TensorHandle::ResourceHandleInfo resource_handle_info;
TF_ASSERT_OK(packed_handle->GetResourceHandleInfo(&resource_handle_info));
EXPECT_EQ(resource_handle_info.dtypes_and_shapes.size(), 1);
EXPECT_EQ(resource_handle_info.dtypes_and_shapes.at(0).dtype, DT_FLOAT);
EXPECT_EQ(
resource_handle_info.dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}),
true);
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
TF_ASSERT_OK(
packed_handle->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes));
EXPECT_EQ(dtypes_and_shapes.size(), 1);
EXPECT_EQ(dtypes_and_shapes.at(0).dtype, DT_FLOAT);
EXPECT_EQ(dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}), true);
CompositeDevice* device = reinterpret_cast<CompositeDevice*>(
absl::get<Device*>(packed_handle->device()));

View File

@ -167,24 +167,22 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
parent_->FindDeviceFromName(device_name.c_str(), &device));
*out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
in.dtype(), device, parent_);
TensorHandle::ResourceHandleInfo resource_handle_info;
std::vector<DtypeAndPartialTensorShape>* dtypes_and_shapes =
&resource_handle_info.dtypes_and_shapes;
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
dtypes_and_shapes)
&dtypes_and_shapes)
.ok()) {
for (const auto& dtype_and_shape_proto :
in.resource_dtypes_and_shapes()) {
dtypes_and_shapes->push_back(DtypeAndPartialTensorShape{
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
dtype_and_shape_proto.dtype(),
TensorShape(dtype_and_shape_proto.shape())});
}
mutex_lock l(mirrored_resource_shape_mu_);
mirrored_resource_shape_map_.emplace(
RemoteTensorHandleInternal(in.op_id(), in.output_num()),
*dtypes_and_shapes);
dtypes_and_shapes);
}
(*out)->SetResourceHandleInfo(std::move(resource_handle_info));
(*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
}
return Status::OK();

View File

@ -38,9 +38,6 @@ void ResourceHandle::AsProto(ResourceHandleProto* proto) const {
dtype_and_shape->set_dtype(dtype_and_shape_pair.dtype);
dtype_and_shape_pair.shape.AsProto(dtype_and_shape->mutable_shape());
}
for (const string& device : allowed_devices_) {
*proto->add_allowed_devices() = device;
}
}
void ResourceHandle::FromProto(const ResourceHandleProto& proto) {
@ -56,9 +53,6 @@ void ResourceHandle::FromProto(const ResourceHandleProto& proto) {
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{dtype, shape});
}
dtypes_and_shapes_ = std::move(dtypes_and_shapes);
for (const string& device : proto.allowed_devices()) {
allowed_devices_.push_back(device);
}
}
string ResourceHandle::SerializeAsString() const {

View File

@ -39,14 +39,8 @@ class ResourceHandle {
// Unique name for the device containing the resource.
const std::string& device() const { return device_; }
// Names of the devices containing the resource.
const std::vector<string>& allowed_devices() const {
return allowed_devices_;
}
void set_device(const std::string& device) { device_ = device; }
void set_allowed_devices(const std::vector<string>& devices) {
allowed_devices_ = devices;
}
// Container in which this resource is placed.
const std::string& container() const { return container_; }
@ -93,12 +87,7 @@ class ResourceHandle {
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
public:
// The default device containing the resource, where the ResourceHandle is
// initially created.
std::string device_;
// A set of devices containing the resource. If empty, the resource only
// exists on device_. Can be represented in wildcard patterns.
std::vector<string> allowed_devices_;
std::string container_;
std::string name_;
uint64 hash_code_ = 0;

View File

@ -41,7 +41,5 @@ message ResourceHandleProto {
// Data types and shapes for the underlying resource.
repeated DtypeAndShape dtypes_and_shapes = 6;
// A set of devices containing the resource. If empty, the resource only
// exists on `device`.
repeated string allowed_devices = 7;
reserved 7;
}

View File

@ -36,8 +36,7 @@ static std::atomic<int64> current_id_;
ResourceHandle MakeResourceHandle(
const string& container, const string& name, const DeviceBase& device,
const TypeIndex& type_index,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes,
const std::vector<string>& allowed_devices) {
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes) {
ResourceHandle result;
result.set_device(device.name());
result.set_container(container);
@ -49,7 +48,6 @@ ResourceHandle MakeResourceHandle(
result.set_hash_code(type_index.hash_code());
result.set_maybe_type_name(type_index.name());
result.set_dtypes_and_shapes(dtypes_and_shapes);
result.set_allowed_devices(allowed_devices);
return result;
}
@ -67,39 +65,12 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
namespace internal {
Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) {
const string& current_device_name = ctx->device()->attributes().name();
if (current_device_name == p.device()) {
return Status::OK();
}
DeviceNameUtils::ParsedName parsed_current_device_name;
if (!DeviceNameUtils::ParseFullName(current_device_name,
&parsed_current_device_name)) {
if (ctx->device()->attributes().name() != p.device()) {
return errors::InvalidArgument(
"Cannot parse device name in OpKernelContext: ", current_device_name);
"Trying to access resource ", p.name(), " located in device ",
p.device(), " from device ", ctx->device()->attributes().name());
}
for (const string& device : p.allowed_devices()) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
return errors::InvalidArgument("Cannot parse allowed device name: ",
device);
}
if (DeviceNameUtils::IsCompleteSpecification(parsed,
parsed_current_device_name)) {
return Status::OK();
}
}
string error_message = strings::StrCat("Trying to access resource ", p.name(),
" located in device ", p.device(),
" from device ", current_device_name);
if (!p.allowed_devices().empty()) {
absl::StrAppend(&error_message, " (allowed devices: ");
for (const string& device : p.allowed_devices()) {
absl::StrAppend(&error_message, device, ", ");
}
absl::StrAppend(&error_message, ") ");
}
return errors::InvalidArgument(error_message);
return Status::OK();
}
} // end namespace internal

View File

@ -291,31 +291,27 @@ class ResourceMgr {
ResourceHandle MakeResourceHandle(
const string& container, const string& name, const DeviceBase& device,
const TypeIndex& type_index,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
const std::vector<string>& allowed_devices = {}) TF_MUST_USE_RESULT;
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {})
TF_MUST_USE_RESULT;
template <typename T>
ResourceHandle MakeResourceHandle(
OpKernelContext* ctx, const string& container, const string& name,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
const std::vector<string>& allowed_devices = {}) {
return MakeResourceHandle(container.empty()
? ctx->resource_manager()->default_container()
: container,
name, *ctx->device(), MakeTypeIndex<T>(),
dtypes_and_shapes, allowed_devices);
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
return MakeResourceHandle(
container.empty() ? ctx->resource_manager()->default_container()
: container,
name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
}
template <typename T>
ResourceHandle MakeResourceHandle(
OpKernelConstruction* ctx, const string& container, const string& name,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
const std::vector<string>& allowed_devices = {}) {
return MakeResourceHandle(container.empty()
? ctx->resource_manager()->default_container()
: container,
name, *ctx->device(), MakeTypeIndex<T>(),
dtypes_and_shapes, allowed_devices);
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
return MakeResourceHandle(
container.empty() ? ctx->resource_manager()->default_container()
: container,
name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
}
Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,

View File

@ -352,51 +352,4 @@ TEST(ResourceHandleTest, DeleteUsingResourceHandle) {
EXPECT_NE(LookupResource<StubResource>(&ctx, p, &lookup_r).ok(), true);
}
TEST(ResourceHandleTest, AllowedDevices) {
const std::vector<string> device_names = {
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:0/device:CPU:2",
"/job:worker/replica:1/task:3/device:CPU:5"};
std::vector<StubDevice> devices;
for (const string& name : device_names) {
devices.emplace_back(name);
}
std::vector<OpKernelContext::Params> params(device_names.size());
std::vector<std::unique_ptr<ResourceMgr>> resource_mgrs;
std::vector<std::unique_ptr<OpKernelContext>> ctxs;
for (int i = 0; i < device_names.size(); ++i) {
resource_mgrs.emplace_back(
absl::make_unique<ResourceMgr>(/* default_container= */ ""));
params[i].resource_manager = resource_mgrs[i].get();
params[i].device = &(devices[i]);
ctxs.emplace_back(
absl::make_unique<OpKernelContext>(&(params[i]), /* num_outputs= */ 0));
}
const string partially_specified_name =
"/job:worker/replica:0/task:0/device:CPU:*";
const string& fully_specified_name = device_names.at(2);
const std::vector<string> allowed_devices = {partially_specified_name,
fully_specified_name};
// Create a ResourceHandle on device 0.
ResourceHandle p = MakeResourceHandle<StubResource>(
ctxs[0].get(), "container", "name",
/* dtypes_and_shapes= */ {}, allowed_devices);
std::vector<StubResource*> resources;
for (const auto& ctx : ctxs) {
StubResource* r = new StubResource;
TF_EXPECT_OK(CreateResource(ctx.get(), p, r));
resources.push_back(r);
}
for (int i = 0; i < ctxs.size(); ++i) {
core::RefCountPtr<StubResource> lookup_r;
TF_EXPECT_OK(LookupResource<StubResource>(ctxs[i].get(), p, &lookup_r));
EXPECT_EQ(lookup_r.get(), resources[i]);
TF_EXPECT_OK(DeleteResource(ctxs[i].get(), p));
}
}
} // end namespace tensorflow

View File

@ -222,8 +222,6 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
PartialTensorShape shape;
OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
OP_REQUIRES_OK(context,
context->GetAttr("allowed_devices", &allowed_devices_));
is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
@ -234,8 +232,7 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
&resource_, attr));
resource_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
context, container_, name_,
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
allowed_devices_);
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
}
}
@ -248,8 +245,7 @@ void VarHandleOp::Compute(OpKernelContext* ctx) {
ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
handle.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
ctx, container_, name_,
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
allowed_devices_);
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
ctx->set_output(0, handle);
} else {
ctx->set_output(0, resource_);

View File

@ -36,10 +36,6 @@ class VarHandleOp : public OpKernel {
Tensor resource_;
DtypeAndPartialTensorShape dtype_and_shape_;
// A set of devices containing the resource variable. Set when the output
// ResourceHandle represents a per-replica/partitioned resource variable.
std::vector<string> allowed_devices_;
};
class ReadVariableOp : public OpKernel {

View File

@ -30,7 +30,6 @@ from tensorflow.core.framework import tensor_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes
@ -1513,41 +1512,5 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
self.assertAllEqual(expected, result)
class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase):
def setUp(self):
super(PerReplicaResourceHandleTest, self).setUp()
cpus = config.list_physical_devices("CPU")
# Set 2 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
])
@test_util.disable_tfrt("Multiple device support. b/154956430")
def testAllowedDevices(self):
device0 = "/job:localhost/replica:0/task:0/device:CPU:0"
device1 = "/job:localhost/replica:0/task:0/device:CPU:1"
value0 = 1
value1 = 2
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[], allowed_devices=[device0, device1])
with ops.device(device0):
assign0 = resource_variable_ops.assign_variable_op(handle, value0)
with ops.device(device1):
assign1 = resource_variable_ops.assign_variable_op(handle, value1)
with ops.control_dependencies([assign0, assign1]):
with ops.device(device0):
read0 = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
with ops.device(device1):
read1 = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
self.assertAllEqual(value0, read0)
self.assertAllEqual(value1, read1)
if __name__ == "__main__":
test.main()