Support returning resource handles from function in XLA
There are a couple of reasons to do this: - resource handle are regular tensors part of a public API that can potentially be returned from a function. - When tfe.defun is executed under GradientTape, it generates a function returning resource handles in certain cases. This CL adds support for returning resource handles from an XLA compiled function. These resource handles must have been passed as arguments to the function. In other words, we don't yet support returning resources created inside the function. tfe.defun never makes functions that create resources. PiperOrigin-RevId: 210442856
This commit is contained in:
parent
df6c8721f8
commit
fc492c08d6
tensorflow/compiler
@ -209,8 +209,13 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
// device memory.
|
||||
|
||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||
// in device memory
|
||||
// in device memory except for resources.
|
||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
||||
for (int i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
if (fbody->ret_types[i] == DT_RESOURCE) {
|
||||
output_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
|
||||
// Create the kernel.
|
||||
NameAttrList function;
|
||||
|
@ -271,31 +271,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
}
|
||||
} else {
|
||||
const TensorShape& shape = kernel->outputs[i].shape;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
|
||||
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
if (allocate_xla_tensors_) {
|
||||
Tensor* output_tensor;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
|
||||
if (xla_tensor) {
|
||||
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
|
||||
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
|
||||
if (use_multiple_streams_) {
|
||||
xla_tensor->SetDefinedOn(stream, definition_event);
|
||||
const DataType& type = kernel->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
if (type == DT_RESOURCE) {
|
||||
ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
if (allocate_xla_tensors_) {
|
||||
Tensor* output_tensor;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
|
||||
if (xla_tensor) {
|
||||
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
|
||||
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
|
||||
if (use_multiple_streams_) {
|
||||
xla_tensor->SetDefinedOn(stream, definition_event);
|
||||
}
|
||||
} else {
|
||||
// xla_tensor wasn't valid, which must mean this is a zero-element
|
||||
// tensor.
|
||||
CHECK_EQ(output_tensor->TotalBytes(), 0);
|
||||
}
|
||||
} else {
|
||||
// xla_tensor wasn't valid, which must mean this is a zero-element
|
||||
// tensor.
|
||||
CHECK_EQ(output_tensor->TotalBytes(), 0);
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
} else {
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
++output_num;
|
||||
}
|
||||
++output_num;
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
|
@ -351,6 +351,38 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
var = f(v)
|
||||
self.assertEqual(2.0, var.numpy())
|
||||
|
||||
def testReturnResourceHandle(self):
|
||||
with self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
def f(v):
|
||||
return v.handle
|
||||
|
||||
f = function.defun(f)
|
||||
handle = f(v)
|
||||
self.assertAllEqual(v.numpy(),
|
||||
resource_variable_ops.read_variable_op(
|
||||
handle, dtypes.float32).numpy())
|
||||
|
||||
def testReturnMultipleResourceHandles(self):
|
||||
with self.test_scope():
|
||||
v1 = resource_variable_ops.ResourceVariable(1.25)
|
||||
v2 = resource_variable_ops.ResourceVariable(2.0)
|
||||
|
||||
def f(v):
|
||||
return v.handle, 3.0 * v, v2.handle, v + v2
|
||||
|
||||
f = function.defun(f)
|
||||
v1_handle, v1_times_3, v2_handle, variable_sum = f(v1)
|
||||
self.assertAllEqual(v1.numpy(),
|
||||
resource_variable_ops.read_variable_op(
|
||||
v1_handle, dtypes.float32).numpy())
|
||||
self.assertEqual(3.75, v1_times_3.numpy())
|
||||
self.assertAllEqual(v2.numpy(),
|
||||
resource_variable_ops.read_variable_op(
|
||||
v2_handle, dtypes.float32).numpy())
|
||||
self.assertEqual(3.25, variable_sum.numpy())
|
||||
|
||||
def testAllArgumentKinds(self):
|
||||
"""Test a complex function that takes different argument kinds.
|
||||
|
||||
@ -457,6 +489,72 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
y = two_x_plus_1(x)
|
||||
self.assertAllEqual([5, 7, 9], y.numpy())
|
||||
|
||||
def testNestedDefunWithVariable(self):
|
||||
with self.test_scope():
|
||||
v0 = resource_variable_ops.ResourceVariable(5.0)
|
||||
|
||||
@function.defun
|
||||
def g(x):
|
||||
x = v0 * x
|
||||
return x
|
||||
|
||||
@function.defun
|
||||
def f(x):
|
||||
x = g(v0 * x)
|
||||
return x
|
||||
|
||||
x = constant_op.constant(3.0)
|
||||
y = f(x)
|
||||
|
||||
self.assertEqual(75, y.numpy())
|
||||
|
||||
def testNestedDefunInGradientTape(self):
|
||||
with self.test_scope():
|
||||
v0 = resource_variable_ops.ResourceVariable(5.0)
|
||||
|
||||
@function.defun
|
||||
def g(x):
|
||||
x = v0 * x
|
||||
return x
|
||||
|
||||
@function.defun
|
||||
def f(x):
|
||||
x = g(v0 * x)
|
||||
return x
|
||||
|
||||
x = constant_op.constant(3.0)
|
||||
with backprop.GradientTape() as tape:
|
||||
y = f(x)
|
||||
dy = tape.gradient(y, v0)
|
||||
|
||||
self.assertEqual(75, y.numpy())
|
||||
self.assertEqual(30, dy.numpy())
|
||||
|
||||
def testNestedDefunInGradientTapeDifferentVars(self):
|
||||
with self.test_scope():
|
||||
v0 = resource_variable_ops.ResourceVariable(5.0)
|
||||
v1 = resource_variable_ops.ResourceVariable(3.0)
|
||||
|
||||
@function.defun
|
||||
def g(x):
|
||||
x = v1 * x
|
||||
return x
|
||||
|
||||
@function.defun
|
||||
def f(x):
|
||||
x = g(v0 * x)
|
||||
return x
|
||||
|
||||
x = constant_op.constant(3.0)
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
y = f(x)
|
||||
dy_v0 = tape.gradient(y, v0)
|
||||
dy_v1 = tape.gradient(y, v1)
|
||||
|
||||
self.assertEqual(45, y.numpy())
|
||||
self.assertEqual(9, dy_v0.numpy())
|
||||
self.assertEqual(15, dy_v1.numpy())
|
||||
|
||||
|
||||
class ExcessivePaddingTest(xla_test.XLATestCase):
|
||||
"""Test that eager execution works with TPU flattened tensors.
|
||||
|
@ -146,6 +146,7 @@ Status GraphCompiler::Compile() {
|
||||
}
|
||||
|
||||
OpKernelContext op_context(¶ms, n->num_outputs());
|
||||
VLOG(3) << "Translating " << params.op_kernel->name();
|
||||
if (IsFunctional(n)) {
|
||||
TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
|
||||
} else {
|
||||
|
@ -48,6 +48,15 @@ class RetvalOp : public XlaOpKernel {
|
||||
} else {
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
DataType input_type = ctx->input_type(0);
|
||||
XlaContext& tc = XlaContext::Get(ctx);
|
||||
|
||||
if (input_type == DT_RESOURCE) {
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
ctx->SetStatus(tc.AddResourceRetval(index_, resource));
|
||||
return;
|
||||
}
|
||||
|
||||
auto is_constant = ctx->builder()->IsConstant(input);
|
||||
if (!is_constant.ok()) {
|
||||
@ -55,7 +64,6 @@ class RetvalOp : public XlaOpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
XlaContext& tc = XlaContext::Get(ctx);
|
||||
if (tc.resolve_compile_time_constants() &&
|
||||
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
|
||||
xla::Literal literal;
|
||||
@ -104,7 +112,8 @@ class RetvalOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp);
|
||||
REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(),
|
||||
RetvalOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -361,6 +361,9 @@ Status BuildComputation(
|
||||
if (retval.has_constant_value()) {
|
||||
output.is_constant = true;
|
||||
output.constant_value = retval.constant_value();
|
||||
} else if (retval.resource() != nullptr) {
|
||||
output.is_constant = false;
|
||||
output.input_index = retval.resource()->arg_num();
|
||||
} else {
|
||||
output.is_constant = false;
|
||||
elems.push_back(retval.handle());
|
||||
@ -495,7 +498,8 @@ Status XlaCompiler::BuildArguments(
|
||||
arg_expression.set_constant_value(arg.constant_value);
|
||||
break;
|
||||
case XlaCompiler::Argument::kInvalid:
|
||||
return errors::Internal("Unreachable case in BuildArguments()");
|
||||
return errors::Internal(
|
||||
"Unreachable case in BuildArguments() while filling constant args");
|
||||
}
|
||||
}
|
||||
|
||||
@ -615,7 +619,8 @@ Status XlaCompiler::BuildArguments(
|
||||
break;
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
case XlaCompiler::Argument::kInvalid:
|
||||
return errors::Internal("Unreachable case in BuildArguments()");
|
||||
return errors::Internal(
|
||||
"Unreachable case in BuildArguments() while filling handles");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -183,6 +183,8 @@ class XlaCompiler {
|
||||
|
||||
struct OutputDescription {
|
||||
// Type and shape of the output. The shape is the unflattened shape.
|
||||
// When `type` is DT_RESOURCE, `shape` is the shape of the resource
|
||||
// variable's value.
|
||||
DataType type;
|
||||
TensorShape shape;
|
||||
|
||||
@ -190,6 +192,10 @@ class XlaCompiler {
|
||||
// 'Tensor' is in host memory.
|
||||
bool is_constant = false;
|
||||
Tensor constant_value;
|
||||
|
||||
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
|
||||
// the index of the input that contains the resource.
|
||||
int input_index;
|
||||
};
|
||||
|
||||
// Describes a variable write side effect of the computation.
|
||||
|
@ -861,6 +861,33 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
void RunAndCheckVariablesComputation(
|
||||
xla::Client* client, const XlaCompiler::CompilationResult& result) {
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client
|
||||
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::LiteralUtil::CreateR1<int32>({5, 144});
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::LiteralUtil::CreateR1<int32>({4, 143});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
|
||||
// Tests a simple graph that reads and writes a variable.
|
||||
TEST_F(XlaCompilerTest, Variables) {
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
@ -892,36 +919,90 @@ TEST_F(XlaCompilerTest, Variables) {
|
||||
// Compiles the graph.
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args, &result));
|
||||
RunAndCheckVariablesComputation(client_, result);
|
||||
}
|
||||
|
||||
// Tests a simple graph that reads and writes a variable.
|
||||
TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
|
||||
auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
|
||||
// Builds a description of the arguments.
|
||||
std::vector<XlaCompiler::Argument> args(1);
|
||||
args[0].kind = XlaCompiler::Argument::kResource;
|
||||
args[0].resource_kind = XlaResource::kVariable;
|
||||
args[0].initialized = true;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = TensorShape({2});
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args, &result));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_
|
||||
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
|
||||
client_->Execute(*result.computation, {param1_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::LiteralUtil::CreateR1<int32>({5, 144});
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::LiteralUtil::CreateR1<int32>({4, 143});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
xla::LiteralUtil::MakeTuple({});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
|
||||
TEST_F(XlaCompilerTest, ReturnResourceHandle) {
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
|
||||
// Adds an identity op around the resource to make sure identity ops propagate
|
||||
// resources correctly.
|
||||
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
|
||||
auto write = ops::AssignAddVariableOp(scope, identity, a);
|
||||
auto read = ops::ReadVariableOp(
|
||||
scope.WithControlDependencies(std::vector<Operation>{write}), var,
|
||||
DT_INT32);
|
||||
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
|
||||
auto r = ops::_Retval(scope.WithOpName("R"), var, 0);
|
||||
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
|
||||
// Builds a description of the arguments.
|
||||
std::vector<XlaCompiler::Argument> args(2);
|
||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = TensorShape({2});
|
||||
args[1].kind = XlaCompiler::Argument::kResource;
|
||||
args[1].resource_kind = XlaResource::kVariable;
|
||||
args[1].initialized = true;
|
||||
args[1].type = DT_INT32;
|
||||
args[1].shape = TensorShape({2});
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args, &result));
|
||||
RunAndCheckVariablesComputation(client_, result);
|
||||
}
|
||||
|
||||
xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||
|
@ -107,6 +107,19 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
|
||||
VLOG(1) << "Adding retval index " << retval_index << " with resource "
|
||||
<< resource->name() << ":" << resource->shape().DebugString()
|
||||
<< " to XLA computation";
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
XlaExpression e;
|
||||
e.set_resource(resource);
|
||||
retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::XlaBuilder* XlaContext::builder() { return builder_; }
|
||||
|
||||
Status XlaContext::CreateResource(
|
||||
|
@ -86,6 +86,9 @@ class XlaContext : public ResourceBase {
|
||||
Status AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::LiteralSlice& literal);
|
||||
|
||||
// As for Retval, but for return values that are resource handles.
|
||||
Status AddResourceRetval(int retval_index, XlaResource* resource);
|
||||
|
||||
// Creates a resource with resource `kind` and initial value `handle`. `name`
|
||||
// is a descriptive name for use in error messages. See the `XlaResource`
|
||||
// constructor for a description of the remaining arguments.
|
||||
|
Loading…
Reference in New Issue
Block a user