[TF:XLA] Refactor XlaCompiler, XlaExpression, and XlaOpKernelContext in preparation for adding support for compiling small computations that don't correspond to a single TF op.

The idea of the refactoring is that XlaExpression is the canonical XLA representation of a symbolic TF value. So in general a computation to compile is a function with type [XlaExpression] -> [XlaExpression], and in a future change we will add a method to XlaCompiler that exposes pretty much exactly that API. The current TF function/graph/op compilation methods are specific ways to build such a function.

* Move XlaExpression into its own file. Improve its ergonomics; it is really a kind of sum type. Also move some useful common methods on XlaExpressions into the XlaExpression class.
* Add support for passing and returning XlaExpressions via XlaOpKernelContext, since they are the underlying representation. The remaining *Input() and *Output() methods are really just conveniences built on top.
* Simplify _Arg and _Retval to just get and set an XlaExpression from an XlaContext. Move logic to flatten return values out of _Retval and move it instead into XlaCompiler so it can be reused when compiling non-graph computations.
* Move logic to assign cores to arguments and return values into a common place in XlaCompiler.

PiperOrigin-RevId: 221104314
This commit is contained in:
Peter Hawkins 2018-11-12 09:26:02 -08:00 committed by TensorFlower Gardener
parent b62da0d3fa
commit f895a9e996
16 changed files with 843 additions and 592 deletions

View File

@ -166,6 +166,7 @@ cc_library(
"xla_compilation_device.cc",
"xla_compiler.cc",
"xla_context.cc",
"xla_expression.cc",
"xla_helpers.cc",
"xla_op_kernel.cc",
"xla_op_registry.cc",
@ -180,6 +181,7 @@ cc_library(
"xla_compilation_device.h",
"xla_compiler.h",
"xla_context.h",
"xla_expression.h",
"xla_helpers.h",
"xla_op_kernel.h",
"xla_op_registry.h",
@ -364,7 +366,10 @@ tf_cc_test(
tf_cc_test(
name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"],
srcs = [
"xla_compiler_test.cc",
"xla_expression_test.cc",
],
deps = [
":common",
":side_effect_util",
@ -389,6 +394,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)

View File

@ -23,9 +23,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
@ -51,12 +52,11 @@ namespace {
Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
const std::vector<const XlaExpression*>& expressions,
std::vector<XlaCompiler::Argument>* args) {
auto builder = ctx->builder();
auto client = ctx->compiler()->client();
std::vector<bool> compile_time_constant_flags(expressions.size());
std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
TF_RETURN_IF_ERROR(
BackwardsConstAnalysis(*graph, &compile_time_constant_flags,
BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant,
/*compile_time_const_nodes=*/nullptr));
args->resize(expressions.size());
@ -65,24 +65,31 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
arg.type = ctx->input_type(i);
arg.shape = ctx->InputShape(i);
if (arg.type == DT_RESOURCE) {
return errors::InvalidArgument(
"Resource as function argument is not yet implemented.");
} else if (expressions[i]->has_constant_value()) {
arg.kind = XlaCompiler::Argument::kConstant;
arg.constant_value = expressions[i]->constant_value();
} else if (compile_time_constant_flags[i]) {
arg.kind = XlaCompiler::Argument::kConstant;
TF_RET_CHECK(expressions[i]->resource() == nullptr)
<< "Input with resource is not yet implemented.";
TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph(
expressions[i]->handle()));
TF_ASSIGN_OR_RETURN(auto literal,
client->ComputeConstant(constant_graph));
TF_RETURN_IF_ERROR(
LiteralToHostTensor(literal, arg.type, &arg.constant_value));
} else {
arg.kind = XlaCompiler::Argument::kParameter;
switch (expressions[i]->kind()) {
case XlaExpression::Kind::kConstant:
arg.kind = XlaCompiler::Argument::kConstant;
arg.constant_value = expressions[i]->constant_value();
break;
case XlaExpression::Kind::kXlaOp:
if (arg_must_be_compile_time_constant[i]) {
TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
expressions[i]->ResolveConstant(client));
if (!value.has_value()) {
return errors::InvalidArgument(
"Argument to function must be a compile-time constant, but "
"unable to resolve argument value to a constant.");
}
arg.kind = XlaCompiler::Argument::kConstant;
arg.constant_value = *value;
} else {
arg.kind = XlaCompiler::Argument::kParameter;
}
break;
case XlaExpression::Kind::kResource:
return errors::Unimplemented(
"Resource as function argument is not yet implemented.");
case XlaExpression::Kind::kInvalid:
return errors::InvalidArgument("Invalid function argument");
}
}
return Status::OK();

View File

@ -14,11 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@ -49,13 +51,9 @@ class XlaArgOp : public XlaOpKernel {
}
const XlaExpression& arg = XlaContext::Get(ctx).args()[index_];
if (arg.resource() != nullptr) {
ctx->SetResourceOutput(0, arg.resource());
} else if (arg.has_constant_value()) {
ctx->SetConstantOutput(0, arg.constant_value());
} else {
ctx->SetOutput(0, arg.handle());
}
OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid,
errors::InvalidArgument("Invalid/missing argument expression"));
ctx->SetOutputExpression(0, arg);
}
private:

View File

@ -42,11 +42,6 @@ class ConstOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape(proto_.tensor_shape());
if (proto_.dtype() == DT_STRING) {
LOG(WARNING) << "Not computing Const of type DT_STRING";
ctx->SetInvalidOutput(0);
return;
}
xla::XlaBuilder* b = ctx->builder();
// To avoid blowups for large constants filled with the same value,

View File

@ -47,63 +47,8 @@ class RetvalOp : public XlaOpKernel {
// compilation.
OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input));
} else {
xla::XlaOp input = ctx->Input(0);
const TensorShape input_shape = ctx->InputShape(0);
DataType input_type = ctx->input_type(0);
XlaContext& tc = XlaContext::Get(ctx);
if (input_type == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
ctx->SetStatus(tc.AddResourceRetval(index_, resource));
return;
}
auto is_constant = ctx->builder()->IsConstant(input);
if (!is_constant.ok()) {
ctx->SetStatus(is_constant.status());
return;
}
if (tc.resolve_compile_time_constants() &&
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
} else {
TensorShape shape = ctx->InputShape(0);
ctx->SetStatus(is_constant.status());
xla::Shape representation_shape;
if (tc.is_entry_computation()) {
xla::StatusOr<xla::Shape> shape_or_status =
tc.RepresentationShape(shape, ctx->input_type(0));
if (!shape_or_status.ok()) {
ctx->SetStatus(shape_or_status.status());
return;
} else {
representation_shape = shape_or_status.ValueOrDie();
}
} else {
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(ctx->input_type(0), shape,
&representation_shape));
}
xla::XlaOp output = input;
if (tc.is_entry_computation()) {
output = xla::Reshape(
input, xla::AsInt64Slice(representation_shape.dimensions()));
} else {
// The core from which a return value is returned depends on the
// device assignment of the input to the retval. Since we can't change
// the device assignment of "input" at this point, we must always
// introduce an operator here, even if the shape does not change.
// TODO(b/76097077): propagate device assignments onto arguments and
// return values of functions, and then reshape unconditionally.
output =
xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0);
}
tc.AddRetval(index_, dtype_, shape, output);
}
XlaContext& xla_context = XlaContext::Get(ctx);
xla_context.SetRetval(index_, ctx->InputExpression(0));
}
}

View File

@ -124,13 +124,4 @@ Status XlaCompilationDevice::MakeTensorFromProto(
"XLACompilationDevice::MakeTensorFromProto should not be called");
}
XlaExpression::XlaExpression() = default;
void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; }
void XlaExpression::set_constant_value(Tensor value) {
has_constant_value_ = true;
constant_value_ = std::move(value);
}
} // namespace tensorflow

View File

@ -18,9 +18,6 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/tf2xla/xla_resource.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/tensor.h"
@ -38,8 +35,8 @@ class XlaCompilationAllocator;
// This is a 'dummy' TensorFlow device that is only used to execute a
// subgraph of XLA compilation Ops to construct a compiled version
// of the subgraph's computation. It has a 'dummy' allocator that
// backs each Tensor with metadata indicating the computation the
// Tensor represents.
// backs each Tensor with an XlaExpression. The shape of the Tensor
// matches the shape of XlaExpression.
//
// We deliberately don't register a device factory because we *never*
// want placement to put Ops on a compilation device. The device is created
@ -67,40 +64,6 @@ class XlaCompilationDevice : public LocalDevice {
std::unique_ptr<XlaCompilationAllocator> allocator_;
};
// A XlaExpression wraps an XLA computation. Each Tensor on an
// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor
// matches the shape of the subcomputation in the XlaOp. Each
// expression is either a constant, or a function of previously-compiled
// expressions.
class XlaExpression {
public:
XlaExpression();
// handle() stores the XLA handle of the computation that the
// expression represents.
void set_handle(const xla::XlaOp& h);
const xla::XlaOp& handle() const { return handle_; }
void set_constant_value(Tensor value);
bool has_constant_value() const { return has_constant_value_; }
const Tensor& constant_value() const { return constant_value_; }
void set_resource(XlaResource* resource) { resource_ = resource; }
XlaResource* resource() const { return resource_; }
private:
// The XLA handle of the expression's computation.
xla::XlaOp handle_;
// If this expression is a constant with a known value, 'constant_value' is a
// host-memory Tensor containing the value. Used to avoid invoking XLA for
// expressions that are trivially constant.
bool has_constant_value_ = false;
Tensor constant_value_;
XlaResource* resource_ = nullptr; // Not owned.
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_

View File

@ -36,11 +36,13 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
@ -64,6 +66,240 @@ Status CheckSignature(const DataTypeVector& types,
return Status::OK();
}
// Uses the _Arg and _Retval nodes in the graph to determine a core assignment
// for each argument and return value.
xla::StatusOr<std::pair<std::map<int, int>, std::map<int, int>>>
ComputeArgAndRetvalCores(const Graph& graph) {
auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr<int> {
TF_ASSIGN_OR_RETURN(
auto sharding,
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
if (sharding.has_value()) {
TF_RET_CHECK(sharding.value().type() ==
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
return sharding.value().tile_assignment_devices(0);
} else {
return -1;
}
};
std::map<int, int> arg_cores;
std::map<int, int> retval_cores;
for (const Node* n : graph.nodes()) {
if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
if (core < 0) continue;
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RET_CHECK(index >= 0) << "Negative _Arg index";
arg_cores[index] = core;
} else if (n->type_string() == FunctionLibraryDefinition::kRetOp) {
TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
if (core < 0) continue;
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RET_CHECK(index >= 0) << "Negative _Retval index";
TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n));
retval_cores[index] = core;
}
}
return std::make_pair(std::move(arg_cores), std::move(retval_cores));
}
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
int64 step_id) {
// Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
// resource manager takes ownership via Create, and unrefs via Cleanup. We
// explicitly add a reference to ensure the refcount at entry is maintained at
// all exit points; Create and Cleanup are always called in this function.
//
// The Executor requires us to use ScopedStepContainer. We wrap it in a
// unique_ptr so we can capture the cleanup status in the end.
xla_context->Ref();
Status status;
auto step_container = absl::make_unique<ScopedStepContainer>(
step_id, [&status, device](const string& name) {
status = device->resource_manager()->Cleanup(name);
});
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
step_container->name(), XlaContext::kXlaContextResourceName,
xla_context));
GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
TF_RETURN_IF_ERROR(graph_compiler.Compile());
// Explicitly clean up the step container, to capture the cleanup status.
step_container.reset();
return Status::OK();
}
// Builds the XLA computation.
// - `args` is the list of input arguments
// - `retvals` is the list of retvals produced by _Retval operators, in index
// order.
// - `args_core` and `retval_cores` are mapping from arg/return indices to core
// assignments.
// - If `return_updated_values_for_all_resources` is true, all resources will be
// included in `resource_updates`, regardless of whether their value changed.
// - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
// - Sets `*resource_updates` to a description of resources whose values are
// written by the computation; the variable writes are the last
// - `resource_updates.size()` return values from the computation. Each entry in
// `resource_updates` is a ResourceUpdate, whose `index` is the index of a
// resource variable argument to the computation to be updated, and `type` is
// the type of the final output.
Status BuildComputation(
const std::vector<XlaCompiler::Argument>& args,
const std::vector<XlaExpression>& retvals,
const std::map<int, int>& arg_cores, const std::map<int, int>& retval_cores,
const std::vector<std::unique_ptr<XlaResource>>& resources,
std::unique_ptr<xla::XlaOp> token_output,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
bool return_updated_values_for_all_resources, bool always_return_tuple,
xla::XlaBuilder* builder, xla::XlaComputation* computation,
int* num_computation_outputs, int* num_nonconst_outputs,
std::vector<XlaCompiler::OutputDescription>* outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
// Attach a common operator name as metadata. This has no semantic effect — it
// merely makes the HLO graph more readable when visualized via TensorBoard,
// since TensorBoard forms groups out of operators with similar names.
xla::OpMetadata retval_metadata;
retval_metadata.set_op_name("XLA_Retvals");
builder->SetOpMetadata(retval_metadata);
auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
// Builds a no-op XLA computation. We need to set the sharding of outputs, but
// cannot change the sharding of the existing output op. To do this, we build
// a new identity op to which shardings can be applied.
auto identity_op = [builder](xla::XlaOp op) {
return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
};
std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size());
for (int i = 0; i < retvals.size(); ++i) {
XlaCompiler::OutputDescription& output = (*outputs)[i];
const XlaExpression& retval = retvals[i];
output.type = retval.dtype();
switch (retval.kind()) {
case XlaExpression::Kind::kConstant:
output.is_constant = true;
output.constant_value = retval.constant_value();
output.shape = output.constant_value.shape();
break;
case XlaExpression::Kind::kXlaOp: {
output.is_constant = false;
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
xla::XlaOp value = retval.handle();
auto it = retval_cores.find(i);
xla::XlaScopedShardingAssignment assign_sharding(
builder, it == retval_cores.end()
? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(it->second));
if (shape_representation_fn) {
// If there is a shape representation function, reshape the output
// tensor to the shape given by the representation shape function.
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
output.shape, output.type));
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
} else if (it != retval_cores.end()) {
// Apply the sharding to the output, if there is a core assignment.
value = identity_op(value);
}
elems.push_back(value);
break;
}
case XlaExpression::Kind::kResource:
output.is_constant = false;
output.input_index = retval.resource()->arg_num();
output.shape = retval.resource()->shape();
break;
case XlaExpression::Kind::kInvalid:
return errors::InvalidArgument(
"Invalid expression returned by computation. "
"This probably means a return value was not set.");
}
}
*num_nonconst_outputs = elems.size();
// Add return values for resources whose values have changed.
std::vector<const XlaResource*> arg_resources;
arg_resources.reserve(resources.size());
for (const auto& resource : resources) {
if (resource->arg_num() >= 0) {
arg_resources.push_back(resource.get());
}
}
std::sort(arg_resources.begin(), arg_resources.end(),
[](const XlaResource* a, const XlaResource* b) {
return a->arg_num() < b->arg_num();
});
for (const XlaResource* resource : arg_resources) {
DCHECK_LT(resource->arg_num(), args.size());
const XlaCompiler::Argument& arg = args[resource->arg_num()];
auto it = arg_cores.find(resource->arg_num());
const int core = it == arg_cores.end() ? -1 : it->second;
bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
// TensorArray gradients were modified if their values changed or there are
// any newly created gradients.
for (const auto& grad : resource->tensor_array_gradients()) {
modified =
modified ||
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
arg.tensor_array_gradients.count(grad.first) == 0;
}
if (return_updated_values_for_all_resources || modified) {
resource_updates->emplace_back();
XlaCompiler::ResourceUpdate& update = resource_updates->back();
update.input_index = resource->arg_num();
update.type = resource->type();
update.shape = resource->shape();
update.modified = modified;
for (const auto& grad : resource->tensor_array_gradients()) {
update.tensor_array_gradients_accessed.insert(grad.first);
}
// Request that the value be returned on a specific core.
xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
xla::XlaOp handle;
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
// Ensures the correct sharding is applied to the output.
handle = identity_op(handle);
elems.push_back(handle);
}
}
// If we have token output, append it as the last one.
if (token_output) {
elems.push_back(*token_output);
}
*num_computation_outputs = elems.size();
// Builds the XLA computation. We *always* form a tuple here to ensure that
// the output value is the last thing added into the XLA computation, even
// if there is only one output value.
auto tuple = xla::Tuple(builder, elems);
if (!always_return_tuple && elems.size() == 1) {
xla::GetTupleElement(tuple, 0);
}
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
if (!computation_status.ok()) {
return computation_status.status();
}
*computation = computation_status.ConsumeValueOrDie();
return Status::OK();
}
} // namespace
bool XlaCompiler::Argument::operator==(
@ -252,14 +488,16 @@ Status XlaCompiler::CompileFunction(
// lowest-numbered core that consumes the argument. We choose the
// lowest-numbered core so the assignment is deterministic.
for (Node* n : graph->nodes()) {
if (absl::string_view(n->type_string()) == "_Arg") {
if (absl::string_view(n->type_string()) ==
FunctionLibraryDefinition::kArgOp) {
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
}
}
// Do _Retval as a second loop, in case the retval's input is an _Arg (which
// may have gotten a device assignment from the first loop).
for (Node* n : graph->nodes()) {
if (absl::string_view(n->type_string()) == "_Retval") {
if (absl::string_view(n->type_string()) ==
FunctionLibraryDefinition::kRetOp) {
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
}
}
@ -353,175 +591,16 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
}
}
namespace {
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
int64 step_id) {
// Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
// resource manager takes ownership via Create, and unrefs via Cleanup. We
// explicitly add a reference to ensure the refcount at entry is maintained at
// all exit points; Create and Cleanup are always called in this function.
//
// The Executor requires us to use ScopedStepContainer. We wrap it in a
// unique_ptr so we can capture the cleanup status in the end.
xla_context->Ref();
Status status;
auto step_container = absl::make_unique<ScopedStepContainer>(
step_id, [&status, device](const string& name) {
status = device->resource_manager()->Cleanup(name);
});
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
step_container->name(), XlaContext::kXlaContextResourceName,
xla_context));
GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
TF_RETURN_IF_ERROR(graph_compiler.Compile());
// Explicitly clean up the step container, to capture the cleanup status.
step_container.reset();
return Status::OK();
}
// Builds the XLA computation.
// `args` is the list of input arguments, `retvals` is the list of retvals
// produced by _Retval operators, in index order.
// If `return_updated_values_for_all_resources` is true, all resources will be
// included in `resource_updates`, regardless of whether their value changed.
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
// Sets `*resource_updates` to a description of resources whose values are
// written by the computation; the variable writes are the last
// `resource_updates.size()` return values from the computation. Each entry in
// `resource_updates` is a (input_index, type) pair, where `input_index` is the
// index of a resource variable argument to the computation, and `type` is the
// type of the final output.
Status BuildComputation(
const std::vector<XlaCompiler::Argument>& args,
const std::vector<int>& arg_cores,
const std::vector<XlaContext::Retval>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources,
std::unique_ptr<xla::XlaOp> token_output,
bool return_updated_values_for_all_resources, bool always_return_tuple,
xla::XlaBuilder* builder, xla::XlaComputation* computation,
int* num_computation_outputs, int* num_nonconst_outputs,
std::vector<XlaCompiler::OutputDescription>* outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size());
for (int i = 0; i < retvals.size(); ++i) {
XlaCompiler::OutputDescription& output = (*outputs)[i];
output.type = retvals[i].type;
output.shape = retvals[i].shape;
const XlaExpression& retval = retvals[i].expression;
if (retval.has_constant_value()) {
output.is_constant = true;
output.constant_value = retval.constant_value();
} else if (retval.resource() != nullptr) {
output.is_constant = false;
output.input_index = retval.resource()->arg_num();
} else {
output.is_constant = false;
elems.push_back(retval.handle());
}
}
*num_nonconst_outputs = elems.size();
// Add return values for resources whose values have changed.
std::vector<const XlaResource*> arg_resources;
arg_resources.reserve(resources.size());
for (const auto& resource : resources) {
if (resource->arg_num() >= 0) {
arg_resources.push_back(resource.get());
}
}
std::sort(arg_resources.begin(), arg_resources.end(),
[](const XlaResource* a, const XlaResource* b) {
return a->arg_num() < b->arg_num();
});
// Attach a common operator name as metadata. This has no semantic effect — it
// merely makes the HLO graph more readable when visualized via TensorBoard,
// since TensorBoard forms groups out of operators with similar names.
xla::OpMetadata retval_metadata;
retval_metadata.set_op_name("XLA_Retvals");
builder->SetOpMetadata(retval_metadata);
for (const XlaResource* resource : arg_resources) {
const XlaCompiler::Argument& arg = args[resource->arg_num()];
const int core = arg_cores[resource->arg_num()];
DCHECK_LT(resource->arg_num(), arg_cores.size());
bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
// TensorArray gradients were modified if their values changed or there are
// any newly created gradients.
for (const auto& grad : resource->tensor_array_gradients()) {
modified =
modified ||
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
arg.tensor_array_gradients.count(grad.first) == 0;
}
if (return_updated_values_for_all_resources || modified) {
resource_updates->emplace_back();
XlaCompiler::ResourceUpdate& update = resource_updates->back();
update.input_index = resource->arg_num();
update.type = resource->type();
update.shape = resource->shape();
update.modified = modified;
for (const auto& grad : resource->tensor_array_gradients()) {
update.tensor_array_gradients_accessed.insert(grad.first);
}
// Request that the value be returned on a specific core.
xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
xla::XlaOp handle;
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
// Since we can't change the sharding metadata of <value> as this point,
// create a tuple/get-tuple-element combination so that sharding
// assignment will be placed on this value, which will cause the resource
// update to be returned from the same device that provided the resource.
handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0);
elems.push_back(handle);
}
}
// If we have token output, append it as the last one.
if (token_output) {
elems.push_back(*token_output);
}
*num_computation_outputs = elems.size();
// Builds the XLA computation. We *always* form a tuple here to ensure that
// the output value is the last thing added into the XLA computation, even
// if there is only one output value.
auto tuple = xla::Tuple(builder, elems);
if (!always_return_tuple && elems.size() == 1) {
xla::GetTupleElement(tuple, 0);
}
builder->ClearOpMetadata();
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
if (!computation_status.ok()) {
return computation_status.status();
}
*computation = computation_status.ConsumeValueOrDie();
return Status::OK();
}
} // namespace
// Builds XLA computations for each of the arguments to the computation.
// `args` are the arguments to the computation.
Status XlaCompiler::BuildArguments(
const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions,
const std::map<int, int>& arg_cores,
std::vector<XlaExpression>* arg_expressions,
std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes,
bool is_entry_computation) {
arg_expressions->resize(args.size());
*arg_cores = std::vector<int>(args.size(), -1);
// Argument numbers of arguments and resources that are to be passed to the
// XLA computation as runtime parameters.
@ -543,7 +622,7 @@ Status XlaCompiler::BuildArguments(
arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(),
/*tensor_array_size=*/arg.tensor_array_size,
/*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
arg_expression.set_resource(resource);
arg_expression = XlaExpression::Resource(resource);
if (arg.initialized) {
input_mapping->push_back(i);
}
@ -555,7 +634,7 @@ Status XlaCompiler::BuildArguments(
break;
}
case XlaCompiler::Argument::kConstant:
arg_expression.set_constant_value(arg.constant_value);
arg_expression = XlaExpression::Constant(arg.constant_value);
break;
case XlaCompiler::Argument::kInvalid:
return errors::Internal(
@ -580,26 +659,6 @@ Status XlaCompiler::BuildArguments(
*input_shapes = arg_shapes;
}
// Use the _Arg nodes in the graph to resolve core assignments.
for (const Node* n : graph.nodes()) {
if (absl::string_view(n->type_string()) != "_Arg") continue;
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RET_CHECK(index >= 0 && index < args.size())
<< "_Arg out of bounds: " << index << " vs " << args.size();
TF_ASSIGN_OR_RETURN(
auto sharding,
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
if (sharding.has_value()) {
TF_RET_CHECK(sharding.value().type() ==
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
const int core = sharding.value().tile_assignment_devices(0);
if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) {
(*arg_cores)[index] = core;
}
}
}
// Attach a common operator name as metadata. This has no semantic effect — it
// merely makes the HLO graph more readable when visualized via TensorBoard,
// since TensorBoard forms groups out of operators with similar names.
@ -615,11 +674,10 @@ Status XlaCompiler::BuildArguments(
xla::OpSharding tuple_sharding;
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
for (int64 parameter : *input_mapping) {
const int core = (*arg_cores)[parameter];
const int root_device = 0;
auto it = arg_cores.find(parameter);
const int core = it == arg_cores.end() ? 0 : it->second;
*tuple_sharding.add_tuple_shardings() =
core == -1 ? xla::sharding_builder::AssignDevice(root_device)
: xla::sharding_builder::AssignDevice(core);
xla::sharding_builder::AssignDevice(core);
}
xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
tuple_sharding);
@ -628,7 +686,8 @@ Status XlaCompiler::BuildArguments(
tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
}
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)];
auto it = arg_cores.find(i);
const int core = it == arg_cores.end() ? -1 : it->second;
xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
@ -636,7 +695,8 @@ Status XlaCompiler::BuildArguments(
}
} else {
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)];
auto it = arg_cores.find(i);
const int core = it == arg_cores.end() ? -1 : it->second;
xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
@ -671,14 +731,14 @@ Status XlaCompiler::BuildArguments(
// TODO(b/76097077): propagate device assignments onto arguments and
// return values of functions, and then reshape unconditionally.
if (is_entry_computation) {
arg_expression.set_handle(
xla::Reshape(arg_handles[i], arg.shape.dim_sizes()));
arg_expression = XlaExpression::XlaOp(
xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type);
} else {
arg_expression.set_handle(arg_handles[i]);
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
}
break;
case XlaCompiler::Argument::kToken: {
arg_expression.set_handle(arg_handles[i]);
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
break;
}
case XlaCompiler::Argument::kConstant:
@ -710,7 +770,7 @@ Status XlaCompiler::CompileSingleOp(
Node* node;
string arg_name = absl::StrCat("_arg", i);
Status status =
NodeBuilder(arg_name, "_Arg")
NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
.ControlInput(graph->source_node())
.Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE
: args[i].type)
@ -724,7 +784,7 @@ Status XlaCompiler::CompileSingleOp(
for (int64 i = 0; i < result_types.size(); ++i) {
Node* node;
string retval_name = absl::StrCat("_retval", i);
Status status = NodeBuilder(retval_name, "_Retval")
Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
.Input(main_node, i)
.Attr("T", result_types[i])
.Attr("index", i)
@ -788,6 +848,32 @@ Status ValidateGraph(const Graph* graph,
return Status::OK();
}
// Converts the value of any expressions whose values are known at compile-time
// to constants.
Status ResolveConstantExpressionsToConstants(
xla::Client* client, absl::Span<XlaExpression> expressions) {
for (XlaExpression& expression : expressions) {
if (expression.kind() == XlaExpression::Kind::kXlaOp) {
TF_ASSIGN_OR_RETURN(absl::optional<Tensor> constant,
expression.ResolveConstant(client));
if (constant.has_value()) {
expression = XlaExpression::Constant(*constant);
}
}
}
return Status::OK();
}
void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
absl::Span<XlaExpression> expressions) {
for (XlaExpression& expression : expressions) {
if (expression.kind() == XlaExpression::Kind::kConstant) {
expression =
XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
}
}
}
} // namespace
Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
@ -815,10 +901,9 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
options_.device_type, name));
xla::XlaBuilder builder(name);
XlaContext* context = new XlaContext(
this, &builder, options_.allow_cpu_custom_calls,
options.resolve_compile_time_constants, options.is_entry_computation,
&options_.shape_representation_fn);
XlaContext* context =
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
&options_.shape_representation_fn);
core::ScopedUnref context_unref(context);
std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
@ -833,10 +918,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
real_args.push_back(token_arg);
}
std::map<int, int> arg_cores;
std::map<int, int> retval_cores;
TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores),
ComputeArgAndRetvalCores(*graph));
std::vector<XlaExpression> arg_expressions;
std::vector<int> arg_cores;
TF_RETURN_IF_ERROR(BuildArguments(
*graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
*graph, real_args, options.use_tuple_arg, &builder, context, arg_cores,
&arg_expressions, &result->input_mapping, &result->xla_input_shapes,
options.is_entry_computation));
context->set_args(std::move(arg_expressions));
@ -884,9 +973,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
int num_computation_outputs;
result->computation = std::make_shared<xla::XlaComputation>();
result->outputs.resize(context->retvals().size());
std::vector<XlaExpression> retvals = context->retvals();
if (options.resolve_compile_time_constants) {
TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants(
client(), absl::Span<XlaExpression>(retvals)));
} else {
ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
}
TF_RETURN_IF_ERROR(BuildComputation(
real_args, arg_cores, context->retvals(), context->resources(),
std::move(token_output), options.return_updated_values_for_all_resources,
real_args, retvals, arg_cores, retval_cores, context->resources(),
std::move(token_output),
options.is_entry_computation ? options_.shape_representation_fn
: ShapeRepresentationFn{},
options.return_updated_values_for_all_resources,
options.always_return_tuple, &builder, result->computation.get(),
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
&result->resource_updates));

View File

@ -21,8 +21,10 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device.h"
@ -415,7 +417,8 @@ class XlaCompiler {
Status BuildArguments(const Graph& graph,
const std::vector<XlaCompiler::Argument>& args,
bool use_tuple_arg, xla::XlaBuilder* builder,
XlaContext* context, std::vector<int>* arg_cores,
XlaContext* context,
const std::map<int, int>& arg_cores,
std::vector<XlaExpression>* arg_expressions,
std::vector<int>* input_mapping,
std::vector<xla::Shape>* input_shapes,

View File

@ -64,63 +64,23 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
XlaContext::XlaContext(
XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
bool allow_cpu_custom_calls,
const std::function<xla::StatusOr<xla::Shape>(
const TensorShape&, DataType)>* shape_representation_fn)
: compiler_(compiler),
builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls),
resolve_compile_time_constants_(resolve_compile_time_constants),
is_entry_computation_(is_entry_computation),
shape_representation_fn_(shape_representation_fn) {}
string XlaContext::DebugString() { return "TLA JIT context"; }
// This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph.
void XlaContext::AddRetval(int retval_index, DataType type,
const TensorShape& shape, const xla::XlaOp& handle) {
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
// Add the return value to the list being built up.
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
void XlaContext::SetRetval(int index, const XlaExpression& expression) {
if (retvals_.size() <= index) {
retvals_.resize(index + 1);
}
XlaExpression e;
e.set_handle(handle);
retvals_[retval_index] = Retval{type, shape, e};
retvals_[index] = expression;
}
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
const xla::LiteralSlice& literal) {
VLOG(1) << "Adding retval index " << retval_index
<< " with non-data-dependent tensor to XLA computation";
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
Tensor value;
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
XlaExpression e;
e.set_constant_value(value);
retvals_[retval_index] = Retval{dtype, value.shape(), e};
return Status::OK();
}
Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
VLOG(1) << "Adding retval index " << retval_index << " with resource "
<< resource->name() << ":" << resource->shape().DebugString()
<< " to XLA computation";
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
XlaExpression e;
e.set_resource(resource);
retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e};
return Status::OK();
}
xla::XlaBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource(
XlaResource::Kind kind, int arg_num, string name, DataType type,
TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size,

View File

@ -20,8 +20,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -46,8 +46,7 @@ class XlaContext : public ResourceBase {
// Creates a new XlaContext. See the documentation on the class data fields
// for descriptions of the arguments.
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
bool allow_cpu_custom_calls,
const std::function<xla::StatusOr<xla::Shape>(
const TensorShape&, DataType)>* shape_representation_fn);
@ -57,37 +56,19 @@ class XlaContext : public ResourceBase {
XlaCompiler* compiler() const { return compiler_; }
// Returns the XlaBuilder that Ops use for compiling new expressions.
xla::XlaBuilder* builder();
xla::XlaBuilder* builder() { return builder_; }
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
bool resolve_compile_time_constants() const {
return resolve_compile_time_constants_;
}
bool is_entry_computation() const { return is_entry_computation_; }
const std::vector<XlaExpression>& args() const { return args_; }
void set_args(std::vector<XlaExpression> args);
struct Retval {
DataType type;
TensorShape shape;
// An XlaExpression representing the Retval's value.
XlaExpression expression;
};
const std::vector<Retval>& retvals() { return retvals_; }
const std::vector<XlaExpression>& retvals() { return retvals_; }
// This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph.
void AddRetval(int retval_index, DataType type, const TensorShape& shape,
const xla::XlaOp& handle);
// As for Retval, but for return values that are compile-time constants.
Status AddConstRetval(int retval_index, DataType dtype,
const xla::LiteralSlice& literal);
// As for Retval, but for return values that are resource handles.
Status AddResourceRetval(int retval_index, XlaResource* resource);
// Sets a return value.
// Since we do not always know in advance how many return values there are,
// grows the return values vector to size index+1 if it is smaller.
void SetRetval(int index, const XlaExpression& expression);
// Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource`
@ -140,24 +121,16 @@ class XlaContext : public ResourceBase {
// Allow ops to emit CustomCall operations for CPU.
const bool allow_cpu_custom_calls_;
// If true, constant return values are returned as Tensors instead of
// run-time computation outputs.
const bool resolve_compile_time_constants_;
// Arguments to the Tensorflow graph, indexed by _Arg index.
// Includes both compile-time constant arguments and runtime parameters.
std::vector<XlaExpression> args_;
// Return values of the Tensorflow graph, indexed by _Retval index.
std::vector<Retval> retvals_;
std::vector<XlaExpression> retvals_;
// Holds ownership of resources. The resources are not ordered.
std::vector<std::unique_ptr<XlaResource>> resources_;
// Is this a top-level computation, or an inner computation (e.g., a while
// body)?
const bool is_entry_computation_;
// Describes the on-host shapes of parameters and return values. Also see:
// XlaDevice::Options::shape_representation_fn.
const std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>*

View File

@ -0,0 +1,145 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
XlaExpression::XlaExpression() = default;
XlaExpression XlaExpression::Invalid() {
XlaExpression e;
e.kind_ = Kind::kInvalid;
return e;
}
XlaExpression XlaExpression::Constant(Tensor value) {
XlaExpression e;
e.kind_ = Kind::kConstant;
e.dtype_ = value.dtype();
e.constant_value_ = value;
return e;
}
XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
XlaExpression e;
e.kind_ = Kind::kXlaOp;
e.dtype_ = dtype;
e.handle_ = value;
return e;
}
XlaExpression XlaExpression::Resource(XlaResource* resource) {
XlaExpression e;
e.kind_ = Kind::kResource;
e.dtype_ = DT_RESOURCE;
e.resource_ = resource;
return e;
}
string XlaExpression::HumanString() const {
switch (kind_) {
case Kind::kInvalid:
return "invalid";
case Kind::kConstant:
return "constant";
case Kind::kXlaOp:
return "xla_op";
case Kind::kResource:
return "resource";
}
}
xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
switch (kind_) {
case Kind::kConstant: {
xla::BorrowingLiteral literal;
TF_RETURN_IF_ERROR(
HostTensorToBorrowingLiteral(constant_value_, &literal));
return xla::ConstantLiteral(builder, literal);
}
case Kind::kXlaOp:
if (builder != handle_.builder()) {
return errors::InvalidArgument(
"Mismatched builders in XlaExpression::AsXlaOp");
}
return handle_;
default:
return errors::InvalidArgument("AsXlaOp called on XlaExpression: ",
HumanString());
}
});
}
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
xla::Client* client) const {
switch (kind()) {
case Kind::kConstant:
return {constant_value()};
case Kind::kXlaOp:
break;
case Kind::kResource:
case Kind::kInvalid:
return errors::InvalidArgument(
"ResolveConstant called on XlaExpression: ", HumanString());
}
TF_ASSIGN_OR_RETURN(bool is_constant,
handle().builder()->IsConstant(handle()));
if (!is_constant) return {absl::nullopt};
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
handle().builder()->BuildConstantSubGraph(handle()));
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
// The XLA layout is specified minor to major, and TensorFlow uses a major to
// minor order.
std::vector<int64> layout_indices(shape.dims());
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
TF_ASSIGN_OR_RETURN(xla::Literal literal,
client->ComputeConstant(constant_graph, &layout));
Tensor tensor;
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor));
return {tensor};
}
xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
switch (kind_) {
case Kind::kConstant:
return constant_value().shape();
case Kind::kXlaOp: {
TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
handle().builder()->GetShape(handle()));
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
return shape;
}
case Kind::kResource:
return TensorShape({});
case Kind::kInvalid:
return errors::InvalidArgument(
"GetShape() called on invalid XlaExpression");
}
}
} // namespace tensorflow

View File

@ -0,0 +1,115 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA
// compilation.
// An expression is one of:
// * a constant tensor.
// * an xla::XlaOp, representing a symbolic XLA value.
// * a resource, e.g., a variable, represented as an XlaResource pointer.
//
// Constant tensors are mostly an optimization to avoid passing large constants
// to XLA, but are also sometimes used to represent tensors that have no XLA
// representation, for example, DT_STRING tensors. A canonical use case might be
// an error message string.
class XlaExpression {
public:
enum class Kind {
kInvalid,
kConstant,
kXlaOp,
kResource,
};
XlaExpression();
XlaExpression(const XlaExpression&) = default;
XlaExpression& operator=(const XlaExpression&) = default;
// Builds an invalid expression. (Same as the default constructor, but makes
// the intent clearer.)
static XlaExpression Invalid();
// Builds a constant XLA expression.
static XlaExpression Constant(Tensor value);
// Builds a XlaOp expression. Since the mapping from TF data types to XLA
// types is not 1-1, the TF type must also be provided; in general it cannot
// be derived from the XLA type.
static XlaExpression XlaOp(xla::XlaOp value, DataType dtype);
// Builds a resource expression.
static XlaExpression Resource(XlaResource* resource);
Kind kind() const { return kind_; }
DataType dtype() const { return dtype_; }
// handle() returns the XlaOp that backs a kXlaOp expression.
const xla::XlaOp& handle() const { return handle_; }
const Tensor& constant_value() const { return constant_value_; }
XlaResource* resource() const { return resource_; }
// Returns a human-readable summary of the expression.
string HumanString() const;
// Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns
// an erroneous XlaOp if the expression is not a constant or an expression.
xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const;
// If a kXlaOp or kConstant expression can be resolved to a compile-time
// constant, returns the value as a host-memory Tensor. Returns an empty
// optional if it cannot be resolved. Returns an error if passed a resource
// expression.
xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
xla::Client* client) const;
// Returns the shape of the tensor.
// The shape of a resource is the shape of a resource handle (i.e., a scalar),
// not the shape of the resource's value.
xla::StatusOr<TensorShape> GetShape() const;
private:
Kind kind_ = Kind::kInvalid;
DataType dtype_ = DT_INVALID;
// The XLA handle of the expression's computation, if kind_ == kXlaOp.
xla::XlaOp handle_;
// The value of the constant, if kind_ == kConstant.
Tensor constant_value_;
// The resource, if kind_ == kResource. Not owned.
XlaResource* resource_ = nullptr;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_

View File

@ -0,0 +1,135 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class XlaExpressionTest : public ::testing::Test {
protected:
void SetUp() override {
client_ = xla::ClientLibrary::LocalClientOrDie();
builder_ = absl::make_unique<xla::XlaBuilder>("acomputation");
constant_ = test::AsScalar<int32>(42);
op_ = xla::ConstantR0<int32>(builder_.get(), 7);
non_constant_op_ = xla::Parameter(
builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x");
resource_ = absl::make_unique<XlaResource>(
XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"),
DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1,
/*tensor_array_gradients=*/std::set<string>(),
/*tensor_array_multiple_writes_aggregate=*/false);
}
xla::Client* client_;
std::unique_ptr<xla::XlaBuilder> builder_;
Tensor constant_;
xla::XlaOp op_;
xla::XlaOp non_constant_op_;
std::unique_ptr<XlaResource> resource_;
};
TEST_F(XlaExpressionTest, Kind) {
EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind());
EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind());
EXPECT_TRUE(XlaExpression::Kind::kConstant ==
XlaExpression::Constant(constant_).kind());
EXPECT_TRUE(XlaExpression::Kind::kXlaOp ==
XlaExpression::XlaOp(op_, DT_INT32).kind());
EXPECT_TRUE(XlaExpression::Kind::kResource ==
XlaExpression::Resource(resource_.get()).kind());
}
TEST_F(XlaExpressionTest, HumanString) {
EXPECT_EQ("invalid", XlaExpression().HumanString());
EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString());
EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString());
EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString());
EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString());
}
TEST_F(XlaExpressionTest, AsXlaOp) {
xla::XlaOp op_as_op =
XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get());
EXPECT_TRUE(op_.IsIdenticalTo(op_as_op));
xla::XlaOp const_as_op =
XlaExpression::Constant(constant_).AsXlaOp(builder_.get());
TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation,
builder_->BuildConstantSubGraph(const_as_op));
TF_ASSERT_OK_AND_ASSIGN(xla::Literal value,
client_->ComputeConstant(computation));
EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0<int32>(42),
value));
}
TEST_F(XlaExpressionTest, GetShape) {
EXPECT_FALSE(XlaExpression().GetShape().ok());
EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok());
TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape,
XlaExpression::Resource(resource_.get()).GetShape());
EXPECT_EQ(TensorShape({}), resource_shape);
TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape,
XlaExpression::XlaOp(op_, DT_INT32).GetShape());
EXPECT_EQ(TensorShape({}), op_shape);
TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape,
XlaExpression::Constant(constant_).GetShape());
EXPECT_EQ(TensorShape({}), constant_shape);
}
TEST_F(XlaExpressionTest, ResolveConstant) {
EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
EXPECT_FALSE(
XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok());
TF_ASSERT_OK_AND_ASSIGN(
absl::optional<Tensor> op_constant,
XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_));
ASSERT_TRUE(op_constant.has_value());
test::ExpectTensorEqual<int32>(test::AsScalar<int32>(7), *op_constant);
TF_ASSERT_OK_AND_ASSIGN(absl::optional<Tensor> op_nonconstant,
XlaExpression::XlaOp(non_constant_op_, DT_FLOAT)
.ResolveConstant(client_));
EXPECT_FALSE(op_nonconstant.has_value());
TF_ASSERT_OK_AND_ASSIGN(
absl::optional<Tensor> constant_constant,
XlaExpression::Constant(constant_).ResolveConstant(client_));
ASSERT_TRUE(constant_constant.has_value());
test::ExpectTensorEqual<int32>(constant_, *constant_constant);
}
} // namespace
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@ -43,32 +44,36 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const {
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
CHECK(expression->handle().valid() || expression->resource() != nullptr);
VLOG(1) << "Fetched T" << expression->handle();
CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
<< expression->HumanString();
return expression;
}
// Retrieves an uninitialized XlaExpression from a newly-allocated tensor.
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) {
// Assigns an XlaExpression to a tensor on an XLA compilation device.
static void AssignExpressionToTensor(Tensor* tensor,
const XlaExpression& value) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
CHECK(!expression->handle().valid());
return const_cast<XlaExpression*>(expression);
CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
<< expression->HumanString();
*const_cast<XlaExpression*>(expression) = value;
}
// Retrieves the XlaOp from an input Tensor to an Op. This computation was
// constructed by an Op that executed previously and created the output Tensor
// using CreateOutputTensorFromComputation or CreateConstantOutputTensor.
static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) {
return CastExpressionFromTensor(tensor)->handle();
const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
return *CastExpressionFromTensor(context_->input(index));
}
const xla::XlaOp& XlaOpKernelContext::Input(int index) {
return GetComputationFromTensor(context_->input(index));
const XlaExpression& XlaOpKernelContext::InputExpression(
absl::string_view name) {
return *CastExpressionFromTensor(GetInputTensorByName(name));
}
const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) {
return GetComputationFromTensor(GetInputTensorByName(name));
xla::XlaOp XlaOpKernelContext::Input(int index) {
return InputExpression(index).AsXlaOp(builder());
}
xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
return InputExpression(name).AsXlaOp(builder());
}
TensorShape XlaOpKernelContext::InputShape(int index) {
@ -125,59 +130,18 @@ Status XlaOpKernelContext::ConstantInput(absl::string_view name,
Status XlaOpKernelContext::ConstantInputReshaped(
int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal) {
const Tensor& tensor = context_->input(index);
TensorShape new_shape(new_dims);
if (tensor.NumElements() != new_shape.num_elements()) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
tensor.shape().DebugString(),
" but was asked to be reshaped to incompatible shape ",
new_shape.DebugString());
}
const XlaExpression* expression = CastExpressionFromTensor(tensor);
// If the tensor has a known constant value, there is no need to invoke XLA.
if (expression->has_constant_value()) {
Tensor temp(tensor.dtype());
if (!temp.CopyFrom(expression->constant_value(), new_shape)) {
// This should never happen. The constant should have a shape compatible
// with the enclosing Tensor.
return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
}
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
return Status::OK();
}
// Make sure we treat zero-element tensors as constant.
if (new_shape.num_elements() == 0) {
Tensor temp(tensor.dtype(), new_shape);
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
return Status::OK();
}
xla::XlaOp handle = expression->handle();
if (new_shape != tensor.shape()) {
// Reshape the handle to the desired shape.
handle = xla::Reshape(handle, new_shape.dim_sizes());
}
// The XLA layout is specified minor to major, and TensorFlow's minor
// dimension is the last one.
std::vector<int64> layout_indices(new_shape.dims());
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
xla::StatusOr<bool> is_constant = builder()->IsConstant(handle);
if (!is_constant.ok()) {
Status status = is_constant.status();
XlaExpression e = InputExpression(index);
xla::StatusOr<absl::optional<Tensor>> constant_or_status =
e.ResolveConstant(compiler()->client());
if (!constant_or_status.ok()) {
Status status = constant_or_status.status();
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
context_->op_kernel().type_string(),
" operator as a compile-time constant.");
return status;
}
if (!is_constant.ValueOrDie()) {
absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
if (!constant.has_value()) {
return errors::InvalidArgument(
"Input ", index, " to ", context_->op_kernel().type_string(),
" operator must be a compile-time constant.\n"
@ -190,25 +154,16 @@ Status XlaOpKernelContext::ConstantInputReshaped(
"stateful operation such as a random number generator.");
}
// Ask the XLA compiler to evaluate the data handle to a literal.
xla::StatusOr<xla::XlaComputation> constant_graph =
builder()->BuildConstantSubGraph(handle);
if (!constant_graph.ok()) {
return errors::Internal(
"Error getting a compile-time constant graph for ",
context_->op_kernel().name(), " input ", index,
".\nError: ", constant_graph.status().error_message());
Tensor temp(constant->dtype());
if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
constant->shape().DebugString(),
" but was asked to be reshaped to incompatible shape ",
TensorShape(new_dims).DebugString());
}
xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
constant_graph.ValueOrDie(), &layout);
if (!computed.ok()) {
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
" input ", index,
" as a compile-time constant.\nError: ",
computed.status().error_message());
}
*constant_literal = std::move(computed).ValueOrDie();
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
return Status::OK();
}
@ -363,7 +318,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
handles->clear();
shapes->clear();
for (const Tensor& input : inputs) {
handles->push_back(GetComputationFromTensor(input));
handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
shapes->push_back(input.shape());
}
return Status::OK();
@ -449,90 +404,53 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
return Status::OK();
}
Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape,
Tensor** output) {
// The step's default allocator is the dummy XlaCompilationAllocator which
// simply allocates a metadata buffer to hold the expression to which it
// corresponds.
if (expected_output_dtype(index) == DT_VARIANT) {
// tensor_data() is not supported for variant Tensor (i.e.,
// DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
// XlaExpression inside the Tensor's tensor_data() does not work for
// variant. Instead construct a uint8 tensor and store the expression in its
// value.
// TODO(jpienaar): This should be refactored to stop masquerading
// XlaExpressions as Tensors.
*output = new Tensor();
TensorShape tensor_shape;
TF_RETURN_IF_ERROR(
context_->allocate_temp(DT_UINT8, tensor_shape, *output));
context_->set_output(index, **output);
} else {
TensorShape tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape));
TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output));
void XlaOpKernelContext::SetOutputExpression(int index,
const XlaExpression& expression) {
Status status = [&] {
// The step's default allocator is the dummy XlaCompilationAllocator which
// simply allocates a metadata buffer to hold the expression to which it
// corresponds.
Tensor* output = nullptr;
// Provides a special behavior for DT_VARIANT: a variant is treated as
// DT_UINT8 scalar as the type to allow mapping for variant to more generic
// types.
if (expression.dtype() == DT_VARIANT) {
// tensor_data() is not supported for variant Tensor (i.e.,
// DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
// XlaExpression inside the Tensor's tensor_data() does not work for
// variant. Instead construct a uint8 tensor and store the expression in
// its value.
// TODO(jpienaar): This should be refactored to stop masquerading
// XlaExpressions as Tensors.
output = new Tensor();
TensorShape tensor_shape;
TF_RETURN_IF_ERROR(
context_->allocate_temp(DT_UINT8, tensor_shape, output));
context_->set_output(index, *output);
} else {
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
}
AssignExpressionToTensor(output, expression);
return Status::OK();
}();
if (!status.ok()) {
SetStatus(status);
}
return Status::OK();
}
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
// Makes the host Tensor that will refer to the expression.
Tensor* output = nullptr;
auto shape_or = builder()->GetShape(handle);
if (!shape_or.ok()) {
SetStatus(shape_or.status());
return;
}
OP_REQUIRES_OK(context_,
allocate_output(index, shape_or.ValueOrDie(), &output));
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
expression->set_handle(handle);
SetOutputExpression(
index,
XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
}
void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
const TensorShape& shape = constant.shape();
xla::BorrowingLiteral literal;
OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal));
xla::XlaOp handle = xla::ConstantLiteral(builder(), literal);
CHECK(handle.valid());
// Make the Tensor that will refer to the expression.
Tensor* output = nullptr;
// The step's default allocator is the dummy XlaCompilationAllocator which
// simply allocates a metadata buffer to hold the expression to which it
// corresponds.
OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output));
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
expression->set_handle(handle);
expression->set_constant_value(constant);
}
void XlaOpKernelContext::SetInvalidOutput(int index) {
Tensor* output = nullptr;
OP_REQUIRES_OK(context_,
context_->allocate_output(index, TensorShape({}), &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
xla::XlaOp handle;
expression->set_handle(handle);
SetOutputExpression(index, XlaExpression::Constant(constant));
}
void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
Tensor* output = nullptr;
// The shape of the output tensor is the shape of the resource itself
// (i.e., a scalar), not the shape of the resource's value.
OP_REQUIRES_OK(context_,
context_->allocate_output(index, TensorShape(), &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
expression->set_resource(resource);
SetOutputExpression(index, XlaExpression::Resource(resource));
}
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {

View File

@ -88,9 +88,9 @@ class XlaOpKernelContext {
// Returns input `index` as a XlaOp. Unlike
// OpKernelContext::Input returns a symbolic value rather than a concrete
// Tensor.
const xla::XlaOp& Input(int index);
xla::XlaOp Input(int index);
// Returns input `name` as a XlaOp.
const xla::XlaOp& Input(absl::string_view name);
xla::XlaOp Input(absl::string_view name);
// Returns true if all inputs are the same shape, otherwise sets the
// status to a non-OK value and returns false.
@ -142,6 +142,10 @@ class XlaOpKernelContext {
Status ConstantInputList(absl::string_view name,
std::vector<xla::Literal>* literals);
// Returns an XlaExpression describing the value of 'index'.
const XlaExpression& InputExpression(int index);
const XlaExpression& InputExpression(absl::string_view name);
// Outputs
int num_outputs() const { return context_->num_outputs(); }
@ -159,9 +163,8 @@ class XlaOpKernelContext {
// SetConstantOutput where possible.
void SetConstantOutput(int index, const Tensor& host_tensor);
// Sets output `index` to an invalid value.
// Any subsequent attempt to consume this output will cause an error.
void SetInvalidOutput(int index);
// Returns an XlaExpression describing the value of 'index'.
void SetOutputExpression(int index, const XlaExpression& expression);
// Status handling.
void SetStatus(const Status& status) { context_->SetStatus(status); }
@ -249,11 +252,6 @@ class XlaOpKernelContext {
// Returns the tensor of input `name`.
const Tensor& GetInputTensorByName(absl::string_view name);
// Wraps OpKernelContext's allocate_output method while providing special
// behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the
// type to allow mapping for variant to more generic types.
Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
// InputShape(index), and stores it in `*constant_literal`. If the input
// cannot be evaluated, e.g., because it depends on unbound parameters,