PiperOrigin-RevId: 299965831
Change-Id: I22b867dd0c7c95552d52b49595f1f5b5279639ec
This commit is contained in:
Yujing Zhang 2020-03-09 16:37:17 -07:00 committed by TensorFlower Gardener
parent 6f968a3a59
commit bae9776f57
13 changed files with 167 additions and 48 deletions

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/versions.pb.h"
@ -42,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/dump_graph.h"
@ -128,19 +130,31 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
return Status::OK();
}
void ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) {
for (auto& node : *graph_def->mutable_node()) {
Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) {
auto update_var_handle_op_node = [](NodeDef& node) -> Status {
if (node.op() == "VarHandleOp") {
node.set_op(tfcompile::kXlaAotOnlyVarHandleOp);
const auto& it = node.attr().find("allowed_devices");
if (it != node.attr().end()) {
if (!it->second.list().s().empty()) {
// TODO(b/149512838): Support non-empty allowed devices.
return errors::InvalidArgument(
"VarHandleOp with non-empty allowed devices is not supported.");
}
node.mutable_attr()->erase("allowed_devices");
}
}
return Status::OK();
};
for (auto& node : *graph_def->mutable_node()) {
TF_RETURN_IF_ERROR(update_var_handle_op_node(node));
}
for (auto& fn : *graph_def->mutable_library()->mutable_function()) {
for (auto& node : *fn.mutable_node_def()) {
if (node.op() == "VarHandleOp") {
node.set_op(tfcompile::kXlaAotOnlyVarHandleOp);
}
TF_RETURN_IF_ERROR(update_var_handle_op_node(node));
}
}
return Status::OK();
}
} // namespace
@ -149,7 +163,7 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
xla::Client* client,
xla::XlaComputation* computation) {
std::unique_ptr<Graph> graph;
ConvertVarHandlesToAotVarHandles(&graph_def);
TF_RETURN_IF_ERROR(ConvertVarHandlesToAotVarHandles(&graph_def));
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
TF_RETURN_IF_ERROR(
ConvertGraphToXla(std::move(graph), config, client, computation));

View File

@ -23,6 +23,13 @@ END
name: "shape"
description: <<END
The (possibly partially specified) shape of this variable.
END
}
attr {
name: "allowed_devices"
description: <<END
The allowed devices containing the resource variable. Set when the output
ResourceHandle represents a per-replica/partitioned resource variable.
END
}
summary: "Creates a handle to a Variable resource."

View File

@ -437,12 +437,13 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
// looking it up in ResourceMgr, which is slow). So we just get
// resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
// resource_dtypes_and_shapes is not empty, take the first element.
std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
&resource_dtypes_and_shapes));
if (!resource_dtypes_and_shapes.empty()) {
TensorHandle::ResourceHandleInfo resource_handle_info;
TF_RETURN_IF_ERROR(input->GetResourceHandleInfo(&resource_handle_info));
std::vector<DtypeAndPartialTensorShape>* resource_dtypes_and_shapes =
&resource_handle_info.dtypes_and_shapes;
if (!resource_dtypes_and_shapes->empty()) {
const DtypeAndPartialTensorShape& dtype_and_shape =
resource_dtypes_and_shapes.at(0);
resource_dtypes_and_shapes->at(0);
input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
// Add _Arg index, dtype and shape to "cache_key".
@ -647,8 +648,13 @@ Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
const AttrValue* shape;
TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
retvals[0]->SetResourceHandleDtypeAndShape(
{DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
TensorHandle::ResourceHandleInfo resource_handle_info = {
{DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}, {}};
// "allowed_devices" is set only when the output represents a
// per-replica/partitioned resource variable.
TryGetNodeAttr(attr_slice, "allowed_devices",
&resource_handle_info.allowed_devices);
retvals[0]->SetResourceHandleInfo(std::move(resource_handle_info));
}
return Status::OK();
}
@ -869,6 +875,19 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
// is a resource we must pin it to prevent different device selection.
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
if (resource_device != op_device || op->Device() == kVariantDeviceNull) {
std::vector<string> allowed_devices;
TF_RETURN_IF_ERROR(
tensor_handle->GetResourceAllowedDevices(&allowed_devices));
if (!allowed_devices.empty()) {
// TODO(b/145922293): Support allowed_devices specified in wildcard
// patterns.
std::vector<string> device_names;
if (std::find(allowed_devices.begin(), allowed_devices.end(),
op->GetDeviceName()) != allowed_devices.end()) {
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(
op->GetDeviceName().c_str(), &resource_device));
}
}
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
<< "device of operation " << op->Name() << " to "
<< resource_device->name() << " because input #" << i

View File

@ -55,13 +55,13 @@ const int32 kInvalidOutputNum = -1;
#endif
} // namespace
void TensorHandle::SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes) {
handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
void TensorHandle::SetResourceHandleInfo(
ResourceHandleInfo&& resource_handle_info) {
resource_handle_info_ = std::move(resource_handle_info);
}
Status TensorHandle::GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result) {
Status TensorHandle::GetResourceHandleInfoImpl(
std::function<void()> set_resource_info) {
if (dtype != DT_RESOURCE) {
return errors::InvalidArgument(
"TensorHandle::GetResourceDtypeAndShape should be called on tensor "
@ -70,22 +70,42 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
}
if (IsRemote()) {
*result = handle_dtypes_and_shapes_;
set_resource_info();
return Status::OK();
}
// Wait for this TensorHandle to be ready.
profiler::TraceMe activity(
"TensorHandle::GetResourceHandleDtypesAndShapes WaitReady",
profiler::TraceMeLevel::kInfo);
profiler::TraceMe activity("TensorHandle::GetResourceHandleInfo WaitReady",
profiler::TraceMeLevel::kInfo);
auto& data = absl::get<LocalTensorHandleData>(data_);
TF_RETURN_IF_ERROR(
data.WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes"));
TF_RETURN_IF_ERROR(data.WaitReady("TensorHandle::GetResourceHandleInfo"));
*result = handle_dtypes_and_shapes_;
set_resource_info();
return Status::OK();
}
Status TensorHandle::GetResourceHandleInfo(ResourceHandleInfo* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_.dtypes_and_shapes;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_.allowed_devices;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t,
TensorHandle** h) {
// TODO(b/136608821): Move away from nullptr
@ -145,8 +165,9 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
GetResourceDevice(t.flat<class ResourceHandle>()(0), ctx)),
ctx_(ctx),
implicit_mirroring_(true),
handle_dtypes_and_shapes_(
t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
resource_handle_info_(
{t.flat<class ResourceHandle>()(0).dtypes_and_shapes(),
t.flat<class ResourceHandle>()(0).allowed_devices()}),
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_)
@ -669,7 +690,8 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
auto& resource_handle = t.flat<class ResourceHandle>()(0);
handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
resource_handle_info_ = {resource_handle.dtypes_and_shapes(),
resource_handle.allowed_devices()};
}
auto& data = absl::get<LocalTensorHandleData>(data_);
return data.SetTensor(std::move(t));

View File

@ -210,13 +210,19 @@ class TensorHandle : public core::RefCounted {
string DebugString() const;
void SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes);
struct ResourceHandleInfo {
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
std::vector<string> allowed_devices;
};
void SetResourceHandleInfo(ResourceHandleInfo&& resource_handle_info);
// If this TensorHandle is 1) a local tensor, and 2) a resource handle,
// return data types and shapes of the underlying resource.
// return data types, shapes and allowed devices of the underlying resource.
Status GetResourceHandleInfo(ResourceHandleInfo* result);
Status GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result);
Status GetResourceAllowedDevices(std::vector<string>* result);
private:
// The TensorHandleData can either represent a local or remote tensor handle.
@ -225,6 +231,8 @@ class TensorHandle : public core::RefCounted {
// with a ready version of the tensor handle data.
bool IsReady() const;
Status GetResourceHandleInfoImpl(std::function<void()> set_resource_info);
VariantDevice const device_;
// Device in which the op producing this tensor was executed. Equals to
@ -268,9 +276,9 @@ class TensorHandle : public core::RefCounted {
bool implicit_mirroring_;
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
// refers to a remote resource handle, we store data types and shapes for
// the underlying resource.
std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
// refers to a remote resource handle, we store data types, shapes and allowed
// devices for the underlying resource.
ResourceHandleInfo resource_handle_info_;
// Does not need synchronization because it can be accessed only after
// WaitReady() has returned. At that point, data_ is immutable.

View File

@ -164,22 +164,24 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
parent_->FindDeviceFromName(device_name.c_str(), &device));
TF_RETURN_IF_ERROR(TensorHandle::CreateLazyRemoteHandle(
in.op_id(), in.output_num(), in.dtype(), device, parent_, out));
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
TensorHandle::ResourceHandleInfo resource_handle_info;
std::vector<DtypeAndPartialTensorShape>* dtypes_and_shapes =
&resource_handle_info.dtypes_and_shapes;
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
&dtypes_and_shapes)
dtypes_and_shapes)
.ok()) {
for (const auto& dtype_and_shape_proto :
in.resource_dtypes_and_shapes()) {
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
dtypes_and_shapes->push_back(DtypeAndPartialTensorShape{
dtype_and_shape_proto.dtype(),
TensorShape(dtype_and_shape_proto.shape())});
}
mutex_lock l(mirrored_resource_shape_mu_);
mirrored_resource_shape_map_.emplace(
RemoteTensorHandleInternal(in.op_id(), in.output_num()),
dtypes_and_shapes);
*dtypes_and_shapes);
}
(*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
(*out)->SetResourceHandleInfo(std::move(resource_handle_info));
}
return Status::OK();

View File

@ -314,11 +314,13 @@ ResourceHandle MakeResourceHandle(
template <typename T>
ResourceHandle MakeResourceHandle(
OpKernelConstruction* ctx, const string& container, const string& name,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
return MakeResourceHandle(
container.empty() ? ctx->resource_manager()->default_container()
: container,
name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
const std::vector<string>& allowed_devices = {}) {
return MakeResourceHandle(container.empty()
? ctx->resource_manager()->default_container()
: container,
name, *ctx->device(), MakeTypeIndex<T>(),
dtypes_and_shapes, allowed_devices);
}
Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,

View File

@ -229,6 +229,8 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
PartialTensorShape shape;
OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
OP_REQUIRES_OK(context,
context->GetAttr("allowed_devices", &allowed_devices_));
is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
@ -239,7 +241,8 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
&resource_, attr));
resource_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
context, container_, name_,
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
allowed_devices_);
}
}
@ -252,7 +255,8 @@ void VarHandleOp::Compute(OpKernelContext* ctx) {
ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
handle.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
ctx, container_, name_,
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
allowed_devices_);
ctx->set_output(0, handle);
} else {
ctx->set_output(0, resource_);

View File

@ -36,6 +36,10 @@ class VarHandleOp : public OpKernel {
Tensor resource_;
DtypeAndPartialTensorShape dtype_and_shape_;
// A set of devices containing the resource variable. Set when the output
// ResourceHandle represents a per-replica/partitioned resource variable.
std::vector<string> allowed_devices_;
};
class ReadVariableOp : public OpKernel {

View File

@ -80,6 +80,7 @@ REGISTER_OP("VarHandleOp")
.Attr("shared_name: string = ''")
.Attr("dtype: type")
.Attr("shape: shape")
.Attr("allowed_devices: list(string) = []")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {

View File

@ -30,6 +30,7 @@ from tensorflow.core.framework import tensor_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes
@ -1488,5 +1489,40 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
self.assertAllEqual(expected, result)
class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase):
def setUp(self):
super(PerReplicaResourceHandleTest, self).setUp()
cpus = config.list_physical_devices("CPU")
# Set 2 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
])
def testAllowedDevices(self):
device0 = "/job:localhost/replica:0/task:0/device:CPU:0"
device1 = "/job:localhost/replica:0/task:0/device:CPU:1"
value0 = 1
value1 = 2
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[], allowed_devices=[device0, device1])
with ops.device(device0):
assign0 = resource_variable_ops.assign_variable_op(handle, value0)
with ops.device(device1):
assign1 = resource_variable_ops.assign_variable_op(handle, value1)
with ops.control_dependencies([assign0, assign1]):
with ops.device(device0):
read0 = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
with ops.device(device1):
read1 = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
self.assertAllEqual(value0, read0)
self.assertAllEqual(value1, read1)
if __name__ == "__main__":
test.main()

View File

@ -4962,7 +4962,7 @@ tf_module {
}
member_method {
name: "VarHandleOp"
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'allowed_devices\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'[]\', \'None\'], "
}
member_method {
name: "VarIsInitializedOp"

View File

@ -4962,7 +4962,7 @@ tf_module {
}
member_method {
name: "VarHandleOp"
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'allowed_devices\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'[]\', \'None\'], "
}
member_method {
name: "VarIsInitializedOp"