[TF:XLA] Support resource variable args into uninlined function calls

PiperOrigin-RevId: 240186584
This commit is contained in:
Sanjoy Das 2019-03-25 11:46:27 -07:00 committed by TensorFlower Gardener
parent f9cda530e4
commit e9d0b39c6e
2 changed files with 79 additions and 6 deletions

View File

@ -341,6 +341,57 @@ class EagerFunctionTest(xla_test.XLATestCase):
var = f()
self.assertEqual(1.0, var.numpy())
def testResourceVariableNoInlineReadWrite(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(1.0)
w = resource_variable_ops.ResourceVariable(0.0)
@function.defun_with_attributes(attributes={'_noinline': True})
def g(x):
w.assign(w.read_value() + x)
return v.read_value() + x * w.read_value()
@function.defun_with_attributes(attributes={'_noinline': True})
def f():
return g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0)
# 1 + 1*1 + 1 + 2*3 + 1 + 3*6 + 1 + 4*10 + 1 + 5*15
self.assertEqual(145.0, f().numpy())
self.assertEqual(15.0, w.read_value().numpy())
def testResourceVariableNoInlineReadOnly(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(10.0)
@function.defun_with_attributes(attributes={'_noinline': True})
def g():
return v.read_value()
@function.defun_with_attributes(attributes={'_noinline': True})
def f():
return g() + g() + g() + g() + g()
self.assertEqual(50.0, f().numpy())
def testResourceVariableNoInlineWriteOnly(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(0.0)
@function.defun_with_attributes(attributes={'_noinline': True})
def g(x):
v.assign(x)
@function.defun_with_attributes(attributes={'_noinline': True})
def f():
g(1.0)
g(2.0)
g(3.0)
g(4.0)
g(5.0)
f()
self.assertEqual(5.0, v.read_value().numpy())
def testUpdateVariable(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(1.0)

View File

@ -88,11 +88,18 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
arg.kind = XlaCompiler::Argument::kParameter;
}
break;
case XlaExpression::Kind::kResource:
// TODO(b/126601755): This is a fairly common use case in TF 2.0 that
// we can hit when inlining is disabled or fails.
return errors::Unimplemented(
"Resource as function argument is not yet implemented.");
case XlaExpression::Kind::kResource: {
XlaResource* resource = expressions[i]->resource();
arg.initialized = resource->initialized();
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = resource->kind();
arg.type = resource->type();
arg.shape = resource->shape();
arg.max_array_size = resource->max_array_size();
arg.name = resource->name();
break;
}
case XlaExpression::Kind::kTensorList:
return errors::Unimplemented(
"TensorList as function argument is not yet implemented.");
@ -266,7 +273,11 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
continue;
}
handles.push_back(expressions[i]->handle());
if (arguments[i].kind == XlaCompiler::Argument::kResource) {
handles.push_back(expressions[i]->resource()->value());
} else {
handles.push_back(expressions[i]->handle());
}
}
if (add_token_input_output) {
std::vector<string> token_input_nodes;
@ -296,6 +307,17 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
++computation_output;
}
}
for (int64 i = 0; i < result.resource_updates.size(); i++) {
if (result.resource_updates[i].modified) {
XlaResource* resource =
expressions[result.resource_updates[i].input_index]->resource();
xla::XlaOp updated_value =
xla::GetTupleElement(output_handle, i + n->num_outputs());
TF_RETURN_IF_ERROR(resource->SetValue(updated_value));
}
}
if (add_token_input_output) {
TF_RETURN_IF_ERROR(compiler->SetNodeToken(
n->name(), xla::GetTupleElement(output_handle, computation_output)));