[TF:XLA] Refactor XlaContext, moving some of its reponsibilities to XlaCompiler and XlaOpKernelContext.

Move handling of arguments and return values to XlaCompiler. Introduce a new XlaContext::HandleOrConstant structure, use it for both arguments and results.
Make XlaCompiler own the xla::ComputationBuilder.

Move code for wrapping/unwrapping XlaExpressions in Tensors to XlaOpKernelContext, which is its only consumer.

No functional changes.
Change: 147250375
This commit is contained in:
Peter Hawkins 2017-02-11 10:42:18 -08:00 committed by TensorFlower Gardener
parent 196ce75c99
commit 4e24bec418
7 changed files with 261 additions and 349 deletions

View File

@ -104,12 +104,12 @@ class ArgOp : public XlaOpKernel {
OP_REQUIRES(ctx, 0 <= index_ && index_ < tc.args().size(),
errors::InvalidArgument("Invalid argument index ", index_));
const XlaCompiler::Argument& arg = tc.args()[index_];
if (arg.parameter < 0) {
const XlaContext::HandleOrConstant& arg = tc.args()[index_];
if (arg.is_constant) {
ctx->SetConstantOutput(0, arg.constant_value);
} else {
ctx->SetOutput(0, tc.parameter(arg.parameter));
ctx->SetOutput(0, arg.handle);
}
}

View File

@ -44,10 +44,9 @@ class XlaCompilationAllocator : public Allocator {
string Name() override { return "tla_jit"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
// Regardless of the size requested, always allocate a
// XlaExpression. Respect the aligment request because there is
// alignment checking even for Tensors whose data is never
// accessed.
// Regardless of the size requested, always allocates an XlaExpression.
// Respects the aligment request because there is alignment checking even
// for Tensors whose data is never accessed.
void* p = port::AlignedMalloc(sizeof(XlaExpression), alignment);
XlaExpression* expression = reinterpret_cast<XlaExpression*>(p);
new (expression) XlaExpression();

View File

@ -271,6 +271,109 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
return cleanup_status;
}
// Builds XLA computations for each of the arguments to the computation.
// `args` are the arguments to the computation. If `use_tuple_arg` is true, a
// single tuple parameter will be used for all arguments; if false, each
// argument gets its own parameter.
Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
bool use_tuple_arg, xla::ComputationBuilder* builder,
std::vector<XlaContext::HandleOrConstant>* context_args) {
context_args->resize(args.size());
// Computes the number of parameters, verifies that they are sequential
// starting from 0.
int num_parameters = 0;
for (int i = 0; i < args.size(); ++i) {
(*context_args)[i].is_constant = (args[i].parameter < 0);
(*context_args)[i].constant_value = args[i].constant_value;
if (args[i].parameter < 0) continue;
if (num_parameters != args[i].parameter) {
return errors::InvalidArgument(
"Parameter numbers to XLA compilation are not consecutive starting "
"from 0");
}
++num_parameters;
if (args[i].shape.num_elements() == 0) {
return errors::InvalidArgument(
"Non-constant argument must have a non-zero number of elements.");
}
}
if (num_parameters == 0) return Status::OK();
std::vector<xla::Shape> parameter_shapes(num_parameters);
for (int i = 0; i < args.size(); ++i) {
const XlaCompiler::Argument& arg = args[i];
if (arg.parameter < 0) continue;
// Computes the shapes of non-constant arguments.
xla::PrimitiveType type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type));
xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(),
&parameter_shapes[arg.parameter]);
}
if (use_tuple_arg && num_parameters > 0) {
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(parameter_shapes);
xla::ComputationDataHandle tuple =
builder->Parameter(0, tuple_shape, "arg_tuple");
for (int i = 0; i < args.size(); ++i) {
const XlaCompiler::Argument& arg = args[i];
if (arg.parameter < 0) continue;
(*context_args)[i].handle =
builder->GetTupleElement(tuple, arg.parameter);
}
} else {
for (int i = 0; i < args.size(); ++i) {
const XlaCompiler::Argument& arg = args[i];
if (arg.parameter < 0) continue;
(*context_args)[i].handle =
builder->Parameter(arg.parameter, parameter_shapes[arg.parameter],
strings::StrCat("arg", i));
}
}
return Status::OK();
}
// Builds the XLA computation. `retvals` is the list of retvals produced by
// _Retval operators, in index order. `has_side_effects` should be true if the
// computation has side effects and should be built even if it has no outputs.
// `num_nonconst_outputs` is set to the number of outputs of the `computation`.
Status BuildComputation(
const std::vector<XlaContext::HandleOrConstant>& retvals,
bool has_side_effects, xla::ComputationBuilder* builder,
xla::Computation* computation, int* num_nonconst_outputs) {
std::vector<xla::ComputationDataHandle> elems;
elems.reserve(retvals.size());
for (const XlaContext::HandleOrConstant& retval : retvals) {
if (!retval.is_constant) {
elems.push_back(retval.handle);
}
}
if (!elems.empty() || has_side_effects) {
// Builds a empty tuple return value for computations that have side effects
// but have no return values.
xla::ComputationDataHandle handle = builder->Tuple(elems);
// TODO(b/31775371): to workaround bug, we must build a no-op computation
// that is guaranteed to be constructed after all of the formal parameters
// to the computation. Once the bug is fixed, we could avoid tupling here.
if (elems.size() == 1) {
handle = builder->GetTupleElement(handle, 0);
}
// Builds the XLA computation.
xla::StatusOr<xla::Computation> computation_status = builder->Build();
if (!computation_status.ok()) {
return computation_status.status();
}
*computation = computation_status.ConsumeValueOrDie();
}
*num_nonconst_outputs = elems.size();
return Status::OK();
}
} // namespace
Status XlaCompiler::CompileGraph(string const& name,
@ -292,41 +395,41 @@ Status XlaCompiler::CompileGraph(string const& name,
args[i].type, args[i].shape, &result->xla_input_shapes.back().second));
}
XlaContext* xla_context =
new XlaContext(this, client(), name, allow_cpu_custom_calls_,
resolve_compile_time_constants_);
core::ScopedUnref xla_context_unref(xla_context);
xla::ComputationBuilder builder(client(), name);
TF_RETURN_IF_ERROR(xla_context->BuildArguments(args, use_tuple_arg));
XlaContext* context = new XlaContext(this, &builder, allow_cpu_custom_calls_,
resolve_compile_time_constants_);
core::ScopedUnref context_unref(context);
std::vector<XlaContext::HandleOrConstant> context_args;
TF_RETURN_IF_ERROR(
BuildArguments(args, use_tuple_arg, &builder, &context_args));
context->set_args(std::move(context_args));
TF_RETURN_IF_ERROR(
ExecuteGraph(xla_context, std::move(graph), device_, flib, NextStepId()));
ExecuteGraph(context, std::move(graph), device_, flib, NextStepId()));
std::vector<XlaContext::ConstRetVal> compile_time_constants;
int num_nonconst_outputs;
TF_RETURN_IF_ERROR(xla_context->CollectResults(
&result->computation, &result->requires_runtime_context,
&compile_time_constants, &num_nonconst_outputs));
TF_RETURN_IF_ERROR(
BuildComputation(context->retvals(), context->has_side_effects(),
&builder, &result->computation, &num_nonconst_outputs));
VLOG(2) << "Outputs: constant: " << compile_time_constants.size()
result->requires_runtime_context = context->has_context_parameter();
// Tuple arguments and runtime context parameters are incompatible.
CHECK(!(use_tuple_arg && result->requires_runtime_context));
VLOG(2) << "Outputs: total: " << context->retvals().size()
<< " nonconstant: " << num_nonconst_outputs;
result->outputs.resize(compile_time_constants.size() + num_nonconst_outputs);
for (const auto& c : compile_time_constants) {
if (!c.status.ok()) {
Status constant_status = c.status;
errors::AppendToMessage(&constant_status,
"Failed evaluating constant XLA return "
"value ",
c.index);
return constant_status;
result->outputs.resize(context->retvals().size());
for (int i = 0; i < context->retvals().size(); ++i) {
const XlaContext::HandleOrConstant& retval = context->retvals()[i];
if (retval.is_constant) {
OutputDescription& output = result->outputs[i];
output.shape = retval.constant_value.shape();
output.is_constant = true;
output.constant_value = retval.constant_value;
}
if (c.index >= result->outputs.size()) {
return errors::InvalidArgument("Invalid argument index ", c.index);
}
OutputDescription& output = result->outputs[c.index];
output.shape = c.value.shape();
output.is_constant = true;
output.constant_value = c.value;
}
if (result->computation.IsNull()) {
@ -363,16 +466,18 @@ Status XlaCompiler::CompileGraph(string const& name,
// Converts the output shapes to TensorShapes.
int computation_output = 0;
for (int i = 0; i < result->outputs.size(); ++i) {
if (!result->outputs[i].is_constant) {
for (int i = 0; i < context->retvals().size(); ++i) {
const XlaContext::HandleOrConstant& retval = context->retvals()[i];
if (!retval.is_constant) {
CHECK_LT(computation_output, num_non_constant_outputs);
OutputDescription& output = result->outputs[i];
output.is_constant = false;
if (num_non_constant_outputs > 1) {
result->outputs[i].shape =
output.shape =
XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(
result->xla_output_shape, computation_output));
} else {
result->outputs[i].shape =
XLAShapeToTensorShape(result->xla_output_shape);
output.shape = XLAShapeToTensorShape(result->xla_output_shape);
}
++computation_output;
}

View File

@ -31,8 +31,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
@ -68,142 +66,33 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context";
return *context;
}
Status XlaContext::BuildArguments(std::vector<XlaCompiler::Argument> args,
bool use_tuple_arg) {
void XlaContext::set_args(std::vector<HandleOrConstant> args) {
args_ = std::move(args);
use_tuple_arg_ = use_tuple_arg;
// Compute the number of parameters, verify that they are sequential starting
// from 0
num_parameters_ = 0;
for (const XlaCompiler::Argument& arg : args_) {
if (arg.parameter < 0) continue;
if (num_parameters_ != arg.parameter) {
return errors::InvalidArgument(
"Parameter numbers to JIT compilation are not consecutive starting "
"from 0");
}
++num_parameters_;
if (arg.shape.num_elements() == 0) {
return errors::InvalidArgument(
"Non-constant argument must have a non-zero number of elements.");
}
}
if (num_parameters_ == 0) return Status::OK();
parameters_.resize(num_parameters_);
std::vector<xla::Shape> parameter_shapes(num_parameters_);
for (int i = 0; i < args_.size(); ++i) {
const XlaCompiler::Argument& arg = args_[i];
if (arg.parameter < 0) continue;
// Computes the shapes of non-constant arguments.
xla::PrimitiveType type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type));
xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(),
&parameter_shapes[arg.parameter]);
}
if (use_tuple_arg_ && num_parameters_ > 0) {
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(parameter_shapes);
xla::ComputationDataHandle tuple =
builder().Parameter(0, tuple_shape, "arg_tuple");
for (int i = 0; i < args_.size(); ++i) {
const XlaCompiler::Argument& arg = args_[i];
if (arg.parameter < 0) continue;
parameters_[arg.parameter] =
builder().GetTupleElement(tuple, arg.parameter);
}
} else {
for (int i = 0; i < args_.size(); ++i) {
const XlaCompiler::Argument& arg = args_[i];
if (arg.parameter < 0) continue;
parameters_[arg.parameter] =
builder().Parameter(arg.parameter, parameter_shapes[arg.parameter],
strings::StrCat("arg", i));
}
}
return Status::OK();
}
Status XlaContext::CollectResults(
xla::Computation* computation, bool* requires_runtime_context,
std::vector<ConstRetVal>* compile_time_constants,
int* num_nonconst_outputs) {
xla::ComputationDataHandle handle;
if (retval_.empty() && has_side_effects_) {
// Build a empty tuple return value for computations that have side effects
// but have no return values.
handle = builder().Tuple({});
} else if (retval_.size() == 1) {
handle = retval_[0].second;
// TODO(b/31775371): to workaround bug, add a no-op computation that is
// guaranteed to be constructed after all of the formal parameters to the
// computation.
handle = builder().GetTupleElement(builder().Tuple({handle}), 0);
// Ensure that the retval is returned even if another computation
// was mistakenly placed on the ComputationBuilder.
TF_CHECK_OK(builder().SetReturnValue(handle));
} else if (retval_.size() > 1) {
// There is at least one data-dependent expression: combine them
// into a Tuple in index order before compiling.
VLOG(1) << "Making the retval tuple.";
std::sort(retval_.begin(), retval_.end(),
[](const std::pair<int, xla::ComputationDataHandle>& a,
const std::pair<int, xla::ComputationDataHandle>& b) {
return a.first < b.first;
});
std::vector<xla::ComputationDataHandle> elems;
elems.reserve(retval_.size());
for (const std::pair<int, xla::ComputationDataHandle>& r : retval_) {
elems.push_back(r.second);
}
// Make a tuple from the vector of handles.
handle = builder().Tuple(elems);
}
if (handle.handle() > 0) {
// Builds the XLA computation.
xla::StatusOr<xla::Computation> computation_status = builder().Build();
if (!computation_status.ok()) {
return computation_status.status();
}
*computation = computation_status.ConsumeValueOrDie();
}
// Make sure the compile time constants are in RetVal index order.
std::sort(compile_time_constant_.begin(), compile_time_constant_.end(),
[](const ConstRetVal& a, const ConstRetVal& b) {
return a.index < b.index;
});
// Fill in the result details and return.
*compile_time_constants = std::move(compile_time_constant_);
*requires_runtime_context = has_context_parameter_;
*num_nonconst_outputs = retval_.size();
return Status::OK();
}
XlaContext::XlaContext(XlaCompiler* compiler, xla::Client* client,
const string& computation_name,
XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
bool allow_cpu_custom_calls,
bool resolve_compile_time_constants)
: compiler_(compiler),
xla_builder_(client, computation_name),
builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls),
resolve_compile_time_constants_(resolve_compile_time_constants) {}
const xla::ComputationDataHandle&
XlaContext::GetOrCreateRuntimeContextParameter() {
CHECK(allow_cpu_custom_calls_);
CHECK(!use_tuple_arg_);
if (has_context_parameter_) return context_parameter_;
has_context_parameter_ = true;
context_parameter_ = xla_builder_.Parameter(
num_parameters_, xla::ShapeUtil::MakeOpaqueShape(), "tf_context");
// Allocate the next available parameter for the context parameter.
int num_parameters = 0;
for (const HandleOrConstant& arg : args_) {
if (!arg.is_constant) {
++num_parameters;
}
}
context_parameter_ = builder_->Parameter(
num_parameters, xla::ShapeUtil::MakeOpaqueShape(), "tf_context");
return context_parameter_;
}
@ -214,23 +103,28 @@ string XlaContext::DebugString() { return "TLA JIT context"; }
void XlaContext::AddRetval(int retval_index,
const xla::ComputationDataHandle& handle) {
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
// Add the return value to the list being built up. The executor
// is multi-threaded so this has to happen under the
// lock.
retval_.emplace_back(retval_index, handle);
// Add the return value to the list being built up.
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
retvals_[retval_index].is_constant = false;
retvals_[retval_index].handle = handle;
}
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
const xla::Literal& literal) {
VLOG(1) << "Adding retval index " << retval_index
<< " with non-data-dependent tensor to XLA computation";
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
if (resolve_compile_time_constants_) {
ConstRetVal value;
value.index = retval_index;
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value.value));
compile_time_constant_.push_back(std::move(value));
retvals_[retval_index].is_constant = true;
TF_RETURN_IF_ERROR(LiteralToHostTensor(
literal, dtype, &retvals_[retval_index].constant_value));
} else {
retval_.emplace_back(retval_index, xla_builder_.ConstantLiteral(literal));
retvals_[retval_index].is_constant = false;
retvals_[retval_index].handle = builder_->ConstantLiteral(literal);
}
return Status::OK();
}
@ -239,40 +133,13 @@ void XlaContext::AddSideEffects() {
has_side_effects_ = true;
}
/* static */ const XlaExpression* XlaContext::CastExpressionFromTensor(
const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
CHECK_NE(expression->handle().handle(), 0);
VLOG(1) << "Fetched T" << expression->handle().handle();
return expression;
}
/* static */ XlaExpression* XlaContext::CastExpressionFromUninitializedTensor(
Tensor* tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
CHECK_EQ(expression->handle().handle(), 0);
return const_cast<XlaExpression*>(expression);
}
/* static */ const XlaExpression* XlaContext::GetExpressionFromTensor(
const Tensor& tensor) {
return CastExpressionFromTensor(tensor);
}
/* static */ const xla::ComputationDataHandle&
XlaContext::GetComputationFromTensor(const Tensor& tensor) {
return CastExpressionFromTensor(tensor)->handle();
}
xla::ComputationBuilder& XlaContext::builder() { return xla_builder_; }
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
return LookupOrCreate(type, &max_func_, [this, type] {
const string type_string = DataTypeString(type);
VLOG(1) << "Building Max() for " << type_string;
xla::ComputationBuilder b(builder().client(), "max<" + type_string + ">");
xla::ComputationBuilder b(builder()->client(), "max<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
@ -286,7 +153,7 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) {
return LookupOrCreate(type, &add_func_, [this, type] {
const string type_string = DataTypeString(type);
VLOG(1) << "Building Add() for " << type_string;
xla::ComputationBuilder b(builder().client(), "add<" + type_string + ">");
xla::ComputationBuilder b(builder()->client(), "add<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
@ -300,7 +167,7 @@ const xla::Computation* XlaContext::GetOrCreateSigmoid(const DataType type) {
return LookupOrCreate(type, &sigmoid_func_, [this, type] {
const string type_string = DataTypeString(type);
VLOG(1) << "Building Sigmoid() for " << type_string;
xla::ComputationBuilder b(builder().client(),
xla::ComputationBuilder b(builder()->client(),
"sigmoid<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the contexts used to represent XLA JIT computatations.
// This file defines the contexts used during XLA compilation.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
@ -33,11 +33,11 @@ limitations under the License.
namespace tensorflow {
// A XlaExpression wraps an XLA computation. Each Tensor sent
// along an edge during XLA JIT compilation represents a
// along an edge during XLA compilation represents a
// XlaExpression, and the shape of the Tensor matches the shape of
// the subcomputation in the ComputationDataHandle. Each
// expression is either a constant, an unbound parameter, or a
// function of previously-compiled expressions.
// expression is either a constant, or a function of previously-compiled
// expressions.
class XlaExpression {
public:
XlaExpression();
@ -52,8 +52,6 @@ class XlaExpression {
const Tensor& constant_value() const { return constant_value_; }
private:
friend class XlaContext;
// The XLA handle of the expression's computation.
xla::ComputationDataHandle handle_;
@ -66,90 +64,52 @@ class XlaExpression {
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
};
// The XlaContext is the data structure accessible from
// OpKernelContexts when evaluating a subgraph of Ops for JIT
// compilation by XLA. When an Op is executed during JIT
// compilation the input Tensors to the Op store handles to
// subcomputations compiled by earlier Ops in the subgraph. The Op can
// retrieve these subcomputations by calling either
// GetExpressionFromTensor, which returns the XlaExpression holding
// the subcomputation; or EvaluateAsConstant which returns an XLA
// literal of the result of the subcomputation or an error status if
// the subcomputation depends on unbound parameters. The Op may then
// use the ComputationBuilder available from XlaContext::builder()
// to compile one or more functions of the inputs into
// ComputationDataHandles. The handles can be stored as new
// expressions corresponding to the outputs of the Op by calling
// CreateOutputTensorFromComputation or
// CreateConstantOutputTensor. The *only* correct way to allocate an
// output tensor is using one of the preceding two methods, since they
// ensure there is a valid XlaExpression backing the output
// tensor. No Op should ever call allocate_output or allocate_temp
// directly on the OpKernelContext. It is permissible to pass a tensor
// from an Op input to an output (e.g. call ctx->set_output with a
// tensor passed as an input). As an example, the softmax Op produces
// output from input as follows:
//
// XlaContext& tc = XlaContext::Get(context);
// xla::ComputationBuilder& b = tc.builder();
// xla::ComputationDataHandle logits =
// tc.GetComputationFromTensor(logits_in));
// ... The softmax computation uses the builder b to compute a
// xla::ComputationDataHandle softmax holding the desired output.
// ...
// OP_REQUIRES_OK(context, tc.CreateOutputTensorFromComputation(
// context, 0, logits_in.shape().dim_sizes(),
// softmax));
//
// The XlaContext is the data structure that holds the state of an XLA
// compilation, that is accessible from OpKernelContexts when compiling a
// subgraph of Ops using XLA.
class XlaContext : public ResourceBase {
public:
// If a retval can be evaluated at JIT time it is returned as a
// Literal in a ConstRetVal struct as part of the ComputationResult.
// TODO(misard) reconcile this with the duplicate data structure in
// the XlaCompilationCache class.
struct ConstRetVal {
// The index of the RetVal corresponding to this constant literal.
int index;
// If status is not OK, value's data is undefined.
Status status;
// The value of the RetVal evaluated at JIT compilation
// time. value.shape() always gives the correct shape of the
// RetVal. If !status.ok() then value's data is undefined, otherwise the
// Tensor buffer is allocated in CPU memory.
Tensor value;
// A struct that represents either a compile-time constant, or an XLA
// computation handle. Used to represent arguments and return values.
struct HandleOrConstant {
// Is this a compile-time constant? If so, what is its value?
bool is_constant;
Tensor constant_value; // Must be in host memory.
// If this is not a constant, a computation handle.
xla::ComputationDataHandle handle;
};
// Virtual method defined by ResourceBase.
string DebugString() override;
// Retrieve the XlaContext corresponding to a step's JIT compilation.
// Retrieves the XlaContext of the current compilation.
static XlaContext& Get(const OpKernelContext* ctx);
static XlaContext& Get(const XlaOpKernelContext* ctx) {
return Get(ctx->op_kernel_context());
}
// Create a new XlaContext.
XlaContext(XlaCompiler* compiler, xla::Client* client,
const string& computation_name, bool allow_cpu_custom_calls,
bool resolve_compile_time_constants);
// Creates a new XlaContext.
XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants);
// Builds XLA computations for each of the arguments.
// Should only be called once to initialize the arguments. Not thread-safe.
Status BuildArguments(std::vector<XlaCompiler::Argument> arguments,
bool use_tuple_arg) TF_MUST_USE_RESULT;
// Virtual method defined by ResourceBase.
string DebugString() override;
// Returns the results of the symbolic computation that have accumulated in
// the XlaContext. After CollectResults() is called, the context is left in
// an invalid state and must not be reused.
// Sets `requires_runtime_context` if the emitted computation requires a
// runtime context argument. `compile_time_constants` describes any non
// data-dependent results of the computation. `num_nonconst_ouputs` is set to
// the number of outputs of the `computation`.
Status CollectResults(xla::Computation* computation,
bool* requires_runtime_context,
std::vector<ConstRetVal>* compile_time_constants,
int* num_nonconst_outputs);
XlaCompiler* compiler() const { return compiler_; }
// Returns the ComputationBuilder that Ops use for compiling new
// expressions.
xla::ComputationBuilder* builder();
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
bool has_context_parameter() const { return has_context_parameter_; }
const std::vector<HandleOrConstant>& args() const { return args_; }
void set_args(std::vector<HandleOrConstant> args);
// Get the runtime context parameter, adding one if it does not already exist.
// Dies if not compiling a local executable.
const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter();
const std::vector<HandleOrConstant>& retvals() { return retvals_; }
// This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph.
@ -159,30 +119,10 @@ class XlaContext : public ResourceBase {
Status AddConstRetval(int retval_index, DataType dtype,
const xla::Literal& literal);
// Mark the computation as having side effects (i.e., Send operators).
// Mark the computation as having side effects (e.g., Send operators).
void AddSideEffects();
// Retrieves the ComputationDataHandle from an input Tensor to an Op. This
// computation was constructed by an Op that executed previously and
// created the output Tensor using CreateOutputTensorFromComputation
// or CreateConstantOutputTensor.
static const xla::ComputationDataHandle& GetComputationFromTensor(
const Tensor& tensor);
XlaCompiler* compiler() const { return compiler_; }
// Returns the ComputationBuilder that Ops use for compiling new
// expressions.
xla::ComputationBuilder& builder();
const std::vector<XlaCompiler::Argument>& args() const { return args_; }
xla::ComputationDataHandle parameter(int num) { return parameters_[num]; }
// Get the runtime context parameter, adding one if it does not already exist.
// Dies if not compiling a local executable.
const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter();
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
bool has_side_effects() const { return has_side_effects_; }
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
@ -203,39 +143,11 @@ class XlaContext : public ResourceBase {
static const char kXlaContextResourceName[];
private:
friend class XlaOpKernelContext;
// This method is used to retrieve an expression that was allocated by
// a previous Op.
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
// This method is used to retrieve an uninitialized expression from a
// newly-allocated tensor.
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor);
// Retrieves the expression from an input Tensor to an Op. This
// expression was constructed by an Op that executed previously and
// created the output Tensor using CreateOutputTensorFromComputation
// or CreateConstantOutputTensor.
static const XlaExpression* GetExpressionFromTensor(const Tensor& tensor);
XlaCompiler* const compiler_;
// The ComputationBuilder used to construct the subgraph's compiled
// representation.
xla::ComputationBuilder xla_builder_;
// Number of XLA Parameters, not counting the context parameter, if any.
int num_parameters_;
// Arguments to the JIT compilation, both compile-time constant arguments and
// runtime parameters.
std::vector<XlaCompiler::Argument> args_;
bool use_tuple_arg_ = false;
// Runtime parameters to the XLA computation. Does not include
// compile-time constant arguments.
std::vector<xla::ComputationDataHandle> parameters_;
xla::ComputationBuilder* builder_;
// Allow ops to emit CustomCall operations for CPU.
const bool allow_cpu_custom_calls_;
@ -251,11 +163,12 @@ class XlaContext : public ResourceBase {
bool has_context_parameter_ = false;
xla::ComputationDataHandle context_parameter_;
// The data-dependent return values of the computation.
std::vector<std::pair<int, xla::ComputationDataHandle>> retval_;
// Arguments to the Tensorflow graph, indexed by _Arg index.
// Includes both compile-time constant arguments and runtime parameters.
std::vector<HandleOrConstant> args_;
// The non-data-dependent return values of the computation.
std::vector<ConstRetVal> compile_time_constant_;
// Return values of the Tensorflow graph, indexed by _Retval index.
std::vector<HandleOrConstant> retvals_;
// Does the computation have side effects, i.e., Send() calls?
bool has_side_effects_ = false;

View File

@ -31,11 +31,37 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
}
xla::ComputationBuilder* XlaOpKernelContext::builder() const {
return &XlaContext::Get(this).builder();
return XlaContext::Get(this).builder();
}
// Retrieves an XlaExpression that was allocated by a previous Op.
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
CHECK_NE(expression->handle().handle(), 0);
VLOG(1) << "Fetched T" << expression->handle().handle();
return expression;
}
// Retrieves an uninitialized XlaExpression from a newly-allocated tensor.
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
CHECK_EQ(expression->handle().handle(), 0);
return const_cast<XlaExpression*>(expression);
}
// Retrieves the ComputationDataHandle from an input Tensor to an Op. This
// computation was constructed by an Op that executed previously and
// created the output Tensor using CreateOutputTensorFromComputation
// or CreateConstantOutputTensor.
static const xla::ComputationDataHandle& GetComputationFromTensor(
const Tensor& tensor) {
return CastExpressionFromTensor(tensor)->handle();
}
const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) {
return XlaContext::GetComputationFromTensor(context_->input(index));
return GetComputationFromTensor(context_->input(index));
}
TensorShape XlaOpKernelContext::InputShape(int index) {
@ -60,8 +86,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
" but was asked to be reshaped to incompatible shape ",
new_shape.DebugString());
}
const XlaExpression* expression =
XlaContext::CastExpressionFromTensor(tensor);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
// If the tensor has a known constant value, there is no need to invoke XLA.
if (expression->has_constant_value()) {
@ -159,7 +184,7 @@ Status XlaOpKernelContext::InputList(
handles->clear();
shapes->clear();
for (const Tensor& input : inputs) {
handles->push_back(XlaContext::GetComputationFromTensor(input));
handles->push_back(GetComputationFromTensor(input));
shapes->push_back(input.shape());
}
return Status::OK();
@ -196,8 +221,7 @@ void XlaOpKernelContext::SetOutput(int index,
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.
XlaExpression* expression =
XlaContext::CastExpressionFromUninitializedTensor(output);
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
expression->set_handle(handle);
}
@ -217,8 +241,7 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.
XlaExpression* expression =
XlaContext::CastExpressionFromUninitializedTensor(output);
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
expression->set_handle(handle);
expression->set_constant_value(constant);
}

View File

@ -45,9 +45,14 @@ class XlaOpKernel : public OpKernel {
// XlaOpKernelContext is a variant of the standard OpKernel class, tailored for
// implementing operators that perform symbolic execution as part of the XLA
// compiler. The key difference is that XlaOpKernelContext produces and consumes
// data as XLA computations, rather than as standard Tensors. (Under the hood,
// symbolic execution communicates using special Tensors, but that is an
// implementation detail that this class hides.)
// data as XLA computations, rather than as standard Tensors.
//
// Under the hood, symbolic execution communicates using special Tensors that
// wrap XlaExpression objects, however this is an implementation detail that
// this class hides. The *only* correct way to allocate a Tensor during
// compilation is using the XlaOpKernelContext methods, since they ensure there
// is a valid XlaExpression backing the tensor. No Op should ever call
// allocate_output or allocate_temp directly on the underlying OpKernelContext.
class XlaOpKernelContext {
public:
explicit XlaOpKernelContext(OpKernelContext* context);