Support returning resource handles from function in XLA

There are a couple of reasons to do this:
 - resource handle are regular tensors part of a public API that
 can potentially be returned from a function.
 - When tfe.defun is executed under GradientTape, it generates a
 function returning resource handles in certain cases.

This CL adds support for returning resource handles from an XLA
compiled function. These resource handles must have been passed as
arguments to the function. In other words, we don't yet support
returning resources created inside the function. tfe.defun never
makes functions that create resources.

PiperOrigin-RevId: 210442856
This commit is contained in:
Igor Ganichev 2018-08-27 15:29:56 -07:00 committed by TensorFlower Gardener
parent df6c8721f8
commit fc492c08d6
10 changed files with 263 additions and 37 deletions

View File

@ -209,8 +209,13 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory
// in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
for (int i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == DT_RESOURCE) {
output_memory_types[i] = HOST_MEMORY;
}
}
// Create the kernel.
NameAttrList function;

View File

@ -271,31 +271,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
}
} else {
const TensorShape& shape = kernel->outputs[i].shape;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
Tensor* output_tensor;
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
if (use_multiple_streams_) {
xla_tensor->SetDefinedOn(stream, definition_event);
const DataType& type = kernel->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_RESOURCE) {
ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
Tensor* output_tensor;
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
if (use_multiple_streams_) {
xla_tensor->SetDefinedOn(stream, definition_event);
}
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
// tensor.
CHECK_EQ(output_tensor->TotalBytes(), 0);
}
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
// tensor.
CHECK_EQ(output_tensor->TotalBytes(), 0);
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
++output_num;
}
++output_num;
}
if (VLOG_IS_ON(3)) {

View File

@ -351,6 +351,38 @@ class EagerFunctionTest(xla_test.XLATestCase):
var = f(v)
self.assertEqual(2.0, var.numpy())
def testReturnResourceHandle(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]])
def f(v):
return v.handle
f = function.defun(f)
handle = f(v)
self.assertAllEqual(v.numpy(),
resource_variable_ops.read_variable_op(
handle, dtypes.float32).numpy())
def testReturnMultipleResourceHandles(self):
with self.test_scope():
v1 = resource_variable_ops.ResourceVariable(1.25)
v2 = resource_variable_ops.ResourceVariable(2.0)
def f(v):
return v.handle, 3.0 * v, v2.handle, v + v2
f = function.defun(f)
v1_handle, v1_times_3, v2_handle, variable_sum = f(v1)
self.assertAllEqual(v1.numpy(),
resource_variable_ops.read_variable_op(
v1_handle, dtypes.float32).numpy())
self.assertEqual(3.75, v1_times_3.numpy())
self.assertAllEqual(v2.numpy(),
resource_variable_ops.read_variable_op(
v2_handle, dtypes.float32).numpy())
self.assertEqual(3.25, variable_sum.numpy())
def testAllArgumentKinds(self):
"""Test a complex function that takes different argument kinds.
@ -457,6 +489,72 @@ class EagerFunctionTest(xla_test.XLATestCase):
y = two_x_plus_1(x)
self.assertAllEqual([5, 7, 9], y.numpy())
def testNestedDefunWithVariable(self):
with self.test_scope():
v0 = resource_variable_ops.ResourceVariable(5.0)
@function.defun
def g(x):
x = v0 * x
return x
@function.defun
def f(x):
x = g(v0 * x)
return x
x = constant_op.constant(3.0)
y = f(x)
self.assertEqual(75, y.numpy())
def testNestedDefunInGradientTape(self):
with self.test_scope():
v0 = resource_variable_ops.ResourceVariable(5.0)
@function.defun
def g(x):
x = v0 * x
return x
@function.defun
def f(x):
x = g(v0 * x)
return x
x = constant_op.constant(3.0)
with backprop.GradientTape() as tape:
y = f(x)
dy = tape.gradient(y, v0)
self.assertEqual(75, y.numpy())
self.assertEqual(30, dy.numpy())
def testNestedDefunInGradientTapeDifferentVars(self):
with self.test_scope():
v0 = resource_variable_ops.ResourceVariable(5.0)
v1 = resource_variable_ops.ResourceVariable(3.0)
@function.defun
def g(x):
x = v1 * x
return x
@function.defun
def f(x):
x = g(v0 * x)
return x
x = constant_op.constant(3.0)
with backprop.GradientTape(persistent=True) as tape:
y = f(x)
dy_v0 = tape.gradient(y, v0)
dy_v1 = tape.gradient(y, v1)
self.assertEqual(45, y.numpy())
self.assertEqual(9, dy_v0.numpy())
self.assertEqual(15, dy_v1.numpy())
class ExcessivePaddingTest(xla_test.XLATestCase):
"""Test that eager execution works with TPU flattened tensors.

View File

@ -146,6 +146,7 @@ Status GraphCompiler::Compile() {
}
OpKernelContext op_context(&params, n->num_outputs());
VLOG(3) << "Translating " << params.op_kernel->name();
if (IsFunctional(n)) {
TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
} else {

View File

@ -48,6 +48,15 @@ class RetvalOp : public XlaOpKernel {
} else {
xla::XlaOp input = ctx->Input(0);
const TensorShape input_shape = ctx->InputShape(0);
DataType input_type = ctx->input_type(0);
XlaContext& tc = XlaContext::Get(ctx);
if (input_type == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
ctx->SetStatus(tc.AddResourceRetval(index_, resource));
return;
}
auto is_constant = ctx->builder()->IsConstant(input);
if (!is_constant.ok()) {
@ -55,7 +64,6 @@ class RetvalOp : public XlaOpKernel {
return;
}
XlaContext& tc = XlaContext::Get(ctx);
if (tc.resolve_compile_time_constants() &&
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
xla::Literal literal;
@ -104,7 +112,8 @@ class RetvalOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
};
REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp);
REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(),
RetvalOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -361,6 +361,9 @@ Status BuildComputation(
if (retval.has_constant_value()) {
output.is_constant = true;
output.constant_value = retval.constant_value();
} else if (retval.resource() != nullptr) {
output.is_constant = false;
output.input_index = retval.resource()->arg_num();
} else {
output.is_constant = false;
elems.push_back(retval.handle());
@ -495,7 +498,8 @@ Status XlaCompiler::BuildArguments(
arg_expression.set_constant_value(arg.constant_value);
break;
case XlaCompiler::Argument::kInvalid:
return errors::Internal("Unreachable case in BuildArguments()");
return errors::Internal(
"Unreachable case in BuildArguments() while filling constant args");
}
}
@ -615,7 +619,8 @@ Status XlaCompiler::BuildArguments(
break;
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kInvalid:
return errors::Internal("Unreachable case in BuildArguments()");
return errors::Internal(
"Unreachable case in BuildArguments() while filling handles");
}
}

View File

@ -183,6 +183,8 @@ class XlaCompiler {
struct OutputDescription {
// Type and shape of the output. The shape is the unflattened shape.
// When `type` is DT_RESOURCE, `shape` is the shape of the resource
// variable's value.
DataType type;
TensorShape shape;
@ -190,6 +192,10 @@ class XlaCompiler {
// 'Tensor' is in host memory.
bool is_constant = false;
Tensor constant_value;
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
// the index of the input that contains the resource.
int input_index;
};
// Describes a variable write side effect of the computation.

View File

@ -861,6 +861,33 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
<< status.error_message();
}
void RunAndCheckVariablesComputation(
xla::Client* client, const XlaCompiler::CompilationResult& result) {
std::unique_ptr<xla::Literal> param0_literal =
xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client->TransferToServer(*param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::LiteralUtil::CreateR1<int32>({5, 144});
std::unique_ptr<xla::Literal> expected1 =
xla::LiteralUtil::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
// Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest, Variables) {
Scope scope = Scope::NewRootScope().ExitOnError();
@ -892,36 +919,90 @@ TEST_F(XlaCompilerTest, Variables) {
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
RunAndCheckVariablesComputation(client_, result);
}
// Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kVariable;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
client_->Execute(*result.computation, {param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::LiteralUtil::CreateR1<int32>({5, 144});
std::unique_ptr<xla::Literal> expected1 =
xla::LiteralUtil::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
xla::LiteralUtil::MakeTuple({});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
TEST_F(XlaCompilerTest, ReturnResourceHandle) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
// Adds an identity op around the resource to make sure identity ops propagate
// resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto r = ops::_Retval(scope.WithOpName("R"), var, 0);
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
RunAndCheckVariablesComputation(client_, result);
}
xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);

View File

@ -107,6 +107,19 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
return Status::OK();
}
Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
VLOG(1) << "Adding retval index " << retval_index << " with resource "
<< resource->name() << ":" << resource->shape().DebugString()
<< " to XLA computation";
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
XlaExpression e;
e.set_resource(resource);
retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e};
return Status::OK();
}
xla::XlaBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource(

View File

@ -86,6 +86,9 @@ class XlaContext : public ResourceBase {
Status AddConstRetval(int retval_index, DataType dtype,
const xla::LiteralSlice& literal);
// As for Retval, but for return values that are resource handles.
Status AddResourceRetval(int retval_index, XlaResource* resource);
// Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource`
// constructor for a description of the remaining arguments.