[TF:XLA] Refactor XlaContext, moving some of its reponsibilities to XlaCompiler and XlaOpKernelContext.

Move handling of arguments and return values to XlaCompiler. Introduce a new XlaContext::HandleOrConstant structure, use it for both arguments and results.
Make XlaCompiler own the xla::ComputationBuilder.

Move code for wrapping/unwrapping XlaExpressions in Tensors to XlaOpKernelContext, which is its only consumer.

No functional changes.
Change: 147250375
This commit is contained in:
Peter Hawkins 2017-02-11 10:42:18 -08:00 committed by TensorFlower Gardener
parent 196ce75c99
commit 4e24bec418
7 changed files with 261 additions and 349 deletions

View File

@ -104,12 +104,12 @@ class ArgOp : public XlaOpKernel {
OP_REQUIRES(ctx, 0 <= index_ && index_ < tc.args().size(), OP_REQUIRES(ctx, 0 <= index_ && index_ < tc.args().size(),
errors::InvalidArgument("Invalid argument index ", index_)); errors::InvalidArgument("Invalid argument index ", index_));
const XlaCompiler::Argument& arg = tc.args()[index_];
if (arg.parameter < 0) { const XlaContext::HandleOrConstant& arg = tc.args()[index_];
if (arg.is_constant) {
ctx->SetConstantOutput(0, arg.constant_value); ctx->SetConstantOutput(0, arg.constant_value);
} else { } else {
ctx->SetOutput(0, tc.parameter(arg.parameter)); ctx->SetOutput(0, arg.handle);
} }
} }

View File

@ -44,10 +44,9 @@ class XlaCompilationAllocator : public Allocator {
string Name() override { return "tla_jit"; } string Name() override { return "tla_jit"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override { void* AllocateRaw(size_t alignment, size_t num_bytes) override {
// Regardless of the size requested, always allocate a // Regardless of the size requested, always allocates an XlaExpression.
// XlaExpression. Respect the aligment request because there is // Respects the aligment request because there is alignment checking even
// alignment checking even for Tensors whose data is never // for Tensors whose data is never accessed.
// accessed.
void* p = port::AlignedMalloc(sizeof(XlaExpression), alignment); void* p = port::AlignedMalloc(sizeof(XlaExpression), alignment);
XlaExpression* expression = reinterpret_cast<XlaExpression*>(p); XlaExpression* expression = reinterpret_cast<XlaExpression*>(p);
new (expression) XlaExpression(); new (expression) XlaExpression();

View File

@ -271,6 +271,109 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
return cleanup_status; return cleanup_status;
} }
// Builds XLA computations for each of the arguments to the computation.
// `args` are the arguments to the computation. If `use_tuple_arg` is true, a
// single tuple parameter will be used for all arguments; if false, each
// argument gets its own parameter.
Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
bool use_tuple_arg, xla::ComputationBuilder* builder,
std::vector<XlaContext::HandleOrConstant>* context_args) {
context_args->resize(args.size());
// Computes the number of parameters, verifies that they are sequential
// starting from 0.
int num_parameters = 0;
for (int i = 0; i < args.size(); ++i) {
(*context_args)[i].is_constant = (args[i].parameter < 0);
(*context_args)[i].constant_value = args[i].constant_value;
if (args[i].parameter < 0) continue;
if (num_parameters != args[i].parameter) {
return errors::InvalidArgument(
"Parameter numbers to XLA compilation are not consecutive starting "
"from 0");
}
++num_parameters;
if (args[i].shape.num_elements() == 0) {
return errors::InvalidArgument(
"Non-constant argument must have a non-zero number of elements.");
}
}
if (num_parameters == 0) return Status::OK();
std::vector<xla::Shape> parameter_shapes(num_parameters);
for (int i = 0; i < args.size(); ++i) {
const XlaCompiler::Argument& arg = args[i];
if (arg.parameter < 0) continue;
// Computes the shapes of non-constant arguments.
xla::PrimitiveType type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type));
xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(),
&parameter_shapes[arg.parameter]);
}
if (use_tuple_arg && num_parameters > 0) {
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(parameter_shapes);
xla::ComputationDataHandle tuple =
builder->Parameter(0, tuple_shape, "arg_tuple");
for (int i = 0; i < args.size(); ++i) {
const XlaCompiler::Argument& arg = args[i];
if (arg.parameter < 0) continue;
(*context_args)[i].handle =
builder->GetTupleElement(tuple, arg.parameter);
}
} else {
for (int i = 0; i < args.size(); ++i) {
const XlaCompiler::Argument& arg = args[i];
if (arg.parameter < 0) continue;
(*context_args)[i].handle =
builder->Parameter(arg.parameter, parameter_shapes[arg.parameter],
strings::StrCat("arg", i));
}
}
return Status::OK();
}
// Builds the XLA computation. `retvals` is the list of retvals produced by
// _Retval operators, in index order. `has_side_effects` should be true if the
// computation has side effects and should be built even if it has no outputs.
// `num_nonconst_outputs` is set to the number of outputs of the `computation`.
Status BuildComputation(
const std::vector<XlaContext::HandleOrConstant>& retvals,
bool has_side_effects, xla::ComputationBuilder* builder,
xla::Computation* computation, int* num_nonconst_outputs) {
std::vector<xla::ComputationDataHandle> elems;
elems.reserve(retvals.size());
for (const XlaContext::HandleOrConstant& retval : retvals) {
if (!retval.is_constant) {
elems.push_back(retval.handle);
}
}
if (!elems.empty() || has_side_effects) {
// Builds a empty tuple return value for computations that have side effects
// but have no return values.
xla::ComputationDataHandle handle = builder->Tuple(elems);
// TODO(b/31775371): to workaround bug, we must build a no-op computation
// that is guaranteed to be constructed after all of the formal parameters
// to the computation. Once the bug is fixed, we could avoid tupling here.
if (elems.size() == 1) {
handle = builder->GetTupleElement(handle, 0);
}
// Builds the XLA computation.
xla::StatusOr<xla::Computation> computation_status = builder->Build();
if (!computation_status.ok()) {
return computation_status.status();
}
*computation = computation_status.ConsumeValueOrDie();
}
*num_nonconst_outputs = elems.size();
return Status::OK();
}
} // namespace } // namespace
Status XlaCompiler::CompileGraph(string const& name, Status XlaCompiler::CompileGraph(string const& name,
@ -292,41 +395,41 @@ Status XlaCompiler::CompileGraph(string const& name,
args[i].type, args[i].shape, &result->xla_input_shapes.back().second)); args[i].type, args[i].shape, &result->xla_input_shapes.back().second));
} }
XlaContext* xla_context = xla::ComputationBuilder builder(client(), name);
new XlaContext(this, client(), name, allow_cpu_custom_calls_,
resolve_compile_time_constants_);
core::ScopedUnref xla_context_unref(xla_context);
TF_RETURN_IF_ERROR(xla_context->BuildArguments(args, use_tuple_arg)); XlaContext* context = new XlaContext(this, &builder, allow_cpu_custom_calls_,
resolve_compile_time_constants_);
core::ScopedUnref context_unref(context);
std::vector<XlaContext::HandleOrConstant> context_args;
TF_RETURN_IF_ERROR(
BuildArguments(args, use_tuple_arg, &builder, &context_args));
context->set_args(std::move(context_args));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ExecuteGraph(xla_context, std::move(graph), device_, flib, NextStepId())); ExecuteGraph(context, std::move(graph), device_, flib, NextStepId()));
std::vector<XlaContext::ConstRetVal> compile_time_constants;
int num_nonconst_outputs; int num_nonconst_outputs;
TF_RETURN_IF_ERROR(xla_context->CollectResults( TF_RETURN_IF_ERROR(
&result->computation, &result->requires_runtime_context, BuildComputation(context->retvals(), context->has_side_effects(),
&compile_time_constants, &num_nonconst_outputs)); &builder, &result->computation, &num_nonconst_outputs));
VLOG(2) << "Outputs: constant: " << compile_time_constants.size() result->requires_runtime_context = context->has_context_parameter();
// Tuple arguments and runtime context parameters are incompatible.
CHECK(!(use_tuple_arg && result->requires_runtime_context));
VLOG(2) << "Outputs: total: " << context->retvals().size()
<< " nonconstant: " << num_nonconst_outputs; << " nonconstant: " << num_nonconst_outputs;
result->outputs.resize(compile_time_constants.size() + num_nonconst_outputs); result->outputs.resize(context->retvals().size());
for (const auto& c : compile_time_constants) { for (int i = 0; i < context->retvals().size(); ++i) {
if (!c.status.ok()) { const XlaContext::HandleOrConstant& retval = context->retvals()[i];
Status constant_status = c.status; if (retval.is_constant) {
errors::AppendToMessage(&constant_status, OutputDescription& output = result->outputs[i];
"Failed evaluating constant XLA return " output.shape = retval.constant_value.shape();
"value ", output.is_constant = true;
c.index); output.constant_value = retval.constant_value;
return constant_status;
} }
if (c.index >= result->outputs.size()) {
return errors::InvalidArgument("Invalid argument index ", c.index);
}
OutputDescription& output = result->outputs[c.index];
output.shape = c.value.shape();
output.is_constant = true;
output.constant_value = c.value;
} }
if (result->computation.IsNull()) { if (result->computation.IsNull()) {
@ -363,16 +466,18 @@ Status XlaCompiler::CompileGraph(string const& name,
// Converts the output shapes to TensorShapes. // Converts the output shapes to TensorShapes.
int computation_output = 0; int computation_output = 0;
for (int i = 0; i < result->outputs.size(); ++i) { for (int i = 0; i < context->retvals().size(); ++i) {
if (!result->outputs[i].is_constant) { const XlaContext::HandleOrConstant& retval = context->retvals()[i];
if (!retval.is_constant) {
CHECK_LT(computation_output, num_non_constant_outputs); CHECK_LT(computation_output, num_non_constant_outputs);
OutputDescription& output = result->outputs[i];
output.is_constant = false;
if (num_non_constant_outputs > 1) { if (num_non_constant_outputs > 1) {
result->outputs[i].shape = output.shape =
XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(
result->xla_output_shape, computation_output)); result->xla_output_shape, computation_output));
} else { } else {
result->outputs[i].shape = output.shape = XLAShapeToTensorShape(result->xla_output_shape);
XLAShapeToTensorShape(result->xla_output_shape);
} }
++computation_output; ++computation_output;
} }

View File

@ -31,8 +31,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow { namespace tensorflow {
@ -68,142 +66,33 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context";
return *context; return *context;
} }
Status XlaContext::BuildArguments(std::vector<XlaCompiler::Argument> args, void XlaContext::set_args(std::vector<HandleOrConstant> args) {
bool use_tuple_arg) {
args_ = std::move(args); args_ = std::move(args);
use_tuple_arg_ = use_tuple_arg;
// Compute the number of parameters, verify that they are sequential starting
// from 0
num_parameters_ = 0;
for (const XlaCompiler::Argument& arg : args_) {
if (arg.parameter < 0) continue;
if (num_parameters_ != arg.parameter) {
return errors::InvalidArgument(
"Parameter numbers to JIT compilation are not consecutive starting "
"from 0");
}
++num_parameters_;
if (arg.shape.num_elements() == 0) {
return errors::InvalidArgument(
"Non-constant argument must have a non-zero number of elements.");
}
}
if (num_parameters_ == 0) return Status::OK();
parameters_.resize(num_parameters_);
std::vector<xla::Shape> parameter_shapes(num_parameters_);
for (int i = 0; i < args_.size(); ++i) {
const XlaCompiler::Argument& arg = args_[i];
if (arg.parameter < 0) continue;
// Computes the shapes of non-constant arguments.
xla::PrimitiveType type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type));
xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(),
&parameter_shapes[arg.parameter]);
}
if (use_tuple_arg_ && num_parameters_ > 0) {
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(parameter_shapes);
xla::ComputationDataHandle tuple =
builder().Parameter(0, tuple_shape, "arg_tuple");
for (int i = 0; i < args_.size(); ++i) {
const XlaCompiler::Argument& arg = args_[i];
if (arg.parameter < 0) continue;
parameters_[arg.parameter] =
builder().GetTupleElement(tuple, arg.parameter);
}
} else {
for (int i = 0; i < args_.size(); ++i) {
const XlaCompiler::Argument& arg = args_[i];
if (arg.parameter < 0) continue;
parameters_[arg.parameter] =
builder().Parameter(arg.parameter, parameter_shapes[arg.parameter],
strings::StrCat("arg", i));
}
}
return Status::OK();
} }
Status XlaContext::CollectResults( XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
xla::Computation* computation, bool* requires_runtime_context,
std::vector<ConstRetVal>* compile_time_constants,
int* num_nonconst_outputs) {
xla::ComputationDataHandle handle;
if (retval_.empty() && has_side_effects_) {
// Build a empty tuple return value for computations that have side effects
// but have no return values.
handle = builder().Tuple({});
} else if (retval_.size() == 1) {
handle = retval_[0].second;
// TODO(b/31775371): to workaround bug, add a no-op computation that is
// guaranteed to be constructed after all of the formal parameters to the
// computation.
handle = builder().GetTupleElement(builder().Tuple({handle}), 0);
// Ensure that the retval is returned even if another computation
// was mistakenly placed on the ComputationBuilder.
TF_CHECK_OK(builder().SetReturnValue(handle));
} else if (retval_.size() > 1) {
// There is at least one data-dependent expression: combine them
// into a Tuple in index order before compiling.
VLOG(1) << "Making the retval tuple.";
std::sort(retval_.begin(), retval_.end(),
[](const std::pair<int, xla::ComputationDataHandle>& a,
const std::pair<int, xla::ComputationDataHandle>& b) {
return a.first < b.first;
});
std::vector<xla::ComputationDataHandle> elems;
elems.reserve(retval_.size());
for (const std::pair<int, xla::ComputationDataHandle>& r : retval_) {
elems.push_back(r.second);
}
// Make a tuple from the vector of handles.
handle = builder().Tuple(elems);
}
if (handle.handle() > 0) {
// Builds the XLA computation.
xla::StatusOr<xla::Computation> computation_status = builder().Build();
if (!computation_status.ok()) {
return computation_status.status();
}
*computation = computation_status.ConsumeValueOrDie();
}
// Make sure the compile time constants are in RetVal index order.
std::sort(compile_time_constant_.begin(), compile_time_constant_.end(),
[](const ConstRetVal& a, const ConstRetVal& b) {
return a.index < b.index;
});
// Fill in the result details and return.
*compile_time_constants = std::move(compile_time_constant_);
*requires_runtime_context = has_context_parameter_;
*num_nonconst_outputs = retval_.size();
return Status::OK();
}
XlaContext::XlaContext(XlaCompiler* compiler, xla::Client* client,
const string& computation_name,
bool allow_cpu_custom_calls, bool allow_cpu_custom_calls,
bool resolve_compile_time_constants) bool resolve_compile_time_constants)
: compiler_(compiler), : compiler_(compiler),
xla_builder_(client, computation_name), builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls), allow_cpu_custom_calls_(allow_cpu_custom_calls),
resolve_compile_time_constants_(resolve_compile_time_constants) {} resolve_compile_time_constants_(resolve_compile_time_constants) {}
const xla::ComputationDataHandle& const xla::ComputationDataHandle&
XlaContext::GetOrCreateRuntimeContextParameter() { XlaContext::GetOrCreateRuntimeContextParameter() {
CHECK(allow_cpu_custom_calls_); CHECK(allow_cpu_custom_calls_);
CHECK(!use_tuple_arg_);
if (has_context_parameter_) return context_parameter_; if (has_context_parameter_) return context_parameter_;
has_context_parameter_ = true; has_context_parameter_ = true;
context_parameter_ = xla_builder_.Parameter(
num_parameters_, xla::ShapeUtil::MakeOpaqueShape(), "tf_context"); // Allocate the next available parameter for the context parameter.
int num_parameters = 0;
for (const HandleOrConstant& arg : args_) {
if (!arg.is_constant) {
++num_parameters;
}
}
context_parameter_ = builder_->Parameter(
num_parameters, xla::ShapeUtil::MakeOpaqueShape(), "tf_context");
return context_parameter_; return context_parameter_;
} }
@ -214,23 +103,28 @@ string XlaContext::DebugString() { return "TLA JIT context"; }
void XlaContext::AddRetval(int retval_index, void XlaContext::AddRetval(int retval_index,
const xla::ComputationDataHandle& handle) { const xla::ComputationDataHandle& handle) {
VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
// Add the return value to the list being built up. The executor // Add the return value to the list being built up.
// is multi-threaded so this has to happen under the if (retvals_.size() <= retval_index) {
// lock. retvals_.resize(retval_index + 1);
retval_.emplace_back(retval_index, handle); }
retvals_[retval_index].is_constant = false;
retvals_[retval_index].handle = handle;
} }
Status XlaContext::AddConstRetval(int retval_index, DataType dtype, Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
const xla::Literal& literal) { const xla::Literal& literal) {
VLOG(1) << "Adding retval index " << retval_index VLOG(1) << "Adding retval index " << retval_index
<< " with non-data-dependent tensor to XLA computation"; << " with non-data-dependent tensor to XLA computation";
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
if (resolve_compile_time_constants_) { if (resolve_compile_time_constants_) {
ConstRetVal value; retvals_[retval_index].is_constant = true;
value.index = retval_index; TF_RETURN_IF_ERROR(LiteralToHostTensor(
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value.value)); literal, dtype, &retvals_[retval_index].constant_value));
compile_time_constant_.push_back(std::move(value));
} else { } else {
retval_.emplace_back(retval_index, xla_builder_.ConstantLiteral(literal)); retvals_[retval_index].is_constant = false;
retvals_[retval_index].handle = builder_->ConstantLiteral(literal);
} }
return Status::OK(); return Status::OK();
} }
@ -239,40 +133,13 @@ void XlaContext::AddSideEffects() {
has_side_effects_ = true; has_side_effects_ = true;
} }
/* static */ const XlaExpression* XlaContext::CastExpressionFromTensor( xla::ComputationBuilder* XlaContext::builder() { return builder_; }
const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
CHECK_NE(expression->handle().handle(), 0);
VLOG(1) << "Fetched T" << expression->handle().handle();
return expression;
}
/* static */ XlaExpression* XlaContext::CastExpressionFromUninitializedTensor(
Tensor* tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
CHECK_EQ(expression->handle().handle(), 0);
return const_cast<XlaExpression*>(expression);
}
/* static */ const XlaExpression* XlaContext::GetExpressionFromTensor(
const Tensor& tensor) {
return CastExpressionFromTensor(tensor);
}
/* static */ const xla::ComputationDataHandle&
XlaContext::GetComputationFromTensor(const Tensor& tensor) {
return CastExpressionFromTensor(tensor)->handle();
}
xla::ComputationBuilder& XlaContext::builder() { return xla_builder_; }
const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
return LookupOrCreate(type, &max_func_, [this, type] { return LookupOrCreate(type, &max_func_, [this, type] {
const string type_string = DataTypeString(type); const string type_string = DataTypeString(type);
VLOG(1) << "Building Max() for " << type_string; VLOG(1) << "Building Max() for " << type_string;
xla::ComputationBuilder b(builder().client(), "max<" + type_string + ">"); xla::ComputationBuilder b(builder()->client(), "max<" + type_string + ">");
xla::PrimitiveType xla_type; xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
@ -286,7 +153,7 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) {
return LookupOrCreate(type, &add_func_, [this, type] { return LookupOrCreate(type, &add_func_, [this, type] {
const string type_string = DataTypeString(type); const string type_string = DataTypeString(type);
VLOG(1) << "Building Add() for " << type_string; VLOG(1) << "Building Add() for " << type_string;
xla::ComputationBuilder b(builder().client(), "add<" + type_string + ">"); xla::ComputationBuilder b(builder()->client(), "add<" + type_string + ">");
xla::PrimitiveType xla_type; xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
@ -300,7 +167,7 @@ const xla::Computation* XlaContext::GetOrCreateSigmoid(const DataType type) {
return LookupOrCreate(type, &sigmoid_func_, [this, type] { return LookupOrCreate(type, &sigmoid_func_, [this, type] {
const string type_string = DataTypeString(type); const string type_string = DataTypeString(type);
VLOG(1) << "Building Sigmoid() for " << type_string; VLOG(1) << "Building Sigmoid() for " << type_string;
xla::ComputationBuilder b(builder().client(), xla::ComputationBuilder b(builder()->client(),
"sigmoid<" + type_string + ">"); "sigmoid<" + type_string + ">");
xla::PrimitiveType xla_type; xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This file defines the contexts used to represent XLA JIT computatations. // This file defines the contexts used during XLA compilation.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
@ -33,11 +33,11 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// A XlaExpression wraps an XLA computation. Each Tensor sent // A XlaExpression wraps an XLA computation. Each Tensor sent
// along an edge during XLA JIT compilation represents a // along an edge during XLA compilation represents a
// XlaExpression, and the shape of the Tensor matches the shape of // XlaExpression, and the shape of the Tensor matches the shape of
// the subcomputation in the ComputationDataHandle. Each // the subcomputation in the ComputationDataHandle. Each
// expression is either a constant, an unbound parameter, or a // expression is either a constant, or a function of previously-compiled
// function of previously-compiled expressions. // expressions.
class XlaExpression { class XlaExpression {
public: public:
XlaExpression(); XlaExpression();
@ -52,8 +52,6 @@ class XlaExpression {
const Tensor& constant_value() const { return constant_value_; } const Tensor& constant_value() const { return constant_value_; }
private: private:
friend class XlaContext;
// The XLA handle of the expression's computation. // The XLA handle of the expression's computation.
xla::ComputationDataHandle handle_; xla::ComputationDataHandle handle_;
@ -66,90 +64,52 @@ class XlaExpression {
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression); TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
}; };
// The XlaContext is the data structure accessible from // The XlaContext is the data structure that holds the state of an XLA
// OpKernelContexts when evaluating a subgraph of Ops for JIT // compilation, that is accessible from OpKernelContexts when compiling a
// compilation by XLA. When an Op is executed during JIT // subgraph of Ops using XLA.
// compilation the input Tensors to the Op store handles to
// subcomputations compiled by earlier Ops in the subgraph. The Op can
// retrieve these subcomputations by calling either
// GetExpressionFromTensor, which returns the XlaExpression holding
// the subcomputation; or EvaluateAsConstant which returns an XLA
// literal of the result of the subcomputation or an error status if
// the subcomputation depends on unbound parameters. The Op may then
// use the ComputationBuilder available from XlaContext::builder()
// to compile one or more functions of the inputs into
// ComputationDataHandles. The handles can be stored as new
// expressions corresponding to the outputs of the Op by calling
// CreateOutputTensorFromComputation or
// CreateConstantOutputTensor. The *only* correct way to allocate an
// output tensor is using one of the preceding two methods, since they
// ensure there is a valid XlaExpression backing the output
// tensor. No Op should ever call allocate_output or allocate_temp
// directly on the OpKernelContext. It is permissible to pass a tensor
// from an Op input to an output (e.g. call ctx->set_output with a
// tensor passed as an input). As an example, the softmax Op produces
// output from input as follows:
//
// XlaContext& tc = XlaContext::Get(context);
// xla::ComputationBuilder& b = tc.builder();
// xla::ComputationDataHandle logits =
// tc.GetComputationFromTensor(logits_in));
// ... The softmax computation uses the builder b to compute a
// xla::ComputationDataHandle softmax holding the desired output.
// ...
// OP_REQUIRES_OK(context, tc.CreateOutputTensorFromComputation(
// context, 0, logits_in.shape().dim_sizes(),
// softmax));
//
class XlaContext : public ResourceBase { class XlaContext : public ResourceBase {
public: public:
// If a retval can be evaluated at JIT time it is returned as a // A struct that represents either a compile-time constant, or an XLA
// Literal in a ConstRetVal struct as part of the ComputationResult. // computation handle. Used to represent arguments and return values.
// TODO(misard) reconcile this with the duplicate data structure in struct HandleOrConstant {
// the XlaCompilationCache class. // Is this a compile-time constant? If so, what is its value?
struct ConstRetVal { bool is_constant;
// The index of the RetVal corresponding to this constant literal. Tensor constant_value; // Must be in host memory.
int index;
// If status is not OK, value's data is undefined. // If this is not a constant, a computation handle.
Status status; xla::ComputationDataHandle handle;
// The value of the RetVal evaluated at JIT compilation
// time. value.shape() always gives the correct shape of the
// RetVal. If !status.ok() then value's data is undefined, otherwise the
// Tensor buffer is allocated in CPU memory.
Tensor value;
}; };
// Retrieves the XlaContext of the current compilation.
// Virtual method defined by ResourceBase.
string DebugString() override;
// Retrieve the XlaContext corresponding to a step's JIT compilation.
static XlaContext& Get(const OpKernelContext* ctx); static XlaContext& Get(const OpKernelContext* ctx);
static XlaContext& Get(const XlaOpKernelContext* ctx) { static XlaContext& Get(const XlaOpKernelContext* ctx) {
return Get(ctx->op_kernel_context()); return Get(ctx->op_kernel_context());
} }
// Create a new XlaContext. // Creates a new XlaContext.
XlaContext(XlaCompiler* compiler, xla::Client* client, XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
const string& computation_name, bool allow_cpu_custom_calls, bool allow_cpu_custom_calls, bool resolve_compile_time_constants);
bool resolve_compile_time_constants);
// Builds XLA computations for each of the arguments. // Virtual method defined by ResourceBase.
// Should only be called once to initialize the arguments. Not thread-safe. string DebugString() override;
Status BuildArguments(std::vector<XlaCompiler::Argument> arguments,
bool use_tuple_arg) TF_MUST_USE_RESULT;
// Returns the results of the symbolic computation that have accumulated in XlaCompiler* compiler() const { return compiler_; }
// the XlaContext. After CollectResults() is called, the context is left in
// an invalid state and must not be reused. // Returns the ComputationBuilder that Ops use for compiling new
// Sets `requires_runtime_context` if the emitted computation requires a // expressions.
// runtime context argument. `compile_time_constants` describes any non xla::ComputationBuilder* builder();
// data-dependent results of the computation. `num_nonconst_ouputs` is set to
// the number of outputs of the `computation`. bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
Status CollectResults(xla::Computation* computation, bool has_context_parameter() const { return has_context_parameter_; }
bool* requires_runtime_context,
std::vector<ConstRetVal>* compile_time_constants, const std::vector<HandleOrConstant>& args() const { return args_; }
int* num_nonconst_outputs); void set_args(std::vector<HandleOrConstant> args);
// Get the runtime context parameter, adding one if it does not already exist.
// Dies if not compiling a local executable.
const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter();
const std::vector<HandleOrConstant>& retvals() { return retvals_; }
// This is called by the Retval Op to associate a computed value // This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph. // with a specific return value of the subgraph.
@ -159,30 +119,10 @@ class XlaContext : public ResourceBase {
Status AddConstRetval(int retval_index, DataType dtype, Status AddConstRetval(int retval_index, DataType dtype,
const xla::Literal& literal); const xla::Literal& literal);
// Mark the computation as having side effects (i.e., Send operators). // Mark the computation as having side effects (e.g., Send operators).
void AddSideEffects(); void AddSideEffects();
// Retrieves the ComputationDataHandle from an input Tensor to an Op. This bool has_side_effects() const { return has_side_effects_; }
// computation was constructed by an Op that executed previously and
// created the output Tensor using CreateOutputTensorFromComputation
// or CreateConstantOutputTensor.
static const xla::ComputationDataHandle& GetComputationFromTensor(
const Tensor& tensor);
XlaCompiler* compiler() const { return compiler_; }
// Returns the ComputationBuilder that Ops use for compiling new
// expressions.
xla::ComputationBuilder& builder();
const std::vector<XlaCompiler::Argument>& args() const { return args_; }
xla::ComputationDataHandle parameter(int num) { return parameters_[num]; }
// Get the runtime context parameter, adding one if it does not already exist.
// Dies if not compiling a local executable.
const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter();
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
// Get an XLA lambda to compute Max. This is cached in the // Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a // XlaContext since it may be used by multiple Ops. There is a
@ -203,39 +143,11 @@ class XlaContext : public ResourceBase {
static const char kXlaContextResourceName[]; static const char kXlaContextResourceName[];
private: private:
friend class XlaOpKernelContext;
// This method is used to retrieve an expression that was allocated by
// a previous Op.
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
// This method is used to retrieve an uninitialized expression from a
// newly-allocated tensor.
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor);
// Retrieves the expression from an input Tensor to an Op. This
// expression was constructed by an Op that executed previously and
// created the output Tensor using CreateOutputTensorFromComputation
// or CreateConstantOutputTensor.
static const XlaExpression* GetExpressionFromTensor(const Tensor& tensor);
XlaCompiler* const compiler_; XlaCompiler* const compiler_;
// The ComputationBuilder used to construct the subgraph's compiled // The ComputationBuilder used to construct the subgraph's compiled
// representation. // representation.
xla::ComputationBuilder xla_builder_; xla::ComputationBuilder* builder_;
// Number of XLA Parameters, not counting the context parameter, if any.
int num_parameters_;
// Arguments to the JIT compilation, both compile-time constant arguments and
// runtime parameters.
std::vector<XlaCompiler::Argument> args_;
bool use_tuple_arg_ = false;
// Runtime parameters to the XLA computation. Does not include
// compile-time constant arguments.
std::vector<xla::ComputationDataHandle> parameters_;
// Allow ops to emit CustomCall operations for CPU. // Allow ops to emit CustomCall operations for CPU.
const bool allow_cpu_custom_calls_; const bool allow_cpu_custom_calls_;
@ -251,11 +163,12 @@ class XlaContext : public ResourceBase {
bool has_context_parameter_ = false; bool has_context_parameter_ = false;
xla::ComputationDataHandle context_parameter_; xla::ComputationDataHandle context_parameter_;
// The data-dependent return values of the computation. // Arguments to the Tensorflow graph, indexed by _Arg index.
std::vector<std::pair<int, xla::ComputationDataHandle>> retval_; // Includes both compile-time constant arguments and runtime parameters.
std::vector<HandleOrConstant> args_;
// The non-data-dependent return values of the computation. // Return values of the Tensorflow graph, indexed by _Retval index.
std::vector<ConstRetVal> compile_time_constant_; std::vector<HandleOrConstant> retvals_;
// Does the computation have side effects, i.e., Send() calls? // Does the computation have side effects, i.e., Send() calls?
bool has_side_effects_ = false; bool has_side_effects_ = false;

View File

@ -31,11 +31,37 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
} }
xla::ComputationBuilder* XlaOpKernelContext::builder() const { xla::ComputationBuilder* XlaOpKernelContext::builder() const {
return &XlaContext::Get(this).builder(); return XlaContext::Get(this).builder();
}
// Retrieves an XlaExpression that was allocated by a previous Op.
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
CHECK_NE(expression->handle().handle(), 0);
VLOG(1) << "Fetched T" << expression->handle().handle();
return expression;
}
// Retrieves an uninitialized XlaExpression from a newly-allocated tensor.
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
CHECK_EQ(expression->handle().handle(), 0);
return const_cast<XlaExpression*>(expression);
}
// Retrieves the ComputationDataHandle from an input Tensor to an Op. This
// computation was constructed by an Op that executed previously and
// created the output Tensor using CreateOutputTensorFromComputation
// or CreateConstantOutputTensor.
static const xla::ComputationDataHandle& GetComputationFromTensor(
const Tensor& tensor) {
return CastExpressionFromTensor(tensor)->handle();
} }
const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) { const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) {
return XlaContext::GetComputationFromTensor(context_->input(index)); return GetComputationFromTensor(context_->input(index));
} }
TensorShape XlaOpKernelContext::InputShape(int index) { TensorShape XlaOpKernelContext::InputShape(int index) {
@ -60,8 +86,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
" but was asked to be reshaped to incompatible shape ", " but was asked to be reshaped to incompatible shape ",
new_shape.DebugString()); new_shape.DebugString());
} }
const XlaExpression* expression = const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaContext::CastExpressionFromTensor(tensor);
// If the tensor has a known constant value, there is no need to invoke XLA. // If the tensor has a known constant value, there is no need to invoke XLA.
if (expression->has_constant_value()) { if (expression->has_constant_value()) {
@ -159,7 +184,7 @@ Status XlaOpKernelContext::InputList(
handles->clear(); handles->clear();
shapes->clear(); shapes->clear();
for (const Tensor& input : inputs) { for (const Tensor& input : inputs) {
handles->push_back(XlaContext::GetComputationFromTensor(input)); handles->push_back(GetComputationFromTensor(input));
shapes->push_back(input.shape()); shapes->push_back(input.shape());
} }
return Status::OK(); return Status::OK();
@ -196,8 +221,7 @@ void XlaOpKernelContext::SetOutput(int index,
// The expression is stored in the tensor's data buffer. Fill in the // The expression is stored in the tensor's data buffer. Fill in the
// fields now. // fields now.
XlaExpression* expression = XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
XlaContext::CastExpressionFromUninitializedTensor(output);
expression->set_handle(handle); expression->set_handle(handle);
} }
@ -217,8 +241,7 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
// The expression is stored in the tensor's data buffer. Fill in the // The expression is stored in the tensor's data buffer. Fill in the
// fields now. // fields now.
XlaExpression* expression = XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
XlaContext::CastExpressionFromUninitializedTensor(output);
expression->set_handle(handle); expression->set_handle(handle);
expression->set_constant_value(constant); expression->set_constant_value(constant);
} }

View File

@ -45,9 +45,14 @@ class XlaOpKernel : public OpKernel {
// XlaOpKernelContext is a variant of the standard OpKernel class, tailored for // XlaOpKernelContext is a variant of the standard OpKernel class, tailored for
// implementing operators that perform symbolic execution as part of the XLA // implementing operators that perform symbolic execution as part of the XLA
// compiler. The key difference is that XlaOpKernelContext produces and consumes // compiler. The key difference is that XlaOpKernelContext produces and consumes
// data as XLA computations, rather than as standard Tensors. (Under the hood, // data as XLA computations, rather than as standard Tensors.
// symbolic execution communicates using special Tensors, but that is an //
// implementation detail that this class hides.) // Under the hood, symbolic execution communicates using special Tensors that
// wrap XlaExpression objects, however this is an implementation detail that
// this class hides. The *only* correct way to allocate a Tensor during
// compilation is using the XlaOpKernelContext methods, since they ensure there
// is a valid XlaExpression backing the tensor. No Op should ever call
// allocate_output or allocate_temp directly on the underlying OpKernelContext.
class XlaOpKernelContext { class XlaOpKernelContext {
public: public:
explicit XlaOpKernelContext(OpKernelContext* context); explicit XlaOpKernelContext(OpKernelContext* context);