[TF:XLA] Cleanups to handling of arguments during XLA compilation:

* combine resource kinds in XlaCompiler::Argument::Kind, use a separate XlaResource::Kind field to distinguish different kinds of resource.
* merge XlaContext::HandleOrConstant and XlaExpression, which were almost identical.
* remove XlaContext::Argument; instead, build XlaExpressions directly from XlaCompiler and add them to the XlaContext.

PiperOrigin-RevId: 168439341
This commit is contained in:
Peter Hawkins 2017-09-12 13:56:34 -07:00 committed by TensorFlower Gardener
parent 7f5346a809
commit 8f37f30027
8 changed files with 102 additions and 155 deletions

View File

@ -184,7 +184,8 @@ Status BuildArguments(int num_constant_args,
XlaCompiler::Argument& arg = (*args)[input_num];
arg.name = variable_args[variable_id].name;
arg.kind = XlaCompiler::Argument::kVariable;
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable;
if (variable_args[variable_id].present) {
const Tensor& value = variable_args[variable_id].value;
arg.type = value.dtype();

View File

@ -21,14 +21,13 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
namespace {
// This OpKernel implements the _Arg Op for XLA JIT devices. It
// associates its output with one of the arguments to a
// subcomputation.
class ArgOp : public XlaOpKernel {
class XlaArgOp : public XlaOpKernel {
public:
explicit ArgOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
explicit XlaArgOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
}
@ -49,35 +48,13 @@ class ArgOp : public XlaOpKernel {
return;
}
XlaContext& xc = XlaContext::Get(ctx);
const XlaContext::Argument& arg = xc.args()[index_];
if (arg.is_resource) {
XlaResource::Kind kind;
switch (arg.kind) {
case XlaCompiler::Argument::kVariable:
kind = XlaResource::kVariable;
break;
case XlaCompiler::Argument::kTensorArray:
kind = XlaResource::kTensorArray;
break;
case XlaCompiler::Argument::kStack:
kind = XlaResource::kStack;
break;
default:
CHECK(false);
}
// TODO(phawkins): this code assumes that variables do not alias.
XlaResource* resource;
OP_REQUIRES_OK(ctx,
xc.CreateResource(kind, index_, arg.name, arg.value.type,
arg.value.handle, &resource));
resource->tensor_array_size = arg.tensor_array_size;
ctx->SetResourceOutput(0, resource);
} else if (arg.value.is_constant) {
ctx->SetConstantOutput(0, arg.value.constant_value);
const XlaExpression& arg = XlaContext::Get(ctx).args()[index_];
if (arg.resource() != nullptr) {
ctx->SetResourceOutput(0, arg.resource());
} else if (arg.has_constant_value()) {
ctx->SetConstantOutput(0, arg.constant_value());
} else {
ctx->SetOutput(0, arg.value.handle);
ctx->SetOutput(0, arg.handle());
}
}
@ -85,10 +62,9 @@ class ArgOp : public XlaOpKernel {
int index_;
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp);
};
REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), ArgOp);
REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), XlaArgOp);
} // namespace
} // namespace tensorflow

View File

@ -50,19 +50,8 @@ Status MakeXlaCompilerArgumentsFromInputs(
TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource));
arg.initialized = resource->value.handle() > 0;
switch (resource->kind) {
case XlaResource::kVariable:
arg.kind = XlaCompiler::Argument::kVariable;
break;
case XlaResource::kTensorArray:
arg.kind = XlaCompiler::Argument::kTensorArray;
break;
case XlaResource::kStack:
arg.kind = XlaCompiler::Argument::kStack;
break;
case XlaResource::kInvalid:
CHECK(false);
}
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = resource->kind;
arg.type = resource->type;
if (arg.initialized) {
auto shape = ctx->builder()->GetShape(resource->value);

View File

@ -139,8 +139,6 @@ class XlaExpression {
Tensor constant_value_;
XlaResource* resource_ = nullptr; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
};
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/device.h"
@ -104,9 +105,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
XlaCompiler::~XlaCompiler() = default;
int64 XlaCompiler::NextStepId() {
return next_step_id_++;
}
int64 XlaCompiler::NextStepId() { return next_step_id_++; }
uint64 XlaCompiler::SignatureHash::operator()(
const std::pair<string, std::vector<Argument>>& signature) const {
@ -260,10 +259,11 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
// `args` are the arguments to the computation.
Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
bool use_tuple_arg, xla::ComputationBuilder* builder,
std::vector<XlaContext::Argument>* context_args,
XlaContext* context,
std::vector<XlaExpression>* arg_expressions,
std::vector<int>* input_mapping,
std::vector<xla::Shape>* input_shapes) {
context_args->resize(args.size());
arg_expressions->resize(args.size());
// Argument numbers of arguments and resources that are to be passed to the
// XLA computation as runtime parameters.
@ -271,33 +271,31 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
parameters.reserve(args.size());
resources.reserve(args.size());
// Fills in constant arguments, and computes non-constant argument order.
for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
++i) {
XlaContext::Argument& context_arg = (*context_args)[i];
context_arg.kind = args[i].kind;
context_arg.name = args[i].name;
context_arg.value.constant_value = args[i].constant_value;
context_arg.value.type = args[i].type;
switch (args[i].kind) {
case XlaCompiler::Argument::kVariable:
case XlaCompiler::Argument::kTensorArray:
case XlaCompiler::Argument::kStack:
context_arg.is_resource = true;
if (args[i].initialized) {
const XlaCompiler::Argument& arg = args[i];
XlaExpression& arg_expression = (*arg_expressions)[i];
switch (arg.kind) {
case XlaCompiler::Argument::kResource:
TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
// TODO(phawkins): this code assumes that resource arguments do not
// alias.
XlaResource* resource;
TF_RETURN_IF_ERROR(
context->CreateResource(arg.resource_kind, i, arg.name, arg.type,
xla::ComputationDataHandle(), &resource));
resource->tensor_array_size = arg.tensor_array_size;
arg_expression.set_resource(resource);
if (arg.initialized) {
resources.push_back(i);
context_arg.value.is_constant = false;
} else {
context_arg.value.is_constant = true;
}
context_arg.tensor_array_size = args[i].tensor_array_size;
break;
case XlaCompiler::Argument::kParameter:
parameters.push_back(i);
context_arg.value.is_constant = false;
break;
case XlaCompiler::Argument::kConstant:
context_arg.value.is_constant = true;
arg_expression.set_constant_value(arg.constant_value);
break;
case XlaCompiler::Argument::kInvalid:
return errors::Internal("Unreachable case in BuildArguments()");
@ -313,27 +311,48 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
input_shapes->resize(parameters.size());
input_mapping->resize(parameters.size());
for (std::vector<int>::size_type i = 0; i < input_shapes->size(); ++i) {
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const XlaCompiler::Argument& arg = args[parameters[i]];
// Computes the shapes of non-constant arguments.
(*input_shapes)[i] = arg.shape;
(*input_mapping)[i] = parameters[i];
}
// Build parameter handles for non-constant arguments.
std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
if (use_tuple_arg) {
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(*input_shapes);
xla::ComputationDataHandle tuple =
builder->Parameter(0, tuple_shape, "arg_tuple");
for (std::vector<int>::size_type i = 0; i < input_shapes->size(); ++i) {
(*context_args)[parameters[i]].value.handle =
builder->GetTupleElement(tuple, i);
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
arg_handles[i] = builder->GetTupleElement(tuple, i);
}
} else {
for (std::vector<int>::size_type i = 0; i < input_shapes->size(); ++i) {
(*context_args)[parameters[i]].value.handle =
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
arg_handles[i] =
builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i));
}
}
// Fill in the handles in non-constant arguments.
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const XlaCompiler::Argument& arg = args[parameters[i]];
XlaExpression& arg_expression = (*arg_expressions)[parameters[i]];
switch (arg.kind) {
case XlaCompiler::Argument::kResource:
TF_RET_CHECK(arg.initialized);
arg_expression.resource()->value = arg_handles[i];
arg_expression.resource()->initial_value = arg_handles[i];
break;
case XlaCompiler::Argument::kParameter:
arg_expression.set_handle(arg_handles[i]);
break;
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kInvalid:
return errors::Internal("Unreachable case in BuildArguments()");
}
}
return Status::OK();
}
@ -354,7 +373,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
// index of a resource variable argument to the computation, and `type` is the
// type of the final output.
Status BuildComputation(
const std::vector<XlaContext::HandleOrConstant>& retvals,
const std::vector<XlaExpression>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources,
bool has_side_effects, bool return_updated_values_for_all_resources,
xla::ComputationBuilder* builder, xla::Computation* computation,
@ -362,9 +381,9 @@ Status BuildComputation(
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
std::vector<xla::ComputationDataHandle> elems;
elems.reserve(retvals.size());
for (const XlaContext::HandleOrConstant& retval : retvals) {
if (!retval.is_constant) {
elems.push_back(retval.handle);
for (const XlaExpression& retval : retvals) {
if (!retval.has_constant_value()) {
elems.push_back(retval.handle());
}
}
*num_nonconst_outputs = elems.size();
@ -460,11 +479,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
result->tuple_arg = options.use_tuple_arg;
std::vector<XlaContext::Argument> context_args;
TF_RETURN_IF_ERROR(BuildArguments(args, options.use_tuple_arg, &builder,
&context_args, &result->input_mapping,
&result->xla_input_shapes));
context->set_args(std::move(context_args));
std::vector<XlaExpression> arg_expressions;
TF_RETURN_IF_ERROR(BuildArguments(
args, options.use_tuple_arg, &builder, context, &arg_expressions,
&result->input_mapping, &result->xla_input_shapes));
context->set_args(std::move(arg_expressions));
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
flib_runtime_, NextStepId()));
@ -486,14 +505,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
VLOG(2) << "Outputs: total: " << context->retvals().size()
<< " nonconstant: " << num_nonconst_outputs;
result->outputs.resize(context->retvals().size());
for (std::vector<XlaContext::HandleOrConstant>::size_type i = 0;
for (std::vector<XlaExpression>::size_type i = 0;
i < context->retvals().size(); ++i) {
const XlaContext::HandleOrConstant& retval = context->retvals()[i];
if (retval.is_constant) {
const XlaExpression& retval = context->retvals()[i];
if (retval.has_constant_value()) {
OutputDescription& output = result->outputs[i];
output.shape = retval.constant_value.shape();
output.shape = retval.constant_value().shape();
output.is_constant = true;
output.constant_value = retval.constant_value;
output.constant_value = retval.constant_value();
}
}
@ -518,10 +537,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
// Converts the output shapes to TensorShapes.
int computation_output = 0;
for (std::vector<XlaContext::HandleOrConstant>::size_type i = 0;
for (std::vector<XlaExpression>::size_type i = 0;
i < context->retvals().size(); ++i) {
const XlaContext::HandleOrConstant& retval = context->retvals()[i];
if (!retval.is_constant) {
const XlaExpression& retval = context->retvals()[i];
if (!retval.has_constant_value()) {
CHECK_LT(computation_output, num_computation_outputs);
OutputDescription& output = result->outputs[i];
output.is_constant = false;

View File

@ -85,17 +85,9 @@ class XlaCompiler {
// Argument is a compile-time constant. No associated runtime parameter.
kConstant,
// Argument is a Variable resource. Has an associated runtime parameter
// iff `initialized` is true.
kVariable,
// Argument is a TensorArray resource. Has an associated runtime parameter
// iff `initialized` is true.
kTensorArray,
// Argument is a Stack resource. Has an associated runtime parameter
// iff `initialized` is true.
kStack,
// Argument is a Variable, TensorArray, or Stack resource. Has an
// associated runtime parameter iff `initialized` is true.
kResource,
// Argument is a run-time parameter.
kParameter,
@ -118,11 +110,14 @@ class XlaCompiler {
// The name of this argument, used for debugging.
string name;
// For a kVariable or kTensorArray, has this resource been initialized?
// For a kResource, what kind of resource is it?
XlaResource::Kind resource_kind = XlaResource::kInvalid;
// For a kResource, has this resource been initialized?
bool initialized = false;
// For a kTensorArray, what is the array's declared size? (Used for lazy
// initialization.)
// For a TensorArray or Stack resource, what is the array's declared size?
// (Used for lazy initialization.)
int64 tensor_array_size = -1;
bool operator==(const Argument& other) const;

View File

@ -58,7 +58,7 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context";
return Get(ctx->op_kernel_context());
}
void XlaContext::set_args(std::vector<Argument> args) {
void XlaContext::set_args(std::vector<XlaExpression> args) {
args_ = std::move(args);
}
@ -78,8 +78,8 @@ XlaContext::GetOrCreateRuntimeContextParameter() {
// Allocate the next available parameter for the context parameter.
int num_parameters = 0;
for (const Argument& arg : args_) {
if (!arg.value.is_constant) {
for (const XlaExpression& arg : args_) {
if (!arg.has_constant_value()) {
++num_parameters;
}
}
@ -99,9 +99,7 @@ void XlaContext::AddRetval(int retval_index, DataType type,
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
retvals_[retval_index].is_constant = false;
retvals_[retval_index].type = type;
retvals_[retval_index].handle = handle;
retvals_[retval_index].set_handle(handle);
}
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
@ -111,14 +109,12 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
retvals_[retval_index].type = dtype;
if (resolve_compile_time_constants_) {
retvals_[retval_index].is_constant = true;
TF_RETURN_IF_ERROR(LiteralToHostTensor(
literal, dtype, &retvals_[retval_index].constant_value));
Tensor value;
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
retvals_[retval_index].set_constant_value(std::move(value));
} else {
retvals_[retval_index].is_constant = false;
retvals_[retval_index].handle = builder_->ConstantLiteral(literal);
retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal));
}
return Status::OK();
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
@ -37,34 +38,6 @@ class XlaOpKernelContext;
// subgraph of Ops using XLA.
class XlaContext : public ResourceBase {
public:
// A struct that represents either a compile-time constant, or an XLA
// computation handle. Used to represent arguments and return values.
struct HandleOrConstant {
// Is this a compile-time constant? If so, what is its value?
bool is_constant;
Tensor constant_value; // Must be in host memory.
// If this is not a constant, a computation handle. Since the mapping from
// Tensorflow types to XLA types is not necessarily injective (one-to-one),
// we also require the Tensorflow type.
DataType type;
xla::ComputationDataHandle handle;
};
struct Argument {
XlaCompiler::Argument::Kind kind;
// Descriptive name for the resource, for use in error messages.
string name;
// Is this a resource?
bool is_resource = false;
HandleOrConstant value;
int64 tensor_array_size = -1;
};
// Retrieves the XlaContext of the current compilation.
static XlaContext& Get(const OpKernelContext* ctx);
static XlaContext& Get(const XlaOpKernelContext* ctx);
@ -85,14 +58,14 @@ class XlaContext : public ResourceBase {
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
bool has_context_parameter() const { return has_context_parameter_; }
const std::vector<Argument>& args() const { return args_; }
void set_args(std::vector<Argument> args);
const std::vector<XlaExpression>& args() const { return args_; }
void set_args(std::vector<XlaExpression> args);
// Get the runtime context parameter, adding one if it does not already exist.
// Dies if not compiling a local executable.
const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter();
const std::vector<HandleOrConstant>& retvals() { return retvals_; }
const std::vector<XlaExpression>& retvals() { return retvals_; }
// This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph.
@ -155,10 +128,10 @@ class XlaContext : public ResourceBase {
// Arguments to the Tensorflow graph, indexed by _Arg index.
// Includes both compile-time constant arguments and runtime parameters.
std::vector<Argument> args_;
std::vector<XlaExpression> args_;
// Return values of the Tensorflow graph, indexed by _Retval index.
std::vector<HandleOrConstant> retvals_;
std::vector<XlaExpression> retvals_;
// Does the computation have side effects, i.e., Send() calls?
bool has_side_effects_ = false;