From e9d0b39c6eb8da5aa39e78adbc193c866588909a Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 25 Mar 2019 11:46:27 -0700 Subject: [PATCH] [TF:XLA] Support resource variable args into uninlined function calls PiperOrigin-RevId: 240186584 --- tensorflow/compiler/tests/eager_test.py | 51 ++++++++++++++++++++ tensorflow/compiler/tf2xla/graph_compiler.cc | 34 ++++++++++--- 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index d2cfcb674c3..a13fbc9815f 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -341,6 +341,57 @@ class EagerFunctionTest(xla_test.XLATestCase): var = f() self.assertEqual(1.0, var.numpy()) + def testResourceVariableNoInlineReadWrite(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + w = resource_variable_ops.ResourceVariable(0.0) + + @function.defun_with_attributes(attributes={'_noinline': True}) + def g(x): + w.assign(w.read_value() + x) + return v.read_value() + x * w.read_value() + + @function.defun_with_attributes(attributes={'_noinline': True}) + def f(): + return g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0) + + # 1 + 1*1 + 1 + 2*3 + 1 + 3*6 + 1 + 4*10 + 1 + 5*15 + self.assertEqual(145.0, f().numpy()) + self.assertEqual(15.0, w.read_value().numpy()) + + def testResourceVariableNoInlineReadOnly(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(10.0) + + @function.defun_with_attributes(attributes={'_noinline': True}) + def g(): + return v.read_value() + + @function.defun_with_attributes(attributes={'_noinline': True}) + def f(): + return g() + g() + g() + g() + g() + + self.assertEqual(50.0, f().numpy()) + + def testResourceVariableNoInlineWriteOnly(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(0.0) + + @function.defun_with_attributes(attributes={'_noinline': True}) + def g(x): + v.assign(x) + + @function.defun_with_attributes(attributes={'_noinline': True}) + def f(): + g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0) + + f() + self.assertEqual(5.0, v.read_value().numpy()) + def testUpdateVariable(self): with self.test_scope(): v = resource_variable_ops.ResourceVariable(1.0) diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index e80b6f50ac3..c9d1ba287dc 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -88,11 +88,18 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, arg.kind = XlaCompiler::Argument::kParameter; } break; - case XlaExpression::Kind::kResource: - // TODO(b/126601755): This is a fairly common use case in TF 2.0 that - // we can hit when inlining is disabled or fails. - return errors::Unimplemented( - "Resource as function argument is not yet implemented."); + case XlaExpression::Kind::kResource: { + XlaResource* resource = expressions[i]->resource(); + + arg.initialized = resource->initialized(); + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = resource->kind(); + arg.type = resource->type(); + arg.shape = resource->shape(); + arg.max_array_size = resource->max_array_size(); + arg.name = resource->name(); + break; + } case XlaExpression::Kind::kTensorList: return errors::Unimplemented( "TensorList as function argument is not yet implemented."); @@ -266,7 +273,11 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, if (arguments[i].kind == XlaCompiler::Argument::kConstant) { continue; } - handles.push_back(expressions[i]->handle()); + if (arguments[i].kind == XlaCompiler::Argument::kResource) { + handles.push_back(expressions[i]->resource()->value()); + } else { + handles.push_back(expressions[i]->handle()); + } } if (add_token_input_output) { std::vector token_input_nodes; @@ -296,6 +307,17 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, ++computation_output; } } + + for (int64 i = 0; i < result.resource_updates.size(); i++) { + if (result.resource_updates[i].modified) { + XlaResource* resource = + expressions[result.resource_updates[i].input_index]->resource(); + xla::XlaOp updated_value = + xla::GetTupleElement(output_handle, i + n->num_outputs()); + TF_RETURN_IF_ERROR(resource->SetValue(updated_value)); + } + } + if (add_token_input_output) { TF_RETURN_IF_ERROR(compiler->SetNodeToken( n->name(), xla::GetTupleElement(output_handle, computation_output)));