[TF2XLA] Support must-be-constant resource variables for compilation
Performs an explicit copy at runtime from device to host if needed. PiperOrigin-RevId: 341491694 Change-Id: If4a6c0c76a1110637a06e96595c6013c8fac17e5
This commit is contained in:
parent
aed0ad4916
commit
38c53e2f59
@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr(
|
|||||||
|
|
||||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||||
constant_arg_indices, inputs, variable_infos);
|
constant_arg_indices, inputs, variable_infos, dev);
|
||||||
TF_RETURN_IF_ERROR(args.status());
|
TF_RETURN_IF_ERROR(args.status());
|
||||||
|
|
||||||
switch (stage) {
|
switch (stage) {
|
||||||
|
@ -206,8 +206,9 @@ static Status CompileToLocalExecutable(
|
|||||||
may_alias_resource_update;
|
may_alias_resource_update;
|
||||||
|
|
||||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
|
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||||
variable_infos);
|
constants, inputs, variable_infos,
|
||||||
|
static_cast<Device*>(ctx->device()));
|
||||||
TF_RETURN_IF_ERROR(args.status());
|
TF_RETURN_IF_ERROR(args.status());
|
||||||
return cache->Compile(options, function, *args, compile_options,
|
return cache->Compile(options, function, *args, compile_options,
|
||||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||||
@ -246,8 +247,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
se::Stream* stream =
|
se::Stream* stream =
|
||||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||||
|
|
||||||
VLOG(1) << "Executing XLA Computation...";
|
|
||||||
|
|
||||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||||
&tf_allocator_adapter, ctx->device(),
|
&tf_allocator_adapter, ctx->device(),
|
||||||
|
@ -140,6 +140,7 @@ XlaCompilationCache::BuildSignature(
|
|||||||
for (const XlaCompiler::Argument& arg : args) {
|
for (const XlaCompiler::Argument& arg : args) {
|
||||||
switch (arg.kind) {
|
switch (arg.kind) {
|
||||||
case XlaCompiler::Argument::kConstant:
|
case XlaCompiler::Argument::kConstant:
|
||||||
|
case XlaCompiler::Argument::kConstantResource:
|
||||||
signature.arg_values.push_back(arg.constant_value);
|
signature.arg_values.push_back(arg.constant_value);
|
||||||
break;
|
break;
|
||||||
case XlaCompiler::Argument::kParameter:
|
case XlaCompiler::Argument::kParameter:
|
||||||
|
@ -153,7 +153,8 @@ Status XlaCompileOnDemandOp::Compile(
|
|||||||
ctx, variables_indices, variable_infos, variable_args));
|
ctx, variables_indices, variable_infos, variable_args));
|
||||||
|
|
||||||
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||||
constant_input_indices, inputs, variable_infos);
|
constant_input_indices, inputs, variable_infos,
|
||||||
|
static_cast<Device*>(ctx->device()));
|
||||||
TF_RETURN_IF_ERROR(args.status());
|
TF_RETURN_IF_ERROR(args.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -564,11 +564,26 @@ xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
|||||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||||
absl::Span<int const> must_be_constant_idxs,
|
absl::Span<int const> must_be_constant_idxs,
|
||||||
absl::Span<const Tensor* const> inputs,
|
absl::Span<const Tensor* const> inputs,
|
||||||
absl::Span<VariableInfo const> variable_args) {
|
absl::Span<VariableInfo const> variable_args, Device* device) {
|
||||||
CHECK(absl::c_is_sorted(must_be_constant_idxs));
|
CHECK(absl::c_is_sorted(must_be_constant_idxs));
|
||||||
std::vector<XlaCompiler::Argument> out;
|
std::vector<XlaCompiler::Argument> out;
|
||||||
out.resize(inputs.size());
|
out.resize(inputs.size());
|
||||||
|
|
||||||
|
// TODO(cheshire): Avoid duplication with framework/op_kernel.h
|
||||||
|
DeviceContext* device_context = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
|
||||||
|
bool using_default_context = false;
|
||||||
|
auto cleanup = xla::MakeCleanup([&] {
|
||||||
|
if (device_context != nullptr && !using_default_context) {
|
||||||
|
device_context->Unref();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (device_context == nullptr) {
|
||||||
|
using_default_context = true;
|
||||||
|
auto* dev_info = device->tensorflow_gpu_device_info();
|
||||||
|
if (dev_info) device_context = dev_info->default_context;
|
||||||
|
}
|
||||||
|
|
||||||
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
||||||
for (const VariableInfo& info : variable_args) {
|
for (const VariableInfo& info : variable_args) {
|
||||||
CHECK(!info.var() || info.lock_held())
|
CHECK(!info.var() || info.lock_held())
|
||||||
@ -581,18 +596,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
|||||||
const Tensor* input = inputs[input_num];
|
const Tensor* input = inputs[input_num];
|
||||||
|
|
||||||
XlaCompiler::Argument& arg = out[input_num];
|
XlaCompiler::Argument& arg = out[input_num];
|
||||||
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
if (variable_info_lookup.count(input_num)) {
|
||||||
// Handles compile-time constants.
|
|
||||||
|
|
||||||
// TODO(b/157241314): Support constants located in resource variables.
|
|
||||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE)
|
|
||||||
<< "tf2xla bridge does not support must-be-constants located in "
|
|
||||||
"resource variables; try moving them to a tensor";
|
|
||||||
arg.kind = XlaCompiler::Argument::kConstant;
|
|
||||||
arg.type = input->dtype();
|
|
||||||
arg.shape = input->shape();
|
|
||||||
arg.constant_value = *input;
|
|
||||||
} else if (variable_info_lookup.count(input_num)) {
|
|
||||||
// Handles resource variables.
|
// Handles resource variables.
|
||||||
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
|
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
|
||||||
const VariableInfo& variable = *variable_info_lookup[input_num];
|
const VariableInfo& variable = *variable_info_lookup[input_num];
|
||||||
@ -613,6 +617,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
|||||||
arg.type = DT_INVALID;
|
arg.type = DT_INVALID;
|
||||||
arg.shape = TensorShape();
|
arg.shape = TensorShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||||
|
TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
|
||||||
|
const Tensor* value = variable.var()->tensor();
|
||||||
|
Tensor value_on_host(value->dtype(), value->shape());
|
||||||
|
if (!device_context) {
|
||||||
|
value_on_host = *value;
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
|
||||||
|
value, "", device, &value_on_host));
|
||||||
|
}
|
||||||
|
arg.kind = XlaCompiler::Argument::kConstantResource;
|
||||||
|
arg.constant_value = value_on_host;
|
||||||
|
}
|
||||||
|
} else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||||
|
arg.kind = XlaCompiler::Argument::kConstant;
|
||||||
|
arg.type = input->dtype();
|
||||||
|
arg.shape = input->shape();
|
||||||
|
arg.constant_value = *input;
|
||||||
} else {
|
} else {
|
||||||
// Normal inputs.
|
// Normal inputs.
|
||||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
||||||
|
@ -143,7 +143,8 @@ class XlaComputationLaunchContext {
|
|||||||
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||||
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
|
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
|
||||||
absl::Span<const Tensor* const> inputs,
|
absl::Span<const Tensor* const> inputs,
|
||||||
absl::Span<VariableInfo const> variable_args);
|
absl::Span<VariableInfo const> variable_args,
|
||||||
|
Device* device);
|
||||||
|
|
||||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||||
// `variables` is a map from TensorFlow argument number to resource variable.
|
// `variables` is a map from TensorFlow argument number to resource variable.
|
||||||
|
@ -73,7 +73,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
|||||||
switch (expressions[i]->kind()) {
|
switch (expressions[i]->kind()) {
|
||||||
case XlaExpression::Kind::kConstant:
|
case XlaExpression::Kind::kConstant:
|
||||||
arg.kind = XlaCompiler::Argument::kConstant;
|
arg.kind = XlaCompiler::Argument::kConstant;
|
||||||
arg.constant_value = expressions[i]->constant_value();
|
arg.constant_value = *expressions[i]->constant_value();
|
||||||
break;
|
break;
|
||||||
case XlaExpression::Kind::kXlaOp:
|
case XlaExpression::Kind::kXlaOp:
|
||||||
if (arg_must_be_compile_time_constant[i]) {
|
if (arg_must_be_compile_time_constant[i]) {
|
||||||
|
@ -39,6 +39,9 @@ struct XlaArgument {
|
|||||||
// associated runtime parameter iff `initialized` is true.
|
// associated runtime parameter iff `initialized` is true.
|
||||||
kResource,
|
kResource,
|
||||||
|
|
||||||
|
// A resource variable with a constant value known at compile time.
|
||||||
|
kConstantResource,
|
||||||
|
|
||||||
// Argument is a run-time parameter.
|
// Argument is a run-time parameter.
|
||||||
kParameter,
|
kParameter,
|
||||||
|
|
||||||
|
@ -207,7 +207,7 @@ Status BuildComputation(
|
|||||||
switch (retval.kind()) {
|
switch (retval.kind()) {
|
||||||
case XlaExpression::Kind::kConstant:
|
case XlaExpression::Kind::kConstant:
|
||||||
output.is_constant = true;
|
output.is_constant = true;
|
||||||
output.constant_value = retval.constant_value();
|
output.constant_value = *retval.constant_value();
|
||||||
output.shape = output.constant_value.shape();
|
output.shape = output.constant_value.shape();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
@ -446,6 +446,9 @@ string XlaCompiler::Argument::HumanString() const {
|
|||||||
case kConstant:
|
case kConstant:
|
||||||
return absl::StrCat("kind=constant", common,
|
return absl::StrCat("kind=constant", common,
|
||||||
" value=", constant_value.DebugString());
|
" value=", constant_value.DebugString());
|
||||||
|
case kConstantResource:
|
||||||
|
return absl::StrCat("kind=constant-resource", common,
|
||||||
|
" value=", constant_value.DebugString());
|
||||||
case kResource: {
|
case kResource: {
|
||||||
string output = absl::StrCat(
|
string output = absl::StrCat(
|
||||||
"kind=resource", common,
|
"kind=resource", common,
|
||||||
@ -856,6 +859,7 @@ Status XlaCompiler::XLAShapeForArgument(
|
|||||||
*xla_shape = absl::get<xla::Shape>(arg.shape);
|
*xla_shape = absl::get<xla::Shape>(arg.shape);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
case XlaCompiler::Argument::kConstantResource:
|
||||||
case XlaCompiler::Argument::kResource: {
|
case XlaCompiler::Argument::kResource: {
|
||||||
TF_RET_CHECK(arg.initialized);
|
TF_RET_CHECK(arg.initialized);
|
||||||
|
|
||||||
@ -959,6 +963,7 @@ Status XlaCompiler::BuildArguments(
|
|||||||
const XlaCompiler::Argument& arg = args[i];
|
const XlaCompiler::Argument& arg = args[i];
|
||||||
XlaExpression& arg_expression = (*arg_expressions)[i];
|
XlaExpression& arg_expression = (*arg_expressions)[i];
|
||||||
switch (arg.kind) {
|
switch (arg.kind) {
|
||||||
|
case XlaCompiler::Argument::kConstantResource:
|
||||||
case XlaCompiler::Argument::kResource: {
|
case XlaCompiler::Argument::kResource: {
|
||||||
TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
|
TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
|
||||||
TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
|
TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
|
||||||
@ -971,7 +976,10 @@ Status XlaCompiler::BuildArguments(
|
|||||||
/*max_array_size=*/arg.max_array_size,
|
/*max_array_size=*/arg.max_array_size,
|
||||||
/*tensor_array_gradients=*/arg.tensor_array_gradients,
|
/*tensor_array_gradients=*/arg.tensor_array_gradients,
|
||||||
/*tensor_array_multiple_writes_aggregate=*/true));
|
/*tensor_array_multiple_writes_aggregate=*/true));
|
||||||
arg_expression = XlaExpression::Resource(resource);
|
arg_expression =
|
||||||
|
arg.kind == XlaCompiler::Argument::kResource
|
||||||
|
? XlaExpression::Resource(resource)
|
||||||
|
: XlaExpression::ConstantResource(arg.constant_value, resource);
|
||||||
if (arg.initialized) {
|
if (arg.initialized) {
|
||||||
input_to_args->push_back(i);
|
input_to_args->push_back(i);
|
||||||
}
|
}
|
||||||
@ -1124,6 +1132,7 @@ Status XlaCompiler::BuildArguments(
|
|||||||
arg_shardings.at(i).DebugString()));
|
arg_shardings.at(i).DebugString()));
|
||||||
XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
|
XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
|
||||||
switch (arg.kind) {
|
switch (arg.kind) {
|
||||||
|
case XlaCompiler::Argument::kConstantResource:
|
||||||
case XlaCompiler::Argument::kResource: {
|
case XlaCompiler::Argument::kResource: {
|
||||||
TF_RET_CHECK(arg.initialized);
|
TF_RET_CHECK(arg.initialized);
|
||||||
XlaResource* resource = arg_expression.resource();
|
XlaResource* resource = arg_expression.resource();
|
||||||
|
@ -38,6 +38,16 @@ XlaExpression XlaExpression::Constant(Tensor value) {
|
|||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaExpression XlaExpression::ConstantResource(Tensor value,
|
||||||
|
XlaResource* resource) {
|
||||||
|
XlaExpression e;
|
||||||
|
e.kind_ = Kind::kResource;
|
||||||
|
e.dtype_ = DT_RESOURCE;
|
||||||
|
e.resource_ = resource;
|
||||||
|
e.constant_value_ = value;
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
|
||||||
XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
|
XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
|
||||||
XlaExpression e;
|
XlaExpression e;
|
||||||
e.kind_ = Kind::kXlaOp;
|
e.kind_ = Kind::kXlaOp;
|
||||||
@ -83,7 +93,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
|
|||||||
case Kind::kConstant: {
|
case Kind::kConstant: {
|
||||||
xla::BorrowingLiteral literal;
|
xla::BorrowingLiteral literal;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
HostTensorToBorrowingLiteral(constant_value_, &literal));
|
HostTensorToBorrowingLiteral(*constant_value_, &literal));
|
||||||
return xla::ConstantLiteral(builder, literal);
|
return xla::ConstantLiteral(builder, literal);
|
||||||
}
|
}
|
||||||
case Kind::kTensorList:
|
case Kind::kTensorList:
|
||||||
@ -106,7 +116,7 @@ xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
|
|||||||
switch (kind()) {
|
switch (kind()) {
|
||||||
case Kind::kConstant: {
|
case Kind::kConstant: {
|
||||||
// Constant values are considered static.
|
// Constant values are considered static.
|
||||||
Tensor constant_false(DT_BOOL, constant_value().shape());
|
Tensor constant_false(DT_BOOL, constant_value()->shape());
|
||||||
auto flat = constant_false.flat<bool>();
|
auto flat = constant_false.flat<bool>();
|
||||||
for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
|
for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
|
||||||
return constant_false;
|
return constant_false;
|
||||||
@ -147,13 +157,12 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
|
|||||||
xla::Client* client, bool dynamic_dimension_is_minus_one) const {
|
xla::Client* client, bool dynamic_dimension_is_minus_one) const {
|
||||||
switch (kind()) {
|
switch (kind()) {
|
||||||
case Kind::kConstant:
|
case Kind::kConstant:
|
||||||
return {constant_value()};
|
case Kind::kResource:
|
||||||
|
return constant_value();
|
||||||
case Kind::kXlaOp:
|
case Kind::kXlaOp:
|
||||||
break;
|
break;
|
||||||
case Kind::kTensorList:
|
case Kind::kTensorList:
|
||||||
TF_FALLTHROUGH_INTENDED;
|
TF_FALLTHROUGH_INTENDED;
|
||||||
case Kind::kResource:
|
|
||||||
TF_FALLTHROUGH_INTENDED;
|
|
||||||
case Kind::kInvalid:
|
case Kind::kInvalid:
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"ResolveConstant called on XlaExpression: ", HumanString());
|
"ResolveConstant called on XlaExpression: ", HumanString());
|
||||||
@ -187,7 +196,12 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
|
|||||||
xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
|
xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
|
||||||
switch (kind_) {
|
switch (kind_) {
|
||||||
case Kind::kConstant:
|
case Kind::kConstant:
|
||||||
return constant_value().shape();
|
return constant_value()->shape();
|
||||||
|
case Kind::kResource:
|
||||||
|
if (constant_value()) {
|
||||||
|
return constant_value()->shape();
|
||||||
|
}
|
||||||
|
return TensorShape({});
|
||||||
case Kind::kXlaOp: {
|
case Kind::kXlaOp: {
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
|
TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
|
||||||
handle().builder()->GetShape(handle()));
|
handle().builder()->GetShape(handle()));
|
||||||
@ -197,8 +211,6 @@ xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
|
|||||||
}
|
}
|
||||||
case Kind::kTensorList:
|
case Kind::kTensorList:
|
||||||
return TensorShape({});
|
return TensorShape({});
|
||||||
case Kind::kResource:
|
|
||||||
return TensorShape({});
|
|
||||||
case Kind::kInvalid:
|
case Kind::kInvalid:
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"GetShape() called on invalid XlaExpression");
|
"GetShape() called on invalid XlaExpression");
|
||||||
|
@ -74,6 +74,9 @@ class XlaExpression {
|
|||||||
// Builds a resource expression.
|
// Builds a resource expression.
|
||||||
static XlaExpression Resource(XlaResource* resource);
|
static XlaExpression Resource(XlaResource* resource);
|
||||||
|
|
||||||
|
// Builds a resource whose value is known at a compile time.
|
||||||
|
static XlaExpression ConstantResource(Tensor value, XlaResource* resource);
|
||||||
|
|
||||||
Kind kind() const { return kind_; }
|
Kind kind() const { return kind_; }
|
||||||
|
|
||||||
DataType dtype() const { return dtype_; }
|
DataType dtype() const { return dtype_; }
|
||||||
@ -81,7 +84,15 @@ class XlaExpression {
|
|||||||
// handle() returns the XlaOp that backs a kXlaOp expression.
|
// handle() returns the XlaOp that backs a kXlaOp expression.
|
||||||
const xla::XlaOp& handle() const { return handle_; }
|
const xla::XlaOp& handle() const { return handle_; }
|
||||||
|
|
||||||
const Tensor& constant_value() const { return constant_value_; }
|
// Return a constant value associated with this expression. Always set for
|
||||||
|
// constants, might be set for resources.
|
||||||
|
absl::optional<Tensor> constant_value() const {
|
||||||
|
if (kind_ == Kind::kResource && resource_->IsOverwritten()) {
|
||||||
|
// The constant is no longer available if the value was overwritten.
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
return constant_value_;
|
||||||
|
}
|
||||||
|
|
||||||
XlaResource* resource() const { return resource_; }
|
XlaResource* resource() const { return resource_; }
|
||||||
|
|
||||||
@ -124,8 +135,8 @@ class XlaExpression {
|
|||||||
// a tuple expression if kind_ == kTensorList.
|
// a tuple expression if kind_ == kTensorList.
|
||||||
xla::XlaOp handle_;
|
xla::XlaOp handle_;
|
||||||
|
|
||||||
// The value of the constant, if kind_ == kConstant.
|
// The value of the constant, if available.
|
||||||
Tensor constant_value_;
|
absl::optional<Tensor> constant_value_;
|
||||||
|
|
||||||
// The resource, if kind_ == kResource. Not owned.
|
// The resource, if kind_ == kResource. Not owned.
|
||||||
XlaResource* resource_ = nullptr;
|
XlaResource* resource_ = nullptr;
|
||||||
|
@ -110,8 +110,10 @@ TEST_F(XlaExpressionTest, GetShape) {
|
|||||||
TEST_F(XlaExpressionTest, ResolveConstant) {
|
TEST_F(XlaExpressionTest, ResolveConstant) {
|
||||||
EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
|
EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
|
||||||
EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
|
EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
|
||||||
EXPECT_FALSE(
|
|
||||||
XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok());
|
EXPECT_FALSE(XlaExpression::Resource(resource_.get())
|
||||||
|
.ResolveConstant(client_)
|
||||||
|
->has_value());
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
absl::optional<Tensor> op_constant,
|
absl::optional<Tensor> op_constant,
|
||||||
@ -131,5 +133,17 @@ TEST_F(XlaExpressionTest, ResolveConstant) {
|
|||||||
test::ExpectTensorEqual<int32>(constant_, *constant_constant);
|
test::ExpectTensorEqual<int32>(constant_, *constant_constant);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaExpressionTest, ResolveConstantOnResource) {
|
||||||
|
XlaExpression constant_resource =
|
||||||
|
XlaExpression::ConstantResource(constant_, resource_.get());
|
||||||
|
EXPECT_TRUE(constant_resource.ResolveConstant(client_).ok());
|
||||||
|
EXPECT_TRUE(resource_->SetZeroValue(builder_.get()).ok());
|
||||||
|
LOG(ERROR) << "Resource is overwritten: " << resource_->IsOverwritten();
|
||||||
|
xla::StatusOr<absl::optional<Tensor>> resolved_constant =
|
||||||
|
constant_resource.ResolveConstant(client_);
|
||||||
|
EXPECT_TRUE(resolved_constant.ok());
|
||||||
|
EXPECT_FALSE(resolved_constant->has_value());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -477,6 +477,13 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
|
|||||||
*shape = variable->shape();
|
*shape = variable->shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!variable->IsOverwritten() && expression->constant_value()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::Literal literal,
|
||||||
|
HostTensorToLiteral(*expression->constant_value()));
|
||||||
|
*value = xla::ConstantLiteral(ctx->builder(), literal);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
|
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
|
||||||
ctx->compiler()->options().shape_representation_fn(
|
ctx->compiler()->options().shape_representation_fn(
|
||||||
variable->shape(), variable->type(),
|
variable->shape(), variable->type(),
|
||||||
|
@ -116,10 +116,12 @@ Status XlaResource::SetValue(const xla::XlaOp& value) {
|
|||||||
"' must be initialized with a valid type before use.");
|
"' must be initialized with a valid type before use.");
|
||||||
}
|
}
|
||||||
value_ = value;
|
value_ = value;
|
||||||
|
is_overwritten_ = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
|
Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
|
||||||
|
is_overwritten_ = true;
|
||||||
if (type_ == DT_INVALID) {
|
if (type_ == DT_INVALID) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Resource '", name_,
|
"Resource '", name_,
|
||||||
|
@ -135,6 +135,8 @@ class XlaResource {
|
|||||||
Status SetFromPack(const std::set<string>& gradient_sources,
|
Status SetFromPack(const std::set<string>& gradient_sources,
|
||||||
const xla::XlaOp& pack, xla::XlaBuilder* builder);
|
const xla::XlaOp& pack, xla::XlaBuilder* builder);
|
||||||
|
|
||||||
|
bool IsOverwritten() { return is_overwritten_; }
|
||||||
|
|
||||||
// TensorArray and Stack specific fields
|
// TensorArray and Stack specific fields
|
||||||
// TODO(phawkins): refactor this code to use subclasses, rather than putting
|
// TODO(phawkins): refactor this code to use subclasses, rather than putting
|
||||||
// kind-specific fields in XlaResource.
|
// kind-specific fields in XlaResource.
|
||||||
@ -179,6 +181,7 @@ class XlaResource {
|
|||||||
bool tensor_array_multiple_writes_aggregate_ = false;
|
bool tensor_array_multiple_writes_aggregate_ = false;
|
||||||
|
|
||||||
std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
|
std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
|
||||||
|
bool is_overwritten_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -656,6 +656,77 @@ class DefFunctionTest(xla_test.XLATestCase):
|
|||||||
self.assertIn('tuple',
|
self.assertIn('tuple',
|
||||||
f.experimental_get_compiler_ir(l)())
|
f.experimental_get_compiler_ir(l)())
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
|
||||||
|
'support getting constants out of resources')
|
||||||
|
def testGetConstantOutOfResourceVariable(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
# Use floats to force device placement.
|
||||||
|
a = variables.Variable(50.0)
|
||||||
|
b = variables.Variable(2.0)
|
||||||
|
|
||||||
|
@def_function.function(jit_compile=True)
|
||||||
|
def f(x):
|
||||||
|
return array_ops.reshape(
|
||||||
|
x, [math_ops.cast(a, dtypes.int32),
|
||||||
|
math_ops.cast(b, dtypes.int32)])
|
||||||
|
|
||||||
|
# OK since the value is known at compile time.
|
||||||
|
out = f(random_ops.random_normal([10, 10]))
|
||||||
|
self.assertEqual(out.shape[0], 50)
|
||||||
|
self.assertEqual(out.shape[1], 2)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
|
||||||
|
'support getting constants out of resources')
|
||||||
|
def testGetConstantOutOfResourceVariableAfterWrite(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
# Use floats to force device placement.
|
||||||
|
a = variables.Variable(50.0)
|
||||||
|
b = variables.Variable(2.0)
|
||||||
|
|
||||||
|
@def_function.function(jit_compile=True)
|
||||||
|
def f(x, val1, val2):
|
||||||
|
a.assign(math_ops.cast(val1, dtypes.float32))
|
||||||
|
b.assign(math_ops.cast(val2, dtypes.float32))
|
||||||
|
return array_ops.reshape(
|
||||||
|
x, [math_ops.cast(a, dtypes.int32),
|
||||||
|
math_ops.cast(b, dtypes.int32)])
|
||||||
|
|
||||||
|
val1 = constant_op.constant(2)
|
||||||
|
val2 = constant_op.constant(50)
|
||||||
|
|
||||||
|
# Returns an error, since the value known at compile time was overriden.
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
'concrete values at compile time'):
|
||||||
|
f(random_ops.random_normal([10, 10]), val1, val2)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
|
||||||
|
'support getting constants out of resources')
|
||||||
|
def testGetConstantOutOfResourceVariableBeforeWrite(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
# Use floats to force device placement.
|
||||||
|
a = variables.Variable(50.0)
|
||||||
|
b = variables.Variable(2.0)
|
||||||
|
|
||||||
|
@def_function.function(jit_compile=True)
|
||||||
|
def f(x, val1, val2):
|
||||||
|
out = array_ops.reshape(
|
||||||
|
x, [math_ops.cast(a, dtypes.int32),
|
||||||
|
math_ops.cast(b, dtypes.int32)])
|
||||||
|
a.assign(math_ops.cast(val1, dtypes.float32))
|
||||||
|
b.assign(math_ops.cast(val2, dtypes.float32))
|
||||||
|
return out
|
||||||
|
|
||||||
|
val1 = constant_op.constant(2)
|
||||||
|
val2 = constant_op.constant(50)
|
||||||
|
|
||||||
|
# OK since the write happens after the reshape.
|
||||||
|
out = f(random_ops.random_normal([10, 10]), val1, val2)
|
||||||
|
self.assertEqual(out.shape[0], 50)
|
||||||
|
self.assertEqual(out.shape[1], 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user