[TF:XLA] Support resource variable args into uninlined function calls
PiperOrigin-RevId: 240186584
This commit is contained in:
parent
f9cda530e4
commit
e9d0b39c6e
@ -341,6 +341,57 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
var = f()
|
||||
self.assertEqual(1.0, var.numpy())
|
||||
|
||||
def testResourceVariableNoInlineReadWrite(self):
|
||||
with self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
w = resource_variable_ops.ResourceVariable(0.0)
|
||||
|
||||
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||
def g(x):
|
||||
w.assign(w.read_value() + x)
|
||||
return v.read_value() + x * w.read_value()
|
||||
|
||||
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||
def f():
|
||||
return g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0)
|
||||
|
||||
# 1 + 1*1 + 1 + 2*3 + 1 + 3*6 + 1 + 4*10 + 1 + 5*15
|
||||
self.assertEqual(145.0, f().numpy())
|
||||
self.assertEqual(15.0, w.read_value().numpy())
|
||||
|
||||
def testResourceVariableNoInlineReadOnly(self):
|
||||
with self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(10.0)
|
||||
|
||||
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||
def g():
|
||||
return v.read_value()
|
||||
|
||||
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||
def f():
|
||||
return g() + g() + g() + g() + g()
|
||||
|
||||
self.assertEqual(50.0, f().numpy())
|
||||
|
||||
def testResourceVariableNoInlineWriteOnly(self):
|
||||
with self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(0.0)
|
||||
|
||||
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||
def g(x):
|
||||
v.assign(x)
|
||||
|
||||
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||
def f():
|
||||
g(1.0)
|
||||
g(2.0)
|
||||
g(3.0)
|
||||
g(4.0)
|
||||
g(5.0)
|
||||
|
||||
f()
|
||||
self.assertEqual(5.0, v.read_value().numpy())
|
||||
|
||||
def testUpdateVariable(self):
|
||||
with self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
|
@ -88,11 +88,18 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
}
|
||||
break;
|
||||
case XlaExpression::Kind::kResource:
|
||||
// TODO(b/126601755): This is a fairly common use case in TF 2.0 that
|
||||
// we can hit when inlining is disabled or fails.
|
||||
return errors::Unimplemented(
|
||||
"Resource as function argument is not yet implemented.");
|
||||
case XlaExpression::Kind::kResource: {
|
||||
XlaResource* resource = expressions[i]->resource();
|
||||
|
||||
arg.initialized = resource->initialized();
|
||||
arg.kind = XlaCompiler::Argument::kResource;
|
||||
arg.resource_kind = resource->kind();
|
||||
arg.type = resource->type();
|
||||
arg.shape = resource->shape();
|
||||
arg.max_array_size = resource->max_array_size();
|
||||
arg.name = resource->name();
|
||||
break;
|
||||
}
|
||||
case XlaExpression::Kind::kTensorList:
|
||||
return errors::Unimplemented(
|
||||
"TensorList as function argument is not yet implemented.");
|
||||
@ -266,8 +273,12 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
|
||||
if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
|
||||
continue;
|
||||
}
|
||||
if (arguments[i].kind == XlaCompiler::Argument::kResource) {
|
||||
handles.push_back(expressions[i]->resource()->value());
|
||||
} else {
|
||||
handles.push_back(expressions[i]->handle());
|
||||
}
|
||||
}
|
||||
if (add_token_input_output) {
|
||||
std::vector<string> token_input_nodes;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
|
||||
@ -296,6 +307,17 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
|
||||
++computation_output;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < result.resource_updates.size(); i++) {
|
||||
if (result.resource_updates[i].modified) {
|
||||
XlaResource* resource =
|
||||
expressions[result.resource_updates[i].input_index]->resource();
|
||||
xla::XlaOp updated_value =
|
||||
xla::GetTupleElement(output_handle, i + n->num_outputs());
|
||||
TF_RETURN_IF_ERROR(resource->SetValue(updated_value));
|
||||
}
|
||||
}
|
||||
|
||||
if (add_token_input_output) {
|
||||
TF_RETURN_IF_ERROR(compiler->SetNodeToken(
|
||||
n->name(), xla::GetTupleElement(output_handle, computation_output)));
|
||||
|
Loading…
Reference in New Issue
Block a user