[TF:XLA] Support resource variable args into uninlined function calls
PiperOrigin-RevId: 240186584
This commit is contained in:
parent
f9cda530e4
commit
e9d0b39c6e
@ -341,6 +341,57 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
|||||||
var = f()
|
var = f()
|
||||||
self.assertEqual(1.0, var.numpy())
|
self.assertEqual(1.0, var.numpy())
|
||||||
|
|
||||||
|
def testResourceVariableNoInlineReadWrite(self):
|
||||||
|
with self.test_scope():
|
||||||
|
v = resource_variable_ops.ResourceVariable(1.0)
|
||||||
|
w = resource_variable_ops.ResourceVariable(0.0)
|
||||||
|
|
||||||
|
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||||
|
def g(x):
|
||||||
|
w.assign(w.read_value() + x)
|
||||||
|
return v.read_value() + x * w.read_value()
|
||||||
|
|
||||||
|
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||||
|
def f():
|
||||||
|
return g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0)
|
||||||
|
|
||||||
|
# 1 + 1*1 + 1 + 2*3 + 1 + 3*6 + 1 + 4*10 + 1 + 5*15
|
||||||
|
self.assertEqual(145.0, f().numpy())
|
||||||
|
self.assertEqual(15.0, w.read_value().numpy())
|
||||||
|
|
||||||
|
def testResourceVariableNoInlineReadOnly(self):
|
||||||
|
with self.test_scope():
|
||||||
|
v = resource_variable_ops.ResourceVariable(10.0)
|
||||||
|
|
||||||
|
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||||
|
def g():
|
||||||
|
return v.read_value()
|
||||||
|
|
||||||
|
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||||
|
def f():
|
||||||
|
return g() + g() + g() + g() + g()
|
||||||
|
|
||||||
|
self.assertEqual(50.0, f().numpy())
|
||||||
|
|
||||||
|
def testResourceVariableNoInlineWriteOnly(self):
|
||||||
|
with self.test_scope():
|
||||||
|
v = resource_variable_ops.ResourceVariable(0.0)
|
||||||
|
|
||||||
|
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||||
|
def g(x):
|
||||||
|
v.assign(x)
|
||||||
|
|
||||||
|
@function.defun_with_attributes(attributes={'_noinline': True})
|
||||||
|
def f():
|
||||||
|
g(1.0)
|
||||||
|
g(2.0)
|
||||||
|
g(3.0)
|
||||||
|
g(4.0)
|
||||||
|
g(5.0)
|
||||||
|
|
||||||
|
f()
|
||||||
|
self.assertEqual(5.0, v.read_value().numpy())
|
||||||
|
|
||||||
def testUpdateVariable(self):
|
def testUpdateVariable(self):
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
v = resource_variable_ops.ResourceVariable(1.0)
|
v = resource_variable_ops.ResourceVariable(1.0)
|
||||||
|
@ -88,11 +88,18 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
|||||||
arg.kind = XlaCompiler::Argument::kParameter;
|
arg.kind = XlaCompiler::Argument::kParameter;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case XlaExpression::Kind::kResource:
|
case XlaExpression::Kind::kResource: {
|
||||||
// TODO(b/126601755): This is a fairly common use case in TF 2.0 that
|
XlaResource* resource = expressions[i]->resource();
|
||||||
// we can hit when inlining is disabled or fails.
|
|
||||||
return errors::Unimplemented(
|
arg.initialized = resource->initialized();
|
||||||
"Resource as function argument is not yet implemented.");
|
arg.kind = XlaCompiler::Argument::kResource;
|
||||||
|
arg.resource_kind = resource->kind();
|
||||||
|
arg.type = resource->type();
|
||||||
|
arg.shape = resource->shape();
|
||||||
|
arg.max_array_size = resource->max_array_size();
|
||||||
|
arg.name = resource->name();
|
||||||
|
break;
|
||||||
|
}
|
||||||
case XlaExpression::Kind::kTensorList:
|
case XlaExpression::Kind::kTensorList:
|
||||||
return errors::Unimplemented(
|
return errors::Unimplemented(
|
||||||
"TensorList as function argument is not yet implemented.");
|
"TensorList as function argument is not yet implemented.");
|
||||||
@ -266,7 +273,11 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
|
|||||||
if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
|
if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
handles.push_back(expressions[i]->handle());
|
if (arguments[i].kind == XlaCompiler::Argument::kResource) {
|
||||||
|
handles.push_back(expressions[i]->resource()->value());
|
||||||
|
} else {
|
||||||
|
handles.push_back(expressions[i]->handle());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (add_token_input_output) {
|
if (add_token_input_output) {
|
||||||
std::vector<string> token_input_nodes;
|
std::vector<string> token_input_nodes;
|
||||||
@ -296,6 +307,17 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
|
|||||||
++computation_output;
|
++computation_output;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int64 i = 0; i < result.resource_updates.size(); i++) {
|
||||||
|
if (result.resource_updates[i].modified) {
|
||||||
|
XlaResource* resource =
|
||||||
|
expressions[result.resource_updates[i].input_index]->resource();
|
||||||
|
xla::XlaOp updated_value =
|
||||||
|
xla::GetTupleElement(output_handle, i + n->num_outputs());
|
||||||
|
TF_RETURN_IF_ERROR(resource->SetValue(updated_value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (add_token_input_output) {
|
if (add_token_input_output) {
|
||||||
TF_RETURN_IF_ERROR(compiler->SetNodeToken(
|
TF_RETURN_IF_ERROR(compiler->SetNodeToken(
|
||||||
n->name(), xla::GetTupleElement(output_handle, computation_output)));
|
n->name(), xla::GetTupleElement(output_handle, computation_output)));
|
||||||
|
Loading…
Reference in New Issue
Block a user