[TF2XLA] Support must-be-constant resource variables for compilation

Performs an explicit copy at runtime from device to host if needed.

PiperOrigin-RevId: 341491694
Change-Id: If4a6c0c76a1110637a06e96595c6013c8fac17e5
This commit is contained in:
George Karpenkov 2020-11-09 14:57:06 -08:00 committed by TensorFlower Gardener
parent aed0ad4916
commit 38c53e2f59
16 changed files with 193 additions and 36 deletions

View File

@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr(
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arg_indices, inputs, variable_infos);
constant_arg_indices, inputs, variable_infos, dev);
TF_RETURN_IF_ERROR(args.status());
switch (stage) {

View File

@ -206,8 +206,9 @@ static Status CompileToLocalExecutable(
may_alias_resource_update;
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
variable_infos);
XlaComputationLaunchContext::BuildXlaCompilerArguments(
constants, inputs, variable_infos,
static_cast<Device*>(ctx->device()));
TF_RETURN_IF_ERROR(args.status());
return cache->Compile(options, function, *args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
@ -246,8 +247,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
VLOG(1) << "Executing XLA Computation...";
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),

View File

@ -140,6 +140,7 @@ XlaCompilationCache::BuildSignature(
for (const XlaCompiler::Argument& arg : args) {
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kConstantResource:
signature.arg_values.push_back(arg.constant_value);
break;
case XlaCompiler::Argument::kParameter:

View File

@ -153,7 +153,8 @@ Status XlaCompileOnDemandOp::Compile(
ctx, variables_indices, variable_infos, variable_args));
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_input_indices, inputs, variable_infos);
constant_input_indices, inputs, variable_infos,
static_cast<Device*>(ctx->device()));
TF_RETURN_IF_ERROR(args.status());
}

View File

@ -564,11 +564,26 @@ xla::StatusOr<std::vector<XlaCompiler::Argument>>
XlaComputationLaunchContext::BuildXlaCompilerArguments(
absl::Span<int const> must_be_constant_idxs,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_args) {
absl::Span<VariableInfo const> variable_args, Device* device) {
CHECK(absl::c_is_sorted(must_be_constant_idxs));
std::vector<XlaCompiler::Argument> out;
out.resize(inputs.size());
// TODO(cheshire): Avoid duplication with framework/op_kernel.h
DeviceContext* device_context = nullptr;
TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
bool using_default_context = false;
auto cleanup = xla::MakeCleanup([&] {
if (device_context != nullptr && !using_default_context) {
device_context->Unref();
}
});
if (device_context == nullptr) {
using_default_context = true;
auto* dev_info = device->tensorflow_gpu_device_info();
if (dev_info) device_context = dev_info->default_context;
}
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
for (const VariableInfo& info : variable_args) {
CHECK(!info.var() || info.lock_held())
@ -581,18 +596,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
const Tensor* input = inputs[input_num];
XlaCompiler::Argument& arg = out[input_num];
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
// Handles compile-time constants.
// TODO(b/157241314): Support constants located in resource variables.
TF_RET_CHECK(input->dtype() != DT_RESOURCE)
<< "tf2xla bridge does not support must-be-constants located in "
"resource variables; try moving them to a tensor";
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input->dtype();
arg.shape = input->shape();
arg.constant_value = *input;
} else if (variable_info_lookup.count(input_num)) {
if (variable_info_lookup.count(input_num)) {
// Handles resource variables.
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
const VariableInfo& variable = *variable_info_lookup[input_num];
@ -613,6 +617,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
arg.type = DT_INVALID;
arg.shape = TensorShape();
}
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
const Tensor* value = variable.var()->tensor();
Tensor value_on_host(value->dtype(), value->shape());
if (!device_context) {
value_on_host = *value;
} else {
TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
value, "", device, &value_on_host));
}
arg.kind = XlaCompiler::Argument::kConstantResource;
arg.constant_value = value_on_host;
}
} else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input->dtype();
arg.shape = input->shape();
arg.constant_value = *input;
} else {
// Normal inputs.
TF_RET_CHECK(input->dtype() != DT_RESOURCE);

View File

@ -143,7 +143,8 @@ class XlaComputationLaunchContext {
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_args);
absl::Span<VariableInfo const> variable_args,
Device* device);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.

View File

@ -73,7 +73,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
switch (expressions[i]->kind()) {
case XlaExpression::Kind::kConstant:
arg.kind = XlaCompiler::Argument::kConstant;
arg.constant_value = expressions[i]->constant_value();
arg.constant_value = *expressions[i]->constant_value();
break;
case XlaExpression::Kind::kXlaOp:
if (arg_must_be_compile_time_constant[i]) {

View File

@ -39,6 +39,9 @@ struct XlaArgument {
// associated runtime parameter iff `initialized` is true.
kResource,
// A resource variable with a constant value known at compile time.
kConstantResource,
// Argument is a run-time parameter.
kParameter,

View File

@ -207,7 +207,7 @@ Status BuildComputation(
switch (retval.kind()) {
case XlaExpression::Kind::kConstant:
output.is_constant = true;
output.constant_value = retval.constant_value();
output.constant_value = *retval.constant_value();
output.shape = output.constant_value.shape();
break;
@ -446,6 +446,9 @@ string XlaCompiler::Argument::HumanString() const {
case kConstant:
return absl::StrCat("kind=constant", common,
" value=", constant_value.DebugString());
case kConstantResource:
return absl::StrCat("kind=constant-resource", common,
" value=", constant_value.DebugString());
case kResource: {
string output = absl::StrCat(
"kind=resource", common,
@ -856,6 +859,7 @@ Status XlaCompiler::XLAShapeForArgument(
*xla_shape = absl::get<xla::Shape>(arg.shape);
return Status::OK();
}
case XlaCompiler::Argument::kConstantResource:
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
@ -959,6 +963,7 @@ Status XlaCompiler::BuildArguments(
const XlaCompiler::Argument& arg = args[i];
XlaExpression& arg_expression = (*arg_expressions)[i];
switch (arg.kind) {
case XlaCompiler::Argument::kConstantResource:
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
@ -971,7 +976,10 @@ Status XlaCompiler::BuildArguments(
/*max_array_size=*/arg.max_array_size,
/*tensor_array_gradients=*/arg.tensor_array_gradients,
/*tensor_array_multiple_writes_aggregate=*/true));
arg_expression = XlaExpression::Resource(resource);
arg_expression =
arg.kind == XlaCompiler::Argument::kResource
? XlaExpression::Resource(resource)
: XlaExpression::ConstantResource(arg.constant_value, resource);
if (arg.initialized) {
input_to_args->push_back(i);
}
@ -1124,6 +1132,7 @@ Status XlaCompiler::BuildArguments(
arg_shardings.at(i).DebugString()));
XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
switch (arg.kind) {
case XlaCompiler::Argument::kConstantResource:
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
XlaResource* resource = arg_expression.resource();

View File

@ -38,6 +38,16 @@ XlaExpression XlaExpression::Constant(Tensor value) {
return e;
}
XlaExpression XlaExpression::ConstantResource(Tensor value,
XlaResource* resource) {
XlaExpression e;
e.kind_ = Kind::kResource;
e.dtype_ = DT_RESOURCE;
e.resource_ = resource;
e.constant_value_ = value;
return e;
}
XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
XlaExpression e;
e.kind_ = Kind::kXlaOp;
@ -83,7 +93,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
case Kind::kConstant: {
xla::BorrowingLiteral literal;
TF_RETURN_IF_ERROR(
HostTensorToBorrowingLiteral(constant_value_, &literal));
HostTensorToBorrowingLiteral(*constant_value_, &literal));
return xla::ConstantLiteral(builder, literal);
}
case Kind::kTensorList:
@ -106,7 +116,7 @@ xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
switch (kind()) {
case Kind::kConstant: {
// Constant values are considered static.
Tensor constant_false(DT_BOOL, constant_value().shape());
Tensor constant_false(DT_BOOL, constant_value()->shape());
auto flat = constant_false.flat<bool>();
for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
return constant_false;
@ -147,13 +157,12 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
xla::Client* client, bool dynamic_dimension_is_minus_one) const {
switch (kind()) {
case Kind::kConstant:
return {constant_value()};
case Kind::kResource:
return constant_value();
case Kind::kXlaOp:
break;
case Kind::kTensorList:
TF_FALLTHROUGH_INTENDED;
case Kind::kResource:
TF_FALLTHROUGH_INTENDED;
case Kind::kInvalid:
return errors::InvalidArgument(
"ResolveConstant called on XlaExpression: ", HumanString());
@ -187,7 +196,12 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
switch (kind_) {
case Kind::kConstant:
return constant_value().shape();
return constant_value()->shape();
case Kind::kResource:
if (constant_value()) {
return constant_value()->shape();
}
return TensorShape({});
case Kind::kXlaOp: {
TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
handle().builder()->GetShape(handle()));
@ -197,8 +211,6 @@ xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
}
case Kind::kTensorList:
return TensorShape({});
case Kind::kResource:
return TensorShape({});
case Kind::kInvalid:
return errors::InvalidArgument(
"GetShape() called on invalid XlaExpression");

View File

@ -74,6 +74,9 @@ class XlaExpression {
// Builds a resource expression.
static XlaExpression Resource(XlaResource* resource);
// Builds a resource whose value is known at a compile time.
static XlaExpression ConstantResource(Tensor value, XlaResource* resource);
Kind kind() const { return kind_; }
DataType dtype() const { return dtype_; }
@ -81,7 +84,15 @@ class XlaExpression {
// handle() returns the XlaOp that backs a kXlaOp expression.
const xla::XlaOp& handle() const { return handle_; }
const Tensor& constant_value() const { return constant_value_; }
// Return a constant value associated with this expression. Always set for
// constants, might be set for resources.
absl::optional<Tensor> constant_value() const {
if (kind_ == Kind::kResource && resource_->IsOverwritten()) {
// The constant is no longer available if the value was overwritten.
return absl::nullopt;
}
return constant_value_;
}
XlaResource* resource() const { return resource_; }
@ -124,8 +135,8 @@ class XlaExpression {
// a tuple expression if kind_ == kTensorList.
xla::XlaOp handle_;
// The value of the constant, if kind_ == kConstant.
Tensor constant_value_;
// The value of the constant, if available.
absl::optional<Tensor> constant_value_;
// The resource, if kind_ == kResource. Not owned.
XlaResource* resource_ = nullptr;

View File

@ -110,8 +110,10 @@ TEST_F(XlaExpressionTest, GetShape) {
TEST_F(XlaExpressionTest, ResolveConstant) {
EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
EXPECT_FALSE(
XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok());
EXPECT_FALSE(XlaExpression::Resource(resource_.get())
.ResolveConstant(client_)
->has_value());
TF_ASSERT_OK_AND_ASSIGN(
absl::optional<Tensor> op_constant,
@ -131,5 +133,17 @@ TEST_F(XlaExpressionTest, ResolveConstant) {
test::ExpectTensorEqual<int32>(constant_, *constant_constant);
}
TEST_F(XlaExpressionTest, ResolveConstantOnResource) {
XlaExpression constant_resource =
XlaExpression::ConstantResource(constant_, resource_.get());
EXPECT_TRUE(constant_resource.ResolveConstant(client_).ok());
EXPECT_TRUE(resource_->SetZeroValue(builder_.get()).ok());
LOG(ERROR) << "Resource is overwritten: " << resource_->IsOverwritten();
xla::StatusOr<absl::optional<Tensor>> resolved_constant =
constant_resource.ResolveConstant(client_);
EXPECT_TRUE(resolved_constant.ok());
EXPECT_FALSE(resolved_constant->has_value());
}
} // namespace
} // namespace tensorflow

View File

@ -477,6 +477,13 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
*shape = variable->shape();
}
if (!variable->IsOverwritten() && expression->constant_value()) {
TF_ASSIGN_OR_RETURN(xla::Literal literal,
HostTensorToLiteral(*expression->constant_value()));
*value = xla::ConstantLiteral(ctx->builder(), literal);
return Status::OK();
}
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
ctx->compiler()->options().shape_representation_fn(
variable->shape(), variable->type(),

View File

@ -116,10 +116,12 @@ Status XlaResource::SetValue(const xla::XlaOp& value) {
"' must be initialized with a valid type before use.");
}
value_ = value;
is_overwritten_ = true;
return Status::OK();
}
Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
is_overwritten_ = true;
if (type_ == DT_INVALID) {
return errors::InvalidArgument(
"Resource '", name_,

View File

@ -135,6 +135,8 @@ class XlaResource {
Status SetFromPack(const std::set<string>& gradient_sources,
const xla::XlaOp& pack, xla::XlaBuilder* builder);
bool IsOverwritten() { return is_overwritten_; }
// TensorArray and Stack specific fields
// TODO(phawkins): refactor this code to use subclasses, rather than putting
// kind-specific fields in XlaResource.
@ -179,6 +181,7 @@ class XlaResource {
bool tensor_array_multiple_writes_aggregate_ = false;
std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
bool is_overwritten_ = false;
};
} // namespace tensorflow

View File

@ -656,6 +656,77 @@ class DefFunctionTest(xla_test.XLATestCase):
self.assertIn('tuple',
f.experimental_get_compiler_ir(l)())
@test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
'support getting constants out of resources')
def testGetConstantOutOfResourceVariable(self):
with ops.device('device:{}:0'.format(self.device)):
# Use floats to force device placement.
a = variables.Variable(50.0)
b = variables.Variable(2.0)
@def_function.function(jit_compile=True)
def f(x):
return array_ops.reshape(
x, [math_ops.cast(a, dtypes.int32),
math_ops.cast(b, dtypes.int32)])
# OK since the value is known at compile time.
out = f(random_ops.random_normal([10, 10]))
self.assertEqual(out.shape[0], 50)
self.assertEqual(out.shape[1], 2)
@test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
'support getting constants out of resources')
def testGetConstantOutOfResourceVariableAfterWrite(self):
with ops.device('device:{}:0'.format(self.device)):
# Use floats to force device placement.
a = variables.Variable(50.0)
b = variables.Variable(2.0)
@def_function.function(jit_compile=True)
def f(x, val1, val2):
a.assign(math_ops.cast(val1, dtypes.float32))
b.assign(math_ops.cast(val2, dtypes.float32))
return array_ops.reshape(
x, [math_ops.cast(a, dtypes.int32),
math_ops.cast(b, dtypes.int32)])
val1 = constant_op.constant(2)
val2 = constant_op.constant(50)
# Returns an error, since the value known at compile time was overriden.
with self.assertRaisesRegex(errors.InvalidArgumentError,
'concrete values at compile time'):
f(random_ops.random_normal([10, 10]), val1, val2)
@test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
'support getting constants out of resources')
def testGetConstantOutOfResourceVariableBeforeWrite(self):
with ops.device('device:{}:0'.format(self.device)):
# Use floats to force device placement.
a = variables.Variable(50.0)
b = variables.Variable(2.0)
@def_function.function(jit_compile=True)
def f(x, val1, val2):
out = array_ops.reshape(
x, [math_ops.cast(a, dtypes.int32),
math_ops.cast(b, dtypes.int32)])
a.assign(math_ops.cast(val1, dtypes.float32))
b.assign(math_ops.cast(val2, dtypes.float32))
return out
val1 = constant_op.constant(2)
val2 = constant_op.constant(50)
# OK since the write happens after the reshape.
out = f(random_ops.random_normal([10, 10]), val1, val2)
self.assertEqual(out.shape[0], 50)
self.assertEqual(out.shape[1], 2)
if __name__ == '__main__':
ops.enable_eager_execution()