[TF:XLA] Refactor XlaContext, moving some of its reponsibilities to XlaCompiler and XlaOpKernelContext.
Move handling of arguments and return values to XlaCompiler. Introduce a new XlaContext::HandleOrConstant structure, use it for both arguments and results. Make XlaCompiler own the xla::ComputationBuilder. Move code for wrapping/unwrapping XlaExpressions in Tensors to XlaOpKernelContext, which is its only consumer. No functional changes. Change: 147250375
This commit is contained in:
parent
196ce75c99
commit
4e24bec418
@ -104,12 +104,12 @@ class ArgOp : public XlaOpKernel {
|
||||
|
||||
OP_REQUIRES(ctx, 0 <= index_ && index_ < tc.args().size(),
|
||||
errors::InvalidArgument("Invalid argument index ", index_));
|
||||
const XlaCompiler::Argument& arg = tc.args()[index_];
|
||||
|
||||
if (arg.parameter < 0) {
|
||||
const XlaContext::HandleOrConstant& arg = tc.args()[index_];
|
||||
if (arg.is_constant) {
|
||||
ctx->SetConstantOutput(0, arg.constant_value);
|
||||
} else {
|
||||
ctx->SetOutput(0, tc.parameter(arg.parameter));
|
||||
ctx->SetOutput(0, arg.handle);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,10 +44,9 @@ class XlaCompilationAllocator : public Allocator {
|
||||
string Name() override { return "tla_jit"; }
|
||||
|
||||
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
|
||||
// Regardless of the size requested, always allocate a
|
||||
// XlaExpression. Respect the aligment request because there is
|
||||
// alignment checking even for Tensors whose data is never
|
||||
// accessed.
|
||||
// Regardless of the size requested, always allocates an XlaExpression.
|
||||
// Respects the aligment request because there is alignment checking even
|
||||
// for Tensors whose data is never accessed.
|
||||
void* p = port::AlignedMalloc(sizeof(XlaExpression), alignment);
|
||||
XlaExpression* expression = reinterpret_cast<XlaExpression*>(p);
|
||||
new (expression) XlaExpression();
|
||||
|
@ -271,6 +271,109 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
return cleanup_status;
|
||||
}
|
||||
|
||||
// Builds XLA computations for each of the arguments to the computation.
|
||||
// `args` are the arguments to the computation. If `use_tuple_arg` is true, a
|
||||
// single tuple parameter will be used for all arguments; if false, each
|
||||
// argument gets its own parameter.
|
||||
Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
bool use_tuple_arg, xla::ComputationBuilder* builder,
|
||||
std::vector<XlaContext::HandleOrConstant>* context_args) {
|
||||
context_args->resize(args.size());
|
||||
|
||||
// Computes the number of parameters, verifies that they are sequential
|
||||
// starting from 0.
|
||||
int num_parameters = 0;
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
(*context_args)[i].is_constant = (args[i].parameter < 0);
|
||||
(*context_args)[i].constant_value = args[i].constant_value;
|
||||
|
||||
if (args[i].parameter < 0) continue;
|
||||
if (num_parameters != args[i].parameter) {
|
||||
return errors::InvalidArgument(
|
||||
"Parameter numbers to XLA compilation are not consecutive starting "
|
||||
"from 0");
|
||||
}
|
||||
++num_parameters;
|
||||
|
||||
if (args[i].shape.num_elements() == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Non-constant argument must have a non-zero number of elements.");
|
||||
}
|
||||
}
|
||||
if (num_parameters == 0) return Status::OK();
|
||||
|
||||
std::vector<xla::Shape> parameter_shapes(num_parameters);
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args[i];
|
||||
if (arg.parameter < 0) continue;
|
||||
// Computes the shapes of non-constant arguments.
|
||||
xla::PrimitiveType type;
|
||||
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type));
|
||||
xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(),
|
||||
¶meter_shapes[arg.parameter]);
|
||||
}
|
||||
|
||||
if (use_tuple_arg && num_parameters > 0) {
|
||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(parameter_shapes);
|
||||
xla::ComputationDataHandle tuple =
|
||||
builder->Parameter(0, tuple_shape, "arg_tuple");
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args[i];
|
||||
if (arg.parameter < 0) continue;
|
||||
(*context_args)[i].handle =
|
||||
builder->GetTupleElement(tuple, arg.parameter);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args[i];
|
||||
if (arg.parameter < 0) continue;
|
||||
(*context_args)[i].handle =
|
||||
builder->Parameter(arg.parameter, parameter_shapes[arg.parameter],
|
||||
strings::StrCat("arg", i));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builds the XLA computation. `retvals` is the list of retvals produced by
|
||||
// _Retval operators, in index order. `has_side_effects` should be true if the
|
||||
// computation has side effects and should be built even if it has no outputs.
|
||||
// `num_nonconst_outputs` is set to the number of outputs of the `computation`.
|
||||
Status BuildComputation(
|
||||
const std::vector<XlaContext::HandleOrConstant>& retvals,
|
||||
bool has_side_effects, xla::ComputationBuilder* builder,
|
||||
xla::Computation* computation, int* num_nonconst_outputs) {
|
||||
std::vector<xla::ComputationDataHandle> elems;
|
||||
elems.reserve(retvals.size());
|
||||
for (const XlaContext::HandleOrConstant& retval : retvals) {
|
||||
if (!retval.is_constant) {
|
||||
elems.push_back(retval.handle);
|
||||
}
|
||||
}
|
||||
|
||||
if (!elems.empty() || has_side_effects) {
|
||||
// Builds a empty tuple return value for computations that have side effects
|
||||
// but have no return values.
|
||||
xla::ComputationDataHandle handle = builder->Tuple(elems);
|
||||
|
||||
// TODO(b/31775371): to workaround bug, we must build a no-op computation
|
||||
// that is guaranteed to be constructed after all of the formal parameters
|
||||
// to the computation. Once the bug is fixed, we could avoid tupling here.
|
||||
if (elems.size() == 1) {
|
||||
handle = builder->GetTupleElement(handle, 0);
|
||||
}
|
||||
|
||||
// Builds the XLA computation.
|
||||
xla::StatusOr<xla::Computation> computation_status = builder->Build();
|
||||
if (!computation_status.ok()) {
|
||||
return computation_status.status();
|
||||
}
|
||||
*computation = computation_status.ConsumeValueOrDie();
|
||||
}
|
||||
*num_nonconst_outputs = elems.size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status XlaCompiler::CompileGraph(string const& name,
|
||||
@ -292,41 +395,41 @@ Status XlaCompiler::CompileGraph(string const& name,
|
||||
args[i].type, args[i].shape, &result->xla_input_shapes.back().second));
|
||||
}
|
||||
|
||||
XlaContext* xla_context =
|
||||
new XlaContext(this, client(), name, allow_cpu_custom_calls_,
|
||||
resolve_compile_time_constants_);
|
||||
core::ScopedUnref xla_context_unref(xla_context);
|
||||
xla::ComputationBuilder builder(client(), name);
|
||||
|
||||
TF_RETURN_IF_ERROR(xla_context->BuildArguments(args, use_tuple_arg));
|
||||
XlaContext* context = new XlaContext(this, &builder, allow_cpu_custom_calls_,
|
||||
resolve_compile_time_constants_);
|
||||
core::ScopedUnref context_unref(context);
|
||||
|
||||
std::vector<XlaContext::HandleOrConstant> context_args;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildArguments(args, use_tuple_arg, &builder, &context_args));
|
||||
context->set_args(std::move(context_args));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExecuteGraph(xla_context, std::move(graph), device_, flib, NextStepId()));
|
||||
ExecuteGraph(context, std::move(graph), device_, flib, NextStepId()));
|
||||
|
||||
std::vector<XlaContext::ConstRetVal> compile_time_constants;
|
||||
int num_nonconst_outputs;
|
||||
TF_RETURN_IF_ERROR(xla_context->CollectResults(
|
||||
&result->computation, &result->requires_runtime_context,
|
||||
&compile_time_constants, &num_nonconst_outputs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildComputation(context->retvals(), context->has_side_effects(),
|
||||
&builder, &result->computation, &num_nonconst_outputs));
|
||||
|
||||
VLOG(2) << "Outputs: constant: " << compile_time_constants.size()
|
||||
result->requires_runtime_context = context->has_context_parameter();
|
||||
|
||||
// Tuple arguments and runtime context parameters are incompatible.
|
||||
CHECK(!(use_tuple_arg && result->requires_runtime_context));
|
||||
|
||||
VLOG(2) << "Outputs: total: " << context->retvals().size()
|
||||
<< " nonconstant: " << num_nonconst_outputs;
|
||||
result->outputs.resize(compile_time_constants.size() + num_nonconst_outputs);
|
||||
for (const auto& c : compile_time_constants) {
|
||||
if (!c.status.ok()) {
|
||||
Status constant_status = c.status;
|
||||
errors::AppendToMessage(&constant_status,
|
||||
"Failed evaluating constant XLA return "
|
||||
"value ",
|
||||
c.index);
|
||||
return constant_status;
|
||||
result->outputs.resize(context->retvals().size());
|
||||
for (int i = 0; i < context->retvals().size(); ++i) {
|
||||
const XlaContext::HandleOrConstant& retval = context->retvals()[i];
|
||||
if (retval.is_constant) {
|
||||
OutputDescription& output = result->outputs[i];
|
||||
output.shape = retval.constant_value.shape();
|
||||
output.is_constant = true;
|
||||
output.constant_value = retval.constant_value;
|
||||
}
|
||||
if (c.index >= result->outputs.size()) {
|
||||
return errors::InvalidArgument("Invalid argument index ", c.index);
|
||||
}
|
||||
OutputDescription& output = result->outputs[c.index];
|
||||
output.shape = c.value.shape();
|
||||
output.is_constant = true;
|
||||
output.constant_value = c.value;
|
||||
}
|
||||
|
||||
if (result->computation.IsNull()) {
|
||||
@ -363,16 +466,18 @@ Status XlaCompiler::CompileGraph(string const& name,
|
||||
|
||||
// Converts the output shapes to TensorShapes.
|
||||
int computation_output = 0;
|
||||
for (int i = 0; i < result->outputs.size(); ++i) {
|
||||
if (!result->outputs[i].is_constant) {
|
||||
for (int i = 0; i < context->retvals().size(); ++i) {
|
||||
const XlaContext::HandleOrConstant& retval = context->retvals()[i];
|
||||
if (!retval.is_constant) {
|
||||
CHECK_LT(computation_output, num_non_constant_outputs);
|
||||
OutputDescription& output = result->outputs[i];
|
||||
output.is_constant = false;
|
||||
if (num_non_constant_outputs > 1) {
|
||||
result->outputs[i].shape =
|
||||
output.shape =
|
||||
XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(
|
||||
result->xla_output_shape, computation_output));
|
||||
} else {
|
||||
result->outputs[i].shape =
|
||||
XLAShapeToTensorShape(result->xla_output_shape);
|
||||
output.shape = XLAShapeToTensorShape(result->xla_output_shape);
|
||||
}
|
||||
++computation_output;
|
||||
}
|
||||
|
@ -31,8 +31,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -68,142 +66,33 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context";
|
||||
return *context;
|
||||
}
|
||||
|
||||
Status XlaContext::BuildArguments(std::vector<XlaCompiler::Argument> args,
|
||||
bool use_tuple_arg) {
|
||||
void XlaContext::set_args(std::vector<HandleOrConstant> args) {
|
||||
args_ = std::move(args);
|
||||
use_tuple_arg_ = use_tuple_arg;
|
||||
|
||||
// Compute the number of parameters, verify that they are sequential starting
|
||||
// from 0
|
||||
num_parameters_ = 0;
|
||||
for (const XlaCompiler::Argument& arg : args_) {
|
||||
if (arg.parameter < 0) continue;
|
||||
if (num_parameters_ != arg.parameter) {
|
||||
return errors::InvalidArgument(
|
||||
"Parameter numbers to JIT compilation are not consecutive starting "
|
||||
"from 0");
|
||||
}
|
||||
++num_parameters_;
|
||||
|
||||
if (arg.shape.num_elements() == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Non-constant argument must have a non-zero number of elements.");
|
||||
}
|
||||
}
|
||||
if (num_parameters_ == 0) return Status::OK();
|
||||
|
||||
parameters_.resize(num_parameters_);
|
||||
|
||||
std::vector<xla::Shape> parameter_shapes(num_parameters_);
|
||||
for (int i = 0; i < args_.size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args_[i];
|
||||
if (arg.parameter < 0) continue;
|
||||
// Computes the shapes of non-constant arguments.
|
||||
xla::PrimitiveType type;
|
||||
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type));
|
||||
xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(),
|
||||
¶meter_shapes[arg.parameter]);
|
||||
}
|
||||
|
||||
if (use_tuple_arg_ && num_parameters_ > 0) {
|
||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(parameter_shapes);
|
||||
xla::ComputationDataHandle tuple =
|
||||
builder().Parameter(0, tuple_shape, "arg_tuple");
|
||||
for (int i = 0; i < args_.size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args_[i];
|
||||
if (arg.parameter < 0) continue;
|
||||
parameters_[arg.parameter] =
|
||||
builder().GetTupleElement(tuple, arg.parameter);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < args_.size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args_[i];
|
||||
if (arg.parameter < 0) continue;
|
||||
parameters_[arg.parameter] =
|
||||
builder().Parameter(arg.parameter, parameter_shapes[arg.parameter],
|
||||
strings::StrCat("arg", i));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaContext::CollectResults(
|
||||
xla::Computation* computation, bool* requires_runtime_context,
|
||||
std::vector<ConstRetVal>* compile_time_constants,
|
||||
int* num_nonconst_outputs) {
|
||||
xla::ComputationDataHandle handle;
|
||||
if (retval_.empty() && has_side_effects_) {
|
||||
// Build a empty tuple return value for computations that have side effects
|
||||
// but have no return values.
|
||||
handle = builder().Tuple({});
|
||||
} else if (retval_.size() == 1) {
|
||||
handle = retval_[0].second;
|
||||
|
||||
// TODO(b/31775371): to workaround bug, add a no-op computation that is
|
||||
// guaranteed to be constructed after all of the formal parameters to the
|
||||
// computation.
|
||||
handle = builder().GetTupleElement(builder().Tuple({handle}), 0);
|
||||
|
||||
// Ensure that the retval is returned even if another computation
|
||||
// was mistakenly placed on the ComputationBuilder.
|
||||
TF_CHECK_OK(builder().SetReturnValue(handle));
|
||||
} else if (retval_.size() > 1) {
|
||||
// There is at least one data-dependent expression: combine them
|
||||
// into a Tuple in index order before compiling.
|
||||
VLOG(1) << "Making the retval tuple.";
|
||||
std::sort(retval_.begin(), retval_.end(),
|
||||
[](const std::pair<int, xla::ComputationDataHandle>& a,
|
||||
const std::pair<int, xla::ComputationDataHandle>& b) {
|
||||
return a.first < b.first;
|
||||
});
|
||||
std::vector<xla::ComputationDataHandle> elems;
|
||||
elems.reserve(retval_.size());
|
||||
for (const std::pair<int, xla::ComputationDataHandle>& r : retval_) {
|
||||
elems.push_back(r.second);
|
||||
}
|
||||
// Make a tuple from the vector of handles.
|
||||
handle = builder().Tuple(elems);
|
||||
}
|
||||
|
||||
if (handle.handle() > 0) {
|
||||
// Builds the XLA computation.
|
||||
xla::StatusOr<xla::Computation> computation_status = builder().Build();
|
||||
if (!computation_status.ok()) {
|
||||
return computation_status.status();
|
||||
}
|
||||
*computation = computation_status.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Make sure the compile time constants are in RetVal index order.
|
||||
std::sort(compile_time_constant_.begin(), compile_time_constant_.end(),
|
||||
[](const ConstRetVal& a, const ConstRetVal& b) {
|
||||
return a.index < b.index;
|
||||
});
|
||||
|
||||
// Fill in the result details and return.
|
||||
*compile_time_constants = std::move(compile_time_constant_);
|
||||
*requires_runtime_context = has_context_parameter_;
|
||||
*num_nonconst_outputs = retval_.size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaContext::XlaContext(XlaCompiler* compiler, xla::Client* client,
|
||||
const string& computation_name,
|
||||
XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
|
||||
bool allow_cpu_custom_calls,
|
||||
bool resolve_compile_time_constants)
|
||||
: compiler_(compiler),
|
||||
xla_builder_(client, computation_name),
|
||||
builder_(builder),
|
||||
allow_cpu_custom_calls_(allow_cpu_custom_calls),
|
||||
resolve_compile_time_constants_(resolve_compile_time_constants) {}
|
||||
|
||||
const xla::ComputationDataHandle&
|
||||
XlaContext::GetOrCreateRuntimeContextParameter() {
|
||||
CHECK(allow_cpu_custom_calls_);
|
||||
CHECK(!use_tuple_arg_);
|
||||
if (has_context_parameter_) return context_parameter_;
|
||||
has_context_parameter_ = true;
|
||||
context_parameter_ = xla_builder_.Parameter(
|
||||
num_parameters_, xla::ShapeUtil::MakeOpaqueShape(), "tf_context");
|
||||
|
||||
// Allocate the next available parameter for the context parameter.
|
||||
int num_parameters = 0;
|
||||
for (const HandleOrConstant& arg : args_) {
|
||||
if (!arg.is_constant) {
|
||||
++num_parameters;
|
||||
}
|
||||
}
|
||||
context_parameter_ = builder_->Parameter(
|
||||
num_parameters, xla::ShapeUtil::MakeOpaqueShape(), "tf_context");
|
||||
return context_parameter_;
|
||||
}
|
||||
|
||||
@ -214,23 +103,28 @@ string XlaContext::DebugString() { return "TLA JIT context"; }
|
||||
void XlaContext::AddRetval(int retval_index,
|
||||
const xla::ComputationDataHandle& handle) {
|
||||
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
|
||||
// Add the return value to the list being built up. The executor
|
||||
// is multi-threaded so this has to happen under the
|
||||
// lock.
|
||||
retval_.emplace_back(retval_index, handle);
|
||||
// Add the return value to the list being built up.
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
retvals_[retval_index].is_constant = false;
|
||||
retvals_[retval_index].handle = handle;
|
||||
}
|
||||
|
||||
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::Literal& literal) {
|
||||
VLOG(1) << "Adding retval index " << retval_index
|
||||
<< " with non-data-dependent tensor to XLA computation";
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
if (resolve_compile_time_constants_) {
|
||||
ConstRetVal value;
|
||||
value.index = retval_index;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value.value));
|
||||
compile_time_constant_.push_back(std::move(value));
|
||||
retvals_[retval_index].is_constant = true;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(
|
||||
literal, dtype, &retvals_[retval_index].constant_value));
|
||||
} else {
|
||||
retval_.emplace_back(retval_index, xla_builder_.ConstantLiteral(literal));
|
||||
retvals_[retval_index].is_constant = false;
|
||||
retvals_[retval_index].handle = builder_->ConstantLiteral(literal);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -239,40 +133,13 @@ void XlaContext::AddSideEffects() {
|
||||
has_side_effects_ = true;
|
||||
}
|
||||
|
||||
/* static */ const XlaExpression* XlaContext::CastExpressionFromTensor(
|
||||
const Tensor& tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
||||
CHECK_NE(expression->handle().handle(), 0);
|
||||
VLOG(1) << "Fetched T" << expression->handle().handle();
|
||||
return expression;
|
||||
}
|
||||
|
||||
/* static */ XlaExpression* XlaContext::CastExpressionFromUninitializedTensor(
|
||||
Tensor* tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
|
||||
CHECK_EQ(expression->handle().handle(), 0);
|
||||
return const_cast<XlaExpression*>(expression);
|
||||
}
|
||||
|
||||
/* static */ const XlaExpression* XlaContext::GetExpressionFromTensor(
|
||||
const Tensor& tensor) {
|
||||
return CastExpressionFromTensor(tensor);
|
||||
}
|
||||
|
||||
/* static */ const xla::ComputationDataHandle&
|
||||
XlaContext::GetComputationFromTensor(const Tensor& tensor) {
|
||||
return CastExpressionFromTensor(tensor)->handle();
|
||||
}
|
||||
|
||||
xla::ComputationBuilder& XlaContext::builder() { return xla_builder_; }
|
||||
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
|
||||
|
||||
const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
|
||||
return LookupOrCreate(type, &max_func_, [this, type] {
|
||||
const string type_string = DataTypeString(type);
|
||||
VLOG(1) << "Building Max() for " << type_string;
|
||||
xla::ComputationBuilder b(builder().client(), "max<" + type_string + ">");
|
||||
xla::ComputationBuilder b(builder()->client(), "max<" + type_string + ">");
|
||||
xla::PrimitiveType xla_type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
|
||||
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
|
||||
@ -286,7 +153,7 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) {
|
||||
return LookupOrCreate(type, &add_func_, [this, type] {
|
||||
const string type_string = DataTypeString(type);
|
||||
VLOG(1) << "Building Add() for " << type_string;
|
||||
xla::ComputationBuilder b(builder().client(), "add<" + type_string + ">");
|
||||
xla::ComputationBuilder b(builder()->client(), "add<" + type_string + ">");
|
||||
xla::PrimitiveType xla_type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
|
||||
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
|
||||
@ -300,7 +167,7 @@ const xla::Computation* XlaContext::GetOrCreateSigmoid(const DataType type) {
|
||||
return LookupOrCreate(type, &sigmoid_func_, [this, type] {
|
||||
const string type_string = DataTypeString(type);
|
||||
VLOG(1) << "Building Sigmoid() for " << type_string;
|
||||
xla::ComputationBuilder b(builder().client(),
|
||||
xla::ComputationBuilder b(builder()->client(),
|
||||
"sigmoid<" + type_string + ">");
|
||||
xla::PrimitiveType xla_type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the contexts used to represent XLA JIT computatations.
|
||||
// This file defines the contexts used during XLA compilation.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
|
||||
@ -33,11 +33,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
// A XlaExpression wraps an XLA computation. Each Tensor sent
|
||||
// along an edge during XLA JIT compilation represents a
|
||||
// along an edge during XLA compilation represents a
|
||||
// XlaExpression, and the shape of the Tensor matches the shape of
|
||||
// the subcomputation in the ComputationDataHandle. Each
|
||||
// expression is either a constant, an unbound parameter, or a
|
||||
// function of previously-compiled expressions.
|
||||
// expression is either a constant, or a function of previously-compiled
|
||||
// expressions.
|
||||
class XlaExpression {
|
||||
public:
|
||||
XlaExpression();
|
||||
@ -52,8 +52,6 @@ class XlaExpression {
|
||||
const Tensor& constant_value() const { return constant_value_; }
|
||||
|
||||
private:
|
||||
friend class XlaContext;
|
||||
|
||||
// The XLA handle of the expression's computation.
|
||||
xla::ComputationDataHandle handle_;
|
||||
|
||||
@ -66,90 +64,52 @@ class XlaExpression {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
|
||||
};
|
||||
|
||||
// The XlaContext is the data structure accessible from
|
||||
// OpKernelContexts when evaluating a subgraph of Ops for JIT
|
||||
// compilation by XLA. When an Op is executed during JIT
|
||||
// compilation the input Tensors to the Op store handles to
|
||||
// subcomputations compiled by earlier Ops in the subgraph. The Op can
|
||||
// retrieve these subcomputations by calling either
|
||||
// GetExpressionFromTensor, which returns the XlaExpression holding
|
||||
// the subcomputation; or EvaluateAsConstant which returns an XLA
|
||||
// literal of the result of the subcomputation or an error status if
|
||||
// the subcomputation depends on unbound parameters. The Op may then
|
||||
// use the ComputationBuilder available from XlaContext::builder()
|
||||
// to compile one or more functions of the inputs into
|
||||
// ComputationDataHandles. The handles can be stored as new
|
||||
// expressions corresponding to the outputs of the Op by calling
|
||||
// CreateOutputTensorFromComputation or
|
||||
// CreateConstantOutputTensor. The *only* correct way to allocate an
|
||||
// output tensor is using one of the preceding two methods, since they
|
||||
// ensure there is a valid XlaExpression backing the output
|
||||
// tensor. No Op should ever call allocate_output or allocate_temp
|
||||
// directly on the OpKernelContext. It is permissible to pass a tensor
|
||||
// from an Op input to an output (e.g. call ctx->set_output with a
|
||||
// tensor passed as an input). As an example, the softmax Op produces
|
||||
// output from input as follows:
|
||||
//
|
||||
// XlaContext& tc = XlaContext::Get(context);
|
||||
// xla::ComputationBuilder& b = tc.builder();
|
||||
// xla::ComputationDataHandle logits =
|
||||
// tc.GetComputationFromTensor(logits_in));
|
||||
// ... The softmax computation uses the builder b to compute a
|
||||
// xla::ComputationDataHandle softmax holding the desired output.
|
||||
// ...
|
||||
// OP_REQUIRES_OK(context, tc.CreateOutputTensorFromComputation(
|
||||
// context, 0, logits_in.shape().dim_sizes(),
|
||||
// softmax));
|
||||
//
|
||||
// The XlaContext is the data structure that holds the state of an XLA
|
||||
// compilation, that is accessible from OpKernelContexts when compiling a
|
||||
// subgraph of Ops using XLA.
|
||||
class XlaContext : public ResourceBase {
|
||||
public:
|
||||
// If a retval can be evaluated at JIT time it is returned as a
|
||||
// Literal in a ConstRetVal struct as part of the ComputationResult.
|
||||
// TODO(misard) reconcile this with the duplicate data structure in
|
||||
// the XlaCompilationCache class.
|
||||
struct ConstRetVal {
|
||||
// The index of the RetVal corresponding to this constant literal.
|
||||
int index;
|
||||
// If status is not OK, value's data is undefined.
|
||||
Status status;
|
||||
// The value of the RetVal evaluated at JIT compilation
|
||||
// time. value.shape() always gives the correct shape of the
|
||||
// RetVal. If !status.ok() then value's data is undefined, otherwise the
|
||||
// Tensor buffer is allocated in CPU memory.
|
||||
Tensor value;
|
||||
// A struct that represents either a compile-time constant, or an XLA
|
||||
// computation handle. Used to represent arguments and return values.
|
||||
struct HandleOrConstant {
|
||||
// Is this a compile-time constant? If so, what is its value?
|
||||
bool is_constant;
|
||||
Tensor constant_value; // Must be in host memory.
|
||||
|
||||
// If this is not a constant, a computation handle.
|
||||
xla::ComputationDataHandle handle;
|
||||
};
|
||||
|
||||
|
||||
// Virtual method defined by ResourceBase.
|
||||
string DebugString() override;
|
||||
|
||||
// Retrieve the XlaContext corresponding to a step's JIT compilation.
|
||||
// Retrieves the XlaContext of the current compilation.
|
||||
static XlaContext& Get(const OpKernelContext* ctx);
|
||||
static XlaContext& Get(const XlaOpKernelContext* ctx) {
|
||||
return Get(ctx->op_kernel_context());
|
||||
}
|
||||
|
||||
// Create a new XlaContext.
|
||||
XlaContext(XlaCompiler* compiler, xla::Client* client,
|
||||
const string& computation_name, bool allow_cpu_custom_calls,
|
||||
bool resolve_compile_time_constants);
|
||||
// Creates a new XlaContext.
|
||||
XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
|
||||
bool allow_cpu_custom_calls, bool resolve_compile_time_constants);
|
||||
|
||||
// Builds XLA computations for each of the arguments.
|
||||
// Should only be called once to initialize the arguments. Not thread-safe.
|
||||
Status BuildArguments(std::vector<XlaCompiler::Argument> arguments,
|
||||
bool use_tuple_arg) TF_MUST_USE_RESULT;
|
||||
// Virtual method defined by ResourceBase.
|
||||
string DebugString() override;
|
||||
|
||||
// Returns the results of the symbolic computation that have accumulated in
|
||||
// the XlaContext. After CollectResults() is called, the context is left in
|
||||
// an invalid state and must not be reused.
|
||||
// Sets `requires_runtime_context` if the emitted computation requires a
|
||||
// runtime context argument. `compile_time_constants` describes any non
|
||||
// data-dependent results of the computation. `num_nonconst_ouputs` is set to
|
||||
// the number of outputs of the `computation`.
|
||||
Status CollectResults(xla::Computation* computation,
|
||||
bool* requires_runtime_context,
|
||||
std::vector<ConstRetVal>* compile_time_constants,
|
||||
int* num_nonconst_outputs);
|
||||
XlaCompiler* compiler() const { return compiler_; }
|
||||
|
||||
// Returns the ComputationBuilder that Ops use for compiling new
|
||||
// expressions.
|
||||
xla::ComputationBuilder* builder();
|
||||
|
||||
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
|
||||
bool has_context_parameter() const { return has_context_parameter_; }
|
||||
|
||||
const std::vector<HandleOrConstant>& args() const { return args_; }
|
||||
void set_args(std::vector<HandleOrConstant> args);
|
||||
|
||||
// Get the runtime context parameter, adding one if it does not already exist.
|
||||
// Dies if not compiling a local executable.
|
||||
const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter();
|
||||
|
||||
const std::vector<HandleOrConstant>& retvals() { return retvals_; }
|
||||
|
||||
// This is called by the Retval Op to associate a computed value
|
||||
// with a specific return value of the subgraph.
|
||||
@ -159,30 +119,10 @@ class XlaContext : public ResourceBase {
|
||||
Status AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::Literal& literal);
|
||||
|
||||
// Mark the computation as having side effects (i.e., Send operators).
|
||||
// Mark the computation as having side effects (e.g., Send operators).
|
||||
void AddSideEffects();
|
||||
|
||||
// Retrieves the ComputationDataHandle from an input Tensor to an Op. This
|
||||
// computation was constructed by an Op that executed previously and
|
||||
// created the output Tensor using CreateOutputTensorFromComputation
|
||||
// or CreateConstantOutputTensor.
|
||||
static const xla::ComputationDataHandle& GetComputationFromTensor(
|
||||
const Tensor& tensor);
|
||||
|
||||
XlaCompiler* compiler() const { return compiler_; }
|
||||
|
||||
// Returns the ComputationBuilder that Ops use for compiling new
|
||||
// expressions.
|
||||
xla::ComputationBuilder& builder();
|
||||
|
||||
const std::vector<XlaCompiler::Argument>& args() const { return args_; }
|
||||
xla::ComputationDataHandle parameter(int num) { return parameters_[num]; }
|
||||
|
||||
// Get the runtime context parameter, adding one if it does not already exist.
|
||||
// Dies if not compiling a local executable.
|
||||
const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter();
|
||||
|
||||
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
|
||||
bool has_side_effects() const { return has_side_effects_; }
|
||||
|
||||
// Get an XLA lambda to compute Max. This is cached in the
|
||||
// XlaContext since it may be used by multiple Ops. There is a
|
||||
@ -203,39 +143,11 @@ class XlaContext : public ResourceBase {
|
||||
static const char kXlaContextResourceName[];
|
||||
|
||||
private:
|
||||
friend class XlaOpKernelContext;
|
||||
|
||||
// This method is used to retrieve an expression that was allocated by
|
||||
// a previous Op.
|
||||
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
|
||||
|
||||
// This method is used to retrieve an uninitialized expression from a
|
||||
// newly-allocated tensor.
|
||||
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor);
|
||||
|
||||
// Retrieves the expression from an input Tensor to an Op. This
|
||||
// expression was constructed by an Op that executed previously and
|
||||
// created the output Tensor using CreateOutputTensorFromComputation
|
||||
// or CreateConstantOutputTensor.
|
||||
static const XlaExpression* GetExpressionFromTensor(const Tensor& tensor);
|
||||
|
||||
XlaCompiler* const compiler_;
|
||||
|
||||
// The ComputationBuilder used to construct the subgraph's compiled
|
||||
// representation.
|
||||
xla::ComputationBuilder xla_builder_;
|
||||
|
||||
// Number of XLA Parameters, not counting the context parameter, if any.
|
||||
int num_parameters_;
|
||||
|
||||
// Arguments to the JIT compilation, both compile-time constant arguments and
|
||||
// runtime parameters.
|
||||
std::vector<XlaCompiler::Argument> args_;
|
||||
bool use_tuple_arg_ = false;
|
||||
|
||||
// Runtime parameters to the XLA computation. Does not include
|
||||
// compile-time constant arguments.
|
||||
std::vector<xla::ComputationDataHandle> parameters_;
|
||||
xla::ComputationBuilder* builder_;
|
||||
|
||||
// Allow ops to emit CustomCall operations for CPU.
|
||||
const bool allow_cpu_custom_calls_;
|
||||
@ -251,11 +163,12 @@ class XlaContext : public ResourceBase {
|
||||
bool has_context_parameter_ = false;
|
||||
xla::ComputationDataHandle context_parameter_;
|
||||
|
||||
// The data-dependent return values of the computation.
|
||||
std::vector<std::pair<int, xla::ComputationDataHandle>> retval_;
|
||||
// Arguments to the Tensorflow graph, indexed by _Arg index.
|
||||
// Includes both compile-time constant arguments and runtime parameters.
|
||||
std::vector<HandleOrConstant> args_;
|
||||
|
||||
// The non-data-dependent return values of the computation.
|
||||
std::vector<ConstRetVal> compile_time_constant_;
|
||||
// Return values of the Tensorflow graph, indexed by _Retval index.
|
||||
std::vector<HandleOrConstant> retvals_;
|
||||
|
||||
// Does the computation have side effects, i.e., Send() calls?
|
||||
bool has_side_effects_ = false;
|
||||
|
@ -31,11 +31,37 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
|
||||
}
|
||||
|
||||
xla::ComputationBuilder* XlaOpKernelContext::builder() const {
|
||||
return &XlaContext::Get(this).builder();
|
||||
return XlaContext::Get(this).builder();
|
||||
}
|
||||
|
||||
// Retrieves an XlaExpression that was allocated by a previous Op.
|
||||
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
||||
CHECK_NE(expression->handle().handle(), 0);
|
||||
VLOG(1) << "Fetched T" << expression->handle().handle();
|
||||
return expression;
|
||||
}
|
||||
|
||||
// Retrieves an uninitialized XlaExpression from a newly-allocated tensor.
|
||||
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
|
||||
CHECK_EQ(expression->handle().handle(), 0);
|
||||
return const_cast<XlaExpression*>(expression);
|
||||
}
|
||||
|
||||
// Retrieves the ComputationDataHandle from an input Tensor to an Op. This
|
||||
// computation was constructed by an Op that executed previously and
|
||||
// created the output Tensor using CreateOutputTensorFromComputation
|
||||
// or CreateConstantOutputTensor.
|
||||
static const xla::ComputationDataHandle& GetComputationFromTensor(
|
||||
const Tensor& tensor) {
|
||||
return CastExpressionFromTensor(tensor)->handle();
|
||||
}
|
||||
|
||||
const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) {
|
||||
return XlaContext::GetComputationFromTensor(context_->input(index));
|
||||
return GetComputationFromTensor(context_->input(index));
|
||||
}
|
||||
|
||||
TensorShape XlaOpKernelContext::InputShape(int index) {
|
||||
@ -60,8 +86,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
" but was asked to be reshaped to incompatible shape ",
|
||||
new_shape.DebugString());
|
||||
}
|
||||
const XlaExpression* expression =
|
||||
XlaContext::CastExpressionFromTensor(tensor);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
|
||||
// If the tensor has a known constant value, there is no need to invoke XLA.
|
||||
if (expression->has_constant_value()) {
|
||||
@ -159,7 +184,7 @@ Status XlaOpKernelContext::InputList(
|
||||
handles->clear();
|
||||
shapes->clear();
|
||||
for (const Tensor& input : inputs) {
|
||||
handles->push_back(XlaContext::GetComputationFromTensor(input));
|
||||
handles->push_back(GetComputationFromTensor(input));
|
||||
shapes->push_back(input.shape());
|
||||
}
|
||||
return Status::OK();
|
||||
@ -196,8 +221,7 @@ void XlaOpKernelContext::SetOutput(int index,
|
||||
|
||||
// The expression is stored in the tensor's data buffer. Fill in the
|
||||
// fields now.
|
||||
XlaExpression* expression =
|
||||
XlaContext::CastExpressionFromUninitializedTensor(output);
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
expression->set_handle(handle);
|
||||
}
|
||||
|
||||
@ -217,8 +241,7 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
|
||||
|
||||
// The expression is stored in the tensor's data buffer. Fill in the
|
||||
// fields now.
|
||||
XlaExpression* expression =
|
||||
XlaContext::CastExpressionFromUninitializedTensor(output);
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
expression->set_handle(handle);
|
||||
expression->set_constant_value(constant);
|
||||
}
|
||||
|
@ -45,9 +45,14 @@ class XlaOpKernel : public OpKernel {
|
||||
// XlaOpKernelContext is a variant of the standard OpKernel class, tailored for
|
||||
// implementing operators that perform symbolic execution as part of the XLA
|
||||
// compiler. The key difference is that XlaOpKernelContext produces and consumes
|
||||
// data as XLA computations, rather than as standard Tensors. (Under the hood,
|
||||
// symbolic execution communicates using special Tensors, but that is an
|
||||
// implementation detail that this class hides.)
|
||||
// data as XLA computations, rather than as standard Tensors.
|
||||
//
|
||||
// Under the hood, symbolic execution communicates using special Tensors that
|
||||
// wrap XlaExpression objects, however this is an implementation detail that
|
||||
// this class hides. The *only* correct way to allocate a Tensor during
|
||||
// compilation is using the XlaOpKernelContext methods, since they ensure there
|
||||
// is a valid XlaExpression backing the tensor. No Op should ever call
|
||||
// allocate_output or allocate_temp directly on the underlying OpKernelContext.
|
||||
class XlaOpKernelContext {
|
||||
public:
|
||||
explicit XlaOpKernelContext(OpKernelContext* context);
|
||||
|
Loading…
Reference in New Issue
Block a user