[TF:XLA] Refactor XlaCompiler, XlaExpression, and XlaOpKernelContext in preparation for adding support for compiling small computations that don't correspond to a single TF op.
The idea of the refactoring is that XlaExpression is the canonical XLA representation of a symbolic TF value. So in general a computation to compile is a function with type [XlaExpression] -> [XlaExpression], and in a future change we will add a method to XlaCompiler that exposes pretty much exactly that API. The current TF function/graph/op compilation methods are specific ways to build such a function. * Move XlaExpression into its own file. Improve its ergonomics; it is really a kind of sum type. Also move some useful common methods on XlaExpressions into the XlaExpression class. * Add support for passing and returning XlaExpressions via XlaOpKernelContext, since they are the underlying representation. The remaining *Input() and *Output() methods are really just conveniences built on top. * Simplify _Arg and _Retval to just get and set an XlaExpression from an XlaContext. Move logic to flatten return values out of _Retval and move it instead into XlaCompiler so it can be reused when compiling non-graph computations. * Move logic to assign cores to arguments and return values into a common place in XlaCompiler. PiperOrigin-RevId: 221104314
This commit is contained in:
parent
b62da0d3fa
commit
f895a9e996
@ -166,6 +166,7 @@ cc_library(
|
||||
"xla_compilation_device.cc",
|
||||
"xla_compiler.cc",
|
||||
"xla_context.cc",
|
||||
"xla_expression.cc",
|
||||
"xla_helpers.cc",
|
||||
"xla_op_kernel.cc",
|
||||
"xla_op_registry.cc",
|
||||
@ -180,6 +181,7 @@ cc_library(
|
||||
"xla_compilation_device.h",
|
||||
"xla_compiler.h",
|
||||
"xla_context.h",
|
||||
"xla_expression.h",
|
||||
"xla_helpers.h",
|
||||
"xla_op_kernel.h",
|
||||
"xla_op_registry.h",
|
||||
@ -364,7 +366,10 @@ tf_cc_test(
|
||||
|
||||
tf_cc_test(
|
||||
name = "xla_compiler_test",
|
||||
srcs = ["xla_compiler_test.cc"],
|
||||
srcs = [
|
||||
"xla_compiler_test.cc",
|
||||
"xla_expression_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":side_effect_util",
|
||||
@ -389,6 +394,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -23,9 +23,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/validate.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -51,12 +52,11 @@ namespace {
|
||||
Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
||||
const std::vector<const XlaExpression*>& expressions,
|
||||
std::vector<XlaCompiler::Argument>* args) {
|
||||
auto builder = ctx->builder();
|
||||
auto client = ctx->compiler()->client();
|
||||
std::vector<bool> compile_time_constant_flags(expressions.size());
|
||||
std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
BackwardsConstAnalysis(*graph, &compile_time_constant_flags,
|
||||
BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant,
|
||||
/*compile_time_const_nodes=*/nullptr));
|
||||
|
||||
args->resize(expressions.size());
|
||||
@ -65,24 +65,31 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
||||
arg.type = ctx->input_type(i);
|
||||
arg.shape = ctx->InputShape(i);
|
||||
|
||||
if (arg.type == DT_RESOURCE) {
|
||||
return errors::InvalidArgument(
|
||||
"Resource as function argument is not yet implemented.");
|
||||
} else if (expressions[i]->has_constant_value()) {
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.constant_value = expressions[i]->constant_value();
|
||||
} else if (compile_time_constant_flags[i]) {
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
TF_RET_CHECK(expressions[i]->resource() == nullptr)
|
||||
<< "Input with resource is not yet implemented.";
|
||||
TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph(
|
||||
expressions[i]->handle()));
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
client->ComputeConstant(constant_graph));
|
||||
TF_RETURN_IF_ERROR(
|
||||
LiteralToHostTensor(literal, arg.type, &arg.constant_value));
|
||||
} else {
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
switch (expressions[i]->kind()) {
|
||||
case XlaExpression::Kind::kConstant:
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.constant_value = expressions[i]->constant_value();
|
||||
break;
|
||||
case XlaExpression::Kind::kXlaOp:
|
||||
if (arg_must_be_compile_time_constant[i]) {
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
|
||||
expressions[i]->ResolveConstant(client));
|
||||
if (!value.has_value()) {
|
||||
return errors::InvalidArgument(
|
||||
"Argument to function must be a compile-time constant, but "
|
||||
"unable to resolve argument value to a constant.");
|
||||
}
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.constant_value = *value;
|
||||
} else {
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
}
|
||||
break;
|
||||
case XlaExpression::Kind::kResource:
|
||||
return errors::Unimplemented(
|
||||
"Resource as function argument is not yet implemented.");
|
||||
case XlaExpression::Kind::kInvalid:
|
||||
return errors::InvalidArgument("Invalid function argument");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -14,11 +14,13 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -49,13 +51,9 @@ class XlaArgOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
const XlaExpression& arg = XlaContext::Get(ctx).args()[index_];
|
||||
if (arg.resource() != nullptr) {
|
||||
ctx->SetResourceOutput(0, arg.resource());
|
||||
} else if (arg.has_constant_value()) {
|
||||
ctx->SetConstantOutput(0, arg.constant_value());
|
||||
} else {
|
||||
ctx->SetOutput(0, arg.handle());
|
||||
}
|
||||
OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid,
|
||||
errors::InvalidArgument("Invalid/missing argument expression"));
|
||||
ctx->SetOutputExpression(0, arg);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -42,11 +42,6 @@ class ConstOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape(proto_.tensor_shape());
|
||||
|
||||
if (proto_.dtype() == DT_STRING) {
|
||||
LOG(WARNING) << "Not computing Const of type DT_STRING";
|
||||
ctx->SetInvalidOutput(0);
|
||||
return;
|
||||
}
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
|
||||
// To avoid blowups for large constants filled with the same value,
|
||||
|
@ -47,63 +47,8 @@ class RetvalOp : public XlaOpKernel {
|
||||
// compilation.
|
||||
OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input));
|
||||
} else {
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
DataType input_type = ctx->input_type(0);
|
||||
XlaContext& tc = XlaContext::Get(ctx);
|
||||
|
||||
if (input_type == DT_RESOURCE) {
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
ctx->SetStatus(tc.AddResourceRetval(index_, resource));
|
||||
return;
|
||||
}
|
||||
|
||||
auto is_constant = ctx->builder()->IsConstant(input);
|
||||
if (!is_constant.ok()) {
|
||||
ctx->SetStatus(is_constant.status());
|
||||
return;
|
||||
}
|
||||
|
||||
if (tc.resolve_compile_time_constants() &&
|
||||
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
|
||||
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
|
||||
} else {
|
||||
TensorShape shape = ctx->InputShape(0);
|
||||
ctx->SetStatus(is_constant.status());
|
||||
xla::Shape representation_shape;
|
||||
if (tc.is_entry_computation()) {
|
||||
xla::StatusOr<xla::Shape> shape_or_status =
|
||||
tc.RepresentationShape(shape, ctx->input_type(0));
|
||||
if (!shape_or_status.ok()) {
|
||||
ctx->SetStatus(shape_or_status.status());
|
||||
return;
|
||||
} else {
|
||||
representation_shape = shape_or_status.ValueOrDie();
|
||||
}
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(ctx->input_type(0), shape,
|
||||
&representation_shape));
|
||||
}
|
||||
|
||||
xla::XlaOp output = input;
|
||||
if (tc.is_entry_computation()) {
|
||||
output = xla::Reshape(
|
||||
input, xla::AsInt64Slice(representation_shape.dimensions()));
|
||||
} else {
|
||||
// The core from which a return value is returned depends on the
|
||||
// device assignment of the input to the retval. Since we can't change
|
||||
// the device assignment of "input" at this point, we must always
|
||||
// introduce an operator here, even if the shape does not change.
|
||||
// TODO(b/76097077): propagate device assignments onto arguments and
|
||||
// return values of functions, and then reshape unconditionally.
|
||||
output =
|
||||
xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0);
|
||||
}
|
||||
tc.AddRetval(index_, dtype_, shape, output);
|
||||
}
|
||||
XlaContext& xla_context = XlaContext::Get(ctx);
|
||||
xla_context.SetRetval(index_, ctx->InputExpression(0));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,13 +124,4 @@ Status XlaCompilationDevice::MakeTensorFromProto(
|
||||
"XLACompilationDevice::MakeTensorFromProto should not be called");
|
||||
}
|
||||
|
||||
XlaExpression::XlaExpression() = default;
|
||||
|
||||
void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; }
|
||||
|
||||
void XlaExpression::set_constant_value(Tensor value) {
|
||||
has_constant_value_ = true;
|
||||
constant_value_ = std::move(value);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,9 +18,6 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -38,8 +35,8 @@ class XlaCompilationAllocator;
|
||||
// This is a 'dummy' TensorFlow device that is only used to execute a
|
||||
// subgraph of XLA compilation Ops to construct a compiled version
|
||||
// of the subgraph's computation. It has a 'dummy' allocator that
|
||||
// backs each Tensor with metadata indicating the computation the
|
||||
// Tensor represents.
|
||||
// backs each Tensor with an XlaExpression. The shape of the Tensor
|
||||
// matches the shape of XlaExpression.
|
||||
//
|
||||
// We deliberately don't register a device factory because we *never*
|
||||
// want placement to put Ops on a compilation device. The device is created
|
||||
@ -67,40 +64,6 @@ class XlaCompilationDevice : public LocalDevice {
|
||||
std::unique_ptr<XlaCompilationAllocator> allocator_;
|
||||
};
|
||||
|
||||
// A XlaExpression wraps an XLA computation. Each Tensor on an
|
||||
// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor
|
||||
// matches the shape of the subcomputation in the XlaOp. Each
|
||||
// expression is either a constant, or a function of previously-compiled
|
||||
// expressions.
|
||||
class XlaExpression {
|
||||
public:
|
||||
XlaExpression();
|
||||
|
||||
// handle() stores the XLA handle of the computation that the
|
||||
// expression represents.
|
||||
void set_handle(const xla::XlaOp& h);
|
||||
const xla::XlaOp& handle() const { return handle_; }
|
||||
|
||||
void set_constant_value(Tensor value);
|
||||
bool has_constant_value() const { return has_constant_value_; }
|
||||
const Tensor& constant_value() const { return constant_value_; }
|
||||
|
||||
void set_resource(XlaResource* resource) { resource_ = resource; }
|
||||
XlaResource* resource() const { return resource_; }
|
||||
|
||||
private:
|
||||
// The XLA handle of the expression's computation.
|
||||
xla::XlaOp handle_;
|
||||
|
||||
// If this expression is a constant with a known value, 'constant_value' is a
|
||||
// host-memory Tensor containing the value. Used to avoid invoking XLA for
|
||||
// expressions that are trivially constant.
|
||||
bool has_constant_value_ = false;
|
||||
Tensor constant_value_;
|
||||
|
||||
XlaResource* resource_ = nullptr; // Not owned.
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_
|
||||
|
@ -36,11 +36,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
@ -64,6 +66,240 @@ Status CheckSignature(const DataTypeVector& types,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Uses the _Arg and _Retval nodes in the graph to determine a core assignment
|
||||
// for each argument and return value.
|
||||
xla::StatusOr<std::pair<std::map<int, int>, std::map<int, int>>>
|
||||
ComputeArgAndRetvalCores(const Graph& graph) {
|
||||
auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr<int> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharding,
|
||||
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
|
||||
if (sharding.has_value()) {
|
||||
TF_RET_CHECK(sharding.value().type() ==
|
||||
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
return sharding.value().tile_assignment_devices(0);
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
std::map<int, int> arg_cores;
|
||||
std::map<int, int> retval_cores;
|
||||
for (const Node* n : graph.nodes()) {
|
||||
if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
|
||||
TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
|
||||
if (core < 0) continue;
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RET_CHECK(index >= 0) << "Negative _Arg index";
|
||||
arg_cores[index] = core;
|
||||
} else if (n->type_string() == FunctionLibraryDefinition::kRetOp) {
|
||||
TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
|
||||
if (core < 0) continue;
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RET_CHECK(index >= 0) << "Negative _Retval index";
|
||||
TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n));
|
||||
retval_cores[index] = core;
|
||||
}
|
||||
}
|
||||
return std::make_pair(std::move(arg_cores), std::move(retval_cores));
|
||||
}
|
||||
|
||||
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
|
||||
int64 step_id) {
|
||||
// Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
|
||||
// resource manager takes ownership via Create, and unrefs via Cleanup. We
|
||||
// explicitly add a reference to ensure the refcount at entry is maintained at
|
||||
// all exit points; Create and Cleanup are always called in this function.
|
||||
//
|
||||
// The Executor requires us to use ScopedStepContainer. We wrap it in a
|
||||
// unique_ptr so we can capture the cleanup status in the end.
|
||||
xla_context->Ref();
|
||||
Status status;
|
||||
auto step_container = absl::make_unique<ScopedStepContainer>(
|
||||
step_id, [&status, device](const string& name) {
|
||||
status = device->resource_manager()->Cleanup(name);
|
||||
});
|
||||
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
|
||||
step_container->name(), XlaContext::kXlaContextResourceName,
|
||||
xla_context));
|
||||
|
||||
GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
|
||||
TF_RETURN_IF_ERROR(graph_compiler.Compile());
|
||||
// Explicitly clean up the step container, to capture the cleanup status.
|
||||
step_container.reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builds the XLA computation.
|
||||
// - `args` is the list of input arguments
|
||||
// - `retvals` is the list of retvals produced by _Retval operators, in index
|
||||
// order.
|
||||
// - `args_core` and `retval_cores` are mapping from arg/return indices to core
|
||||
// assignments.
|
||||
// - If `return_updated_values_for_all_resources` is true, all resources will be
|
||||
// included in `resource_updates`, regardless of whether their value changed.
|
||||
// - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
|
||||
// - Sets `*resource_updates` to a description of resources whose values are
|
||||
// written by the computation; the variable writes are the last
|
||||
// - `resource_updates.size()` return values from the computation. Each entry in
|
||||
// `resource_updates` is a ResourceUpdate, whose `index` is the index of a
|
||||
// resource variable argument to the computation to be updated, and `type` is
|
||||
// the type of the final output.
|
||||
Status BuildComputation(
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const std::vector<XlaExpression>& retvals,
|
||||
const std::map<int, int>& arg_cores, const std::map<int, int>& retval_cores,
|
||||
const std::vector<std::unique_ptr<XlaResource>>& resources,
|
||||
std::unique_ptr<xla::XlaOp> token_output,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||
bool return_updated_values_for_all_resources, bool always_return_tuple,
|
||||
xla::XlaBuilder* builder, xla::XlaComputation* computation,
|
||||
int* num_computation_outputs, int* num_nonconst_outputs,
|
||||
std::vector<XlaCompiler::OutputDescription>* outputs,
|
||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
||||
// Attach a common operator name as metadata. This has no semantic effect — it
|
||||
// merely makes the HLO graph more readable when visualized via TensorBoard,
|
||||
// since TensorBoard forms groups out of operators with similar names.
|
||||
xla::OpMetadata retval_metadata;
|
||||
retval_metadata.set_op_name("XLA_Retvals");
|
||||
builder->SetOpMetadata(retval_metadata);
|
||||
auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
|
||||
|
||||
// Builds a no-op XLA computation. We need to set the sharding of outputs, but
|
||||
// cannot change the sharding of the existing output op. To do this, we build
|
||||
// a new identity op to which shardings can be applied.
|
||||
auto identity_op = [builder](xla::XlaOp op) {
|
||||
return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
|
||||
};
|
||||
|
||||
std::vector<xla::XlaOp> elems;
|
||||
elems.reserve(retvals.size());
|
||||
for (int i = 0; i < retvals.size(); ++i) {
|
||||
XlaCompiler::OutputDescription& output = (*outputs)[i];
|
||||
const XlaExpression& retval = retvals[i];
|
||||
output.type = retval.dtype();
|
||||
switch (retval.kind()) {
|
||||
case XlaExpression::Kind::kConstant:
|
||||
output.is_constant = true;
|
||||
output.constant_value = retval.constant_value();
|
||||
output.shape = output.constant_value.shape();
|
||||
break;
|
||||
|
||||
case XlaExpression::Kind::kXlaOp: {
|
||||
output.is_constant = false;
|
||||
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
|
||||
xla::XlaOp value = retval.handle();
|
||||
auto it = retval_cores.find(i);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, it == retval_cores.end()
|
||||
? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(it->second));
|
||||
if (shape_representation_fn) {
|
||||
// If there is a shape representation function, reshape the output
|
||||
// tensor to the shape given by the representation shape function.
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
|
||||
output.shape, output.type));
|
||||
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
|
||||
} else if (it != retval_cores.end()) {
|
||||
// Apply the sharding to the output, if there is a core assignment.
|
||||
value = identity_op(value);
|
||||
}
|
||||
elems.push_back(value);
|
||||
break;
|
||||
}
|
||||
|
||||
case XlaExpression::Kind::kResource:
|
||||
output.is_constant = false;
|
||||
output.input_index = retval.resource()->arg_num();
|
||||
output.shape = retval.resource()->shape();
|
||||
break;
|
||||
|
||||
case XlaExpression::Kind::kInvalid:
|
||||
return errors::InvalidArgument(
|
||||
"Invalid expression returned by computation. "
|
||||
"This probably means a return value was not set.");
|
||||
}
|
||||
}
|
||||
*num_nonconst_outputs = elems.size();
|
||||
|
||||
// Add return values for resources whose values have changed.
|
||||
std::vector<const XlaResource*> arg_resources;
|
||||
arg_resources.reserve(resources.size());
|
||||
for (const auto& resource : resources) {
|
||||
if (resource->arg_num() >= 0) {
|
||||
arg_resources.push_back(resource.get());
|
||||
}
|
||||
}
|
||||
std::sort(arg_resources.begin(), arg_resources.end(),
|
||||
[](const XlaResource* a, const XlaResource* b) {
|
||||
return a->arg_num() < b->arg_num();
|
||||
});
|
||||
|
||||
for (const XlaResource* resource : arg_resources) {
|
||||
DCHECK_LT(resource->arg_num(), args.size());
|
||||
const XlaCompiler::Argument& arg = args[resource->arg_num()];
|
||||
auto it = arg_cores.find(resource->arg_num());
|
||||
const int core = it == arg_cores.end() ? -1 : it->second;
|
||||
bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
|
||||
// TensorArray gradients were modified if their values changed or there are
|
||||
// any newly created gradients.
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
modified =
|
||||
modified ||
|
||||
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
|
||||
arg.tensor_array_gradients.count(grad.first) == 0;
|
||||
}
|
||||
if (return_updated_values_for_all_resources || modified) {
|
||||
resource_updates->emplace_back();
|
||||
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
||||
update.input_index = resource->arg_num();
|
||||
update.type = resource->type();
|
||||
update.shape = resource->shape();
|
||||
update.modified = modified;
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||
}
|
||||
|
||||
// Request that the value be returned on a specific core.
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
|
||||
xla::XlaOp handle;
|
||||
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
|
||||
|
||||
// Ensures the correct sharding is applied to the output.
|
||||
handle = identity_op(handle);
|
||||
|
||||
elems.push_back(handle);
|
||||
}
|
||||
}
|
||||
|
||||
// If we have token output, append it as the last one.
|
||||
if (token_output) {
|
||||
elems.push_back(*token_output);
|
||||
}
|
||||
|
||||
*num_computation_outputs = elems.size();
|
||||
|
||||
// Builds the XLA computation. We *always* form a tuple here to ensure that
|
||||
// the output value is the last thing added into the XLA computation, even
|
||||
// if there is only one output value.
|
||||
auto tuple = xla::Tuple(builder, elems);
|
||||
if (!always_return_tuple && elems.size() == 1) {
|
||||
xla::GetTupleElement(tuple, 0);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
|
||||
if (!computation_status.ok()) {
|
||||
return computation_status.status();
|
||||
}
|
||||
*computation = computation_status.ConsumeValueOrDie();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool XlaCompiler::Argument::operator==(
|
||||
@ -252,14 +488,16 @@ Status XlaCompiler::CompileFunction(
|
||||
// lowest-numbered core that consumes the argument. We choose the
|
||||
// lowest-numbered core so the assignment is deterministic.
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (absl::string_view(n->type_string()) == "_Arg") {
|
||||
if (absl::string_view(n->type_string()) ==
|
||||
FunctionLibraryDefinition::kArgOp) {
|
||||
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
|
||||
}
|
||||
}
|
||||
// Do _Retval as a second loop, in case the retval's input is an _Arg (which
|
||||
// may have gotten a device assignment from the first loop).
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (absl::string_view(n->type_string()) == "_Retval") {
|
||||
if (absl::string_view(n->type_string()) ==
|
||||
FunctionLibraryDefinition::kRetOp) {
|
||||
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
|
||||
}
|
||||
}
|
||||
@ -353,175 +591,16 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
|
||||
int64 step_id) {
|
||||
// Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
|
||||
// resource manager takes ownership via Create, and unrefs via Cleanup. We
|
||||
// explicitly add a reference to ensure the refcount at entry is maintained at
|
||||
// all exit points; Create and Cleanup are always called in this function.
|
||||
//
|
||||
// The Executor requires us to use ScopedStepContainer. We wrap it in a
|
||||
// unique_ptr so we can capture the cleanup status in the end.
|
||||
xla_context->Ref();
|
||||
Status status;
|
||||
auto step_container = absl::make_unique<ScopedStepContainer>(
|
||||
step_id, [&status, device](const string& name) {
|
||||
status = device->resource_manager()->Cleanup(name);
|
||||
});
|
||||
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
|
||||
step_container->name(), XlaContext::kXlaContextResourceName,
|
||||
xla_context));
|
||||
|
||||
GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
|
||||
TF_RETURN_IF_ERROR(graph_compiler.Compile());
|
||||
// Explicitly clean up the step container, to capture the cleanup status.
|
||||
step_container.reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builds the XLA computation.
|
||||
// `args` is the list of input arguments, `retvals` is the list of retvals
|
||||
// produced by _Retval operators, in index order.
|
||||
// If `return_updated_values_for_all_resources` is true, all resources will be
|
||||
// included in `resource_updates`, regardless of whether their value changed.
|
||||
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
|
||||
// Sets `*resource_updates` to a description of resources whose values are
|
||||
// written by the computation; the variable writes are the last
|
||||
// `resource_updates.size()` return values from the computation. Each entry in
|
||||
// `resource_updates` is a (input_index, type) pair, where `input_index` is the
|
||||
// index of a resource variable argument to the computation, and `type` is the
|
||||
// type of the final output.
|
||||
Status BuildComputation(
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const std::vector<int>& arg_cores,
|
||||
const std::vector<XlaContext::Retval>& retvals,
|
||||
const std::vector<std::unique_ptr<XlaResource>>& resources,
|
||||
std::unique_ptr<xla::XlaOp> token_output,
|
||||
bool return_updated_values_for_all_resources, bool always_return_tuple,
|
||||
xla::XlaBuilder* builder, xla::XlaComputation* computation,
|
||||
int* num_computation_outputs, int* num_nonconst_outputs,
|
||||
std::vector<XlaCompiler::OutputDescription>* outputs,
|
||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
||||
std::vector<xla::XlaOp> elems;
|
||||
elems.reserve(retvals.size());
|
||||
for (int i = 0; i < retvals.size(); ++i) {
|
||||
XlaCompiler::OutputDescription& output = (*outputs)[i];
|
||||
output.type = retvals[i].type;
|
||||
output.shape = retvals[i].shape;
|
||||
const XlaExpression& retval = retvals[i].expression;
|
||||
if (retval.has_constant_value()) {
|
||||
output.is_constant = true;
|
||||
output.constant_value = retval.constant_value();
|
||||
} else if (retval.resource() != nullptr) {
|
||||
output.is_constant = false;
|
||||
output.input_index = retval.resource()->arg_num();
|
||||
} else {
|
||||
output.is_constant = false;
|
||||
elems.push_back(retval.handle());
|
||||
}
|
||||
}
|
||||
*num_nonconst_outputs = elems.size();
|
||||
|
||||
// Add return values for resources whose values have changed.
|
||||
std::vector<const XlaResource*> arg_resources;
|
||||
arg_resources.reserve(resources.size());
|
||||
for (const auto& resource : resources) {
|
||||
if (resource->arg_num() >= 0) {
|
||||
arg_resources.push_back(resource.get());
|
||||
}
|
||||
}
|
||||
std::sort(arg_resources.begin(), arg_resources.end(),
|
||||
[](const XlaResource* a, const XlaResource* b) {
|
||||
return a->arg_num() < b->arg_num();
|
||||
});
|
||||
|
||||
// Attach a common operator name as metadata. This has no semantic effect — it
|
||||
// merely makes the HLO graph more readable when visualized via TensorBoard,
|
||||
// since TensorBoard forms groups out of operators with similar names.
|
||||
xla::OpMetadata retval_metadata;
|
||||
retval_metadata.set_op_name("XLA_Retvals");
|
||||
builder->SetOpMetadata(retval_metadata);
|
||||
|
||||
for (const XlaResource* resource : arg_resources) {
|
||||
const XlaCompiler::Argument& arg = args[resource->arg_num()];
|
||||
const int core = arg_cores[resource->arg_num()];
|
||||
DCHECK_LT(resource->arg_num(), arg_cores.size());
|
||||
bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
|
||||
// TensorArray gradients were modified if their values changed or there are
|
||||
// any newly created gradients.
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
modified =
|
||||
modified ||
|
||||
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
|
||||
arg.tensor_array_gradients.count(grad.first) == 0;
|
||||
}
|
||||
if (return_updated_values_for_all_resources || modified) {
|
||||
resource_updates->emplace_back();
|
||||
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
||||
update.input_index = resource->arg_num();
|
||||
update.type = resource->type();
|
||||
update.shape = resource->shape();
|
||||
update.modified = modified;
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||
}
|
||||
|
||||
// Request that the value be returned on a specific core.
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
|
||||
xla::XlaOp handle;
|
||||
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
|
||||
|
||||
// Since we can't change the sharding metadata of <value> as this point,
|
||||
// create a tuple/get-tuple-element combination so that sharding
|
||||
// assignment will be placed on this value, which will cause the resource
|
||||
// update to be returned from the same device that provided the resource.
|
||||
handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0);
|
||||
elems.push_back(handle);
|
||||
}
|
||||
}
|
||||
|
||||
// If we have token output, append it as the last one.
|
||||
if (token_output) {
|
||||
elems.push_back(*token_output);
|
||||
}
|
||||
|
||||
*num_computation_outputs = elems.size();
|
||||
|
||||
// Builds the XLA computation. We *always* form a tuple here to ensure that
|
||||
// the output value is the last thing added into the XLA computation, even
|
||||
// if there is only one output value.
|
||||
auto tuple = xla::Tuple(builder, elems);
|
||||
if (!always_return_tuple && elems.size() == 1) {
|
||||
xla::GetTupleElement(tuple, 0);
|
||||
}
|
||||
builder->ClearOpMetadata();
|
||||
|
||||
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
|
||||
if (!computation_status.ok()) {
|
||||
return computation_status.status();
|
||||
}
|
||||
*computation = computation_status.ConsumeValueOrDie();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Builds XLA computations for each of the arguments to the computation.
|
||||
// `args` are the arguments to the computation.
|
||||
Status XlaCompiler::BuildArguments(
|
||||
const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
|
||||
bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
|
||||
std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions,
|
||||
const std::map<int, int>& arg_cores,
|
||||
std::vector<XlaExpression>* arg_expressions,
|
||||
std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes,
|
||||
bool is_entry_computation) {
|
||||
arg_expressions->resize(args.size());
|
||||
*arg_cores = std::vector<int>(args.size(), -1);
|
||||
|
||||
// Argument numbers of arguments and resources that are to be passed to the
|
||||
// XLA computation as runtime parameters.
|
||||
@ -543,7 +622,7 @@ Status XlaCompiler::BuildArguments(
|
||||
arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(),
|
||||
/*tensor_array_size=*/arg.tensor_array_size,
|
||||
/*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
|
||||
arg_expression.set_resource(resource);
|
||||
arg_expression = XlaExpression::Resource(resource);
|
||||
if (arg.initialized) {
|
||||
input_mapping->push_back(i);
|
||||
}
|
||||
@ -555,7 +634,7 @@ Status XlaCompiler::BuildArguments(
|
||||
break;
|
||||
}
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
arg_expression.set_constant_value(arg.constant_value);
|
||||
arg_expression = XlaExpression::Constant(arg.constant_value);
|
||||
break;
|
||||
case XlaCompiler::Argument::kInvalid:
|
||||
return errors::Internal(
|
||||
@ -580,26 +659,6 @@ Status XlaCompiler::BuildArguments(
|
||||
*input_shapes = arg_shapes;
|
||||
}
|
||||
|
||||
// Use the _Arg nodes in the graph to resolve core assignments.
|
||||
for (const Node* n : graph.nodes()) {
|
||||
if (absl::string_view(n->type_string()) != "_Arg") continue;
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RET_CHECK(index >= 0 && index < args.size())
|
||||
<< "_Arg out of bounds: " << index << " vs " << args.size();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharding,
|
||||
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
|
||||
if (sharding.has_value()) {
|
||||
TF_RET_CHECK(sharding.value().type() ==
|
||||
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
const int core = sharding.value().tile_assignment_devices(0);
|
||||
if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) {
|
||||
(*arg_cores)[index] = core;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Attach a common operator name as metadata. This has no semantic effect — it
|
||||
// merely makes the HLO graph more readable when visualized via TensorBoard,
|
||||
// since TensorBoard forms groups out of operators with similar names.
|
||||
@ -615,11 +674,10 @@ Status XlaCompiler::BuildArguments(
|
||||
xla::OpSharding tuple_sharding;
|
||||
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
|
||||
for (int64 parameter : *input_mapping) {
|
||||
const int core = (*arg_cores)[parameter];
|
||||
const int root_device = 0;
|
||||
auto it = arg_cores.find(parameter);
|
||||
const int core = it == arg_cores.end() ? 0 : it->second;
|
||||
*tuple_sharding.add_tuple_shardings() =
|
||||
core == -1 ? xla::sharding_builder::AssignDevice(root_device)
|
||||
: xla::sharding_builder::AssignDevice(core);
|
||||
xla::sharding_builder::AssignDevice(core);
|
||||
}
|
||||
xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
|
||||
tuple_sharding);
|
||||
@ -628,7 +686,8 @@ Status XlaCompiler::BuildArguments(
|
||||
tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
|
||||
}
|
||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||
const int core = (*arg_cores)[input_mapping->at(i)];
|
||||
auto it = arg_cores.find(i);
|
||||
const int core = it == arg_cores.end() ? -1 : it->second;
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
@ -636,7 +695,8 @@ Status XlaCompiler::BuildArguments(
|
||||
}
|
||||
} else {
|
||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||
const int core = (*arg_cores)[input_mapping->at(i)];
|
||||
auto it = arg_cores.find(i);
|
||||
const int core = it == arg_cores.end() ? -1 : it->second;
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
@ -671,14 +731,14 @@ Status XlaCompiler::BuildArguments(
|
||||
// TODO(b/76097077): propagate device assignments onto arguments and
|
||||
// return values of functions, and then reshape unconditionally.
|
||||
if (is_entry_computation) {
|
||||
arg_expression.set_handle(
|
||||
xla::Reshape(arg_handles[i], arg.shape.dim_sizes()));
|
||||
arg_expression = XlaExpression::XlaOp(
|
||||
xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type);
|
||||
} else {
|
||||
arg_expression.set_handle(arg_handles[i]);
|
||||
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
|
||||
}
|
||||
break;
|
||||
case XlaCompiler::Argument::kToken: {
|
||||
arg_expression.set_handle(arg_handles[i]);
|
||||
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
|
||||
break;
|
||||
}
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
@ -710,7 +770,7 @@ Status XlaCompiler::CompileSingleOp(
|
||||
Node* node;
|
||||
string arg_name = absl::StrCat("_arg", i);
|
||||
Status status =
|
||||
NodeBuilder(arg_name, "_Arg")
|
||||
NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
|
||||
.ControlInput(graph->source_node())
|
||||
.Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE
|
||||
: args[i].type)
|
||||
@ -724,7 +784,7 @@ Status XlaCompiler::CompileSingleOp(
|
||||
for (int64 i = 0; i < result_types.size(); ++i) {
|
||||
Node* node;
|
||||
string retval_name = absl::StrCat("_retval", i);
|
||||
Status status = NodeBuilder(retval_name, "_Retval")
|
||||
Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
|
||||
.Input(main_node, i)
|
||||
.Attr("T", result_types[i])
|
||||
.Attr("index", i)
|
||||
@ -788,6 +848,32 @@ Status ValidateGraph(const Graph* graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts the value of any expressions whose values are known at compile-time
|
||||
// to constants.
|
||||
Status ResolveConstantExpressionsToConstants(
|
||||
xla::Client* client, absl::Span<XlaExpression> expressions) {
|
||||
for (XlaExpression& expression : expressions) {
|
||||
if (expression.kind() == XlaExpression::Kind::kXlaOp) {
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<Tensor> constant,
|
||||
expression.ResolveConstant(client));
|
||||
if (constant.has_value()) {
|
||||
expression = XlaExpression::Constant(*constant);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
|
||||
absl::Span<XlaExpression> expressions) {
|
||||
for (XlaExpression& expression : expressions) {
|
||||
if (expression.kind() == XlaExpression::Kind::kConstant) {
|
||||
expression =
|
||||
XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
@ -815,10 +901,9 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
options_.device_type, name));
|
||||
|
||||
xla::XlaBuilder builder(name);
|
||||
XlaContext* context = new XlaContext(
|
||||
this, &builder, options_.allow_cpu_custom_calls,
|
||||
options.resolve_compile_time_constants, options.is_entry_computation,
|
||||
&options_.shape_representation_fn);
|
||||
XlaContext* context =
|
||||
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
|
||||
&options_.shape_representation_fn);
|
||||
core::ScopedUnref context_unref(context);
|
||||
|
||||
std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
|
||||
@ -833,10 +918,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
real_args.push_back(token_arg);
|
||||
}
|
||||
|
||||
std::map<int, int> arg_cores;
|
||||
std::map<int, int> retval_cores;
|
||||
TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores),
|
||||
ComputeArgAndRetvalCores(*graph));
|
||||
|
||||
std::vector<XlaExpression> arg_expressions;
|
||||
std::vector<int> arg_cores;
|
||||
TF_RETURN_IF_ERROR(BuildArguments(
|
||||
*graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
|
||||
*graph, real_args, options.use_tuple_arg, &builder, context, arg_cores,
|
||||
&arg_expressions, &result->input_mapping, &result->xla_input_shapes,
|
||||
options.is_entry_computation));
|
||||
context->set_args(std::move(arg_expressions));
|
||||
@ -884,9 +973,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
int num_computation_outputs;
|
||||
result->computation = std::make_shared<xla::XlaComputation>();
|
||||
result->outputs.resize(context->retvals().size());
|
||||
std::vector<XlaExpression> retvals = context->retvals();
|
||||
if (options.resolve_compile_time_constants) {
|
||||
TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants(
|
||||
client(), absl::Span<XlaExpression>(retvals)));
|
||||
} else {
|
||||
ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(BuildComputation(
|
||||
real_args, arg_cores, context->retvals(), context->resources(),
|
||||
std::move(token_output), options.return_updated_values_for_all_resources,
|
||||
real_args, retvals, arg_cores, retval_cores, context->resources(),
|
||||
std::move(token_output),
|
||||
options.is_entry_computation ? options_.shape_representation_fn
|
||||
: ShapeRepresentationFn{},
|
||||
options.return_updated_values_for_all_resources,
|
||||
options.always_return_tuple, &builder, result->computation.get(),
|
||||
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
|
||||
&result->resource_updates));
|
||||
|
@ -21,8 +21,10 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -415,7 +417,8 @@ class XlaCompiler {
|
||||
Status BuildArguments(const Graph& graph,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
bool use_tuple_arg, xla::XlaBuilder* builder,
|
||||
XlaContext* context, std::vector<int>* arg_cores,
|
||||
XlaContext* context,
|
||||
const std::map<int, int>& arg_cores,
|
||||
std::vector<XlaExpression>* arg_expressions,
|
||||
std::vector<int>* input_mapping,
|
||||
std::vector<xla::Shape>* input_shapes,
|
||||
|
@ -64,63 +64,23 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
|
||||
|
||||
XlaContext::XlaContext(
|
||||
XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
|
||||
bool is_entry_computation,
|
||||
bool allow_cpu_custom_calls,
|
||||
const std::function<xla::StatusOr<xla::Shape>(
|
||||
const TensorShape&, DataType)>* shape_representation_fn)
|
||||
: compiler_(compiler),
|
||||
builder_(builder),
|
||||
allow_cpu_custom_calls_(allow_cpu_custom_calls),
|
||||
resolve_compile_time_constants_(resolve_compile_time_constants),
|
||||
is_entry_computation_(is_entry_computation),
|
||||
shape_representation_fn_(shape_representation_fn) {}
|
||||
|
||||
string XlaContext::DebugString() { return "TLA JIT context"; }
|
||||
|
||||
// This is called by the Retval Op to associate a computed value
|
||||
// with a specific return value of the subgraph.
|
||||
void XlaContext::AddRetval(int retval_index, DataType type,
|
||||
const TensorShape& shape, const xla::XlaOp& handle) {
|
||||
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
|
||||
// Add the return value to the list being built up.
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
void XlaContext::SetRetval(int index, const XlaExpression& expression) {
|
||||
if (retvals_.size() <= index) {
|
||||
retvals_.resize(index + 1);
|
||||
}
|
||||
XlaExpression e;
|
||||
e.set_handle(handle);
|
||||
retvals_[retval_index] = Retval{type, shape, e};
|
||||
retvals_[index] = expression;
|
||||
}
|
||||
|
||||
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::LiteralSlice& literal) {
|
||||
VLOG(1) << "Adding retval index " << retval_index
|
||||
<< " with non-data-dependent tensor to XLA computation";
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
Tensor value;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
|
||||
XlaExpression e;
|
||||
e.set_constant_value(value);
|
||||
retvals_[retval_index] = Retval{dtype, value.shape(), e};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
|
||||
VLOG(1) << "Adding retval index " << retval_index << " with resource "
|
||||
<< resource->name() << ":" << resource->shape().DebugString()
|
||||
<< " to XLA computation";
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
XlaExpression e;
|
||||
e.set_resource(resource);
|
||||
retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::XlaBuilder* XlaContext::builder() { return builder_; }
|
||||
|
||||
Status XlaContext::CreateResource(
|
||||
XlaResource::Kind kind, int arg_num, string name, DataType type,
|
||||
TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size,
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -46,8 +46,7 @@ class XlaContext : public ResourceBase {
|
||||
// Creates a new XlaContext. See the documentation on the class data fields
|
||||
// for descriptions of the arguments.
|
||||
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
|
||||
bool is_entry_computation,
|
||||
bool allow_cpu_custom_calls,
|
||||
const std::function<xla::StatusOr<xla::Shape>(
|
||||
const TensorShape&, DataType)>* shape_representation_fn);
|
||||
|
||||
@ -57,37 +56,19 @@ class XlaContext : public ResourceBase {
|
||||
XlaCompiler* compiler() const { return compiler_; }
|
||||
|
||||
// Returns the XlaBuilder that Ops use for compiling new expressions.
|
||||
xla::XlaBuilder* builder();
|
||||
xla::XlaBuilder* builder() { return builder_; }
|
||||
|
||||
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
|
||||
|
||||
bool resolve_compile_time_constants() const {
|
||||
return resolve_compile_time_constants_;
|
||||
}
|
||||
bool is_entry_computation() const { return is_entry_computation_; }
|
||||
|
||||
const std::vector<XlaExpression>& args() const { return args_; }
|
||||
void set_args(std::vector<XlaExpression> args);
|
||||
|
||||
struct Retval {
|
||||
DataType type;
|
||||
TensorShape shape;
|
||||
// An XlaExpression representing the Retval's value.
|
||||
XlaExpression expression;
|
||||
};
|
||||
const std::vector<Retval>& retvals() { return retvals_; }
|
||||
const std::vector<XlaExpression>& retvals() { return retvals_; }
|
||||
|
||||
// This is called by the Retval Op to associate a computed value
|
||||
// with a specific return value of the subgraph.
|
||||
void AddRetval(int retval_index, DataType type, const TensorShape& shape,
|
||||
const xla::XlaOp& handle);
|
||||
|
||||
// As for Retval, but for return values that are compile-time constants.
|
||||
Status AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::LiteralSlice& literal);
|
||||
|
||||
// As for Retval, but for return values that are resource handles.
|
||||
Status AddResourceRetval(int retval_index, XlaResource* resource);
|
||||
// Sets a return value.
|
||||
// Since we do not always know in advance how many return values there are,
|
||||
// grows the return values vector to size index+1 if it is smaller.
|
||||
void SetRetval(int index, const XlaExpression& expression);
|
||||
|
||||
// Creates a resource with resource `kind` and initial value `handle`. `name`
|
||||
// is a descriptive name for use in error messages. See the `XlaResource`
|
||||
@ -140,24 +121,16 @@ class XlaContext : public ResourceBase {
|
||||
// Allow ops to emit CustomCall operations for CPU.
|
||||
const bool allow_cpu_custom_calls_;
|
||||
|
||||
// If true, constant return values are returned as Tensors instead of
|
||||
// run-time computation outputs.
|
||||
const bool resolve_compile_time_constants_;
|
||||
|
||||
// Arguments to the Tensorflow graph, indexed by _Arg index.
|
||||
// Includes both compile-time constant arguments and runtime parameters.
|
||||
std::vector<XlaExpression> args_;
|
||||
|
||||
// Return values of the Tensorflow graph, indexed by _Retval index.
|
||||
std::vector<Retval> retvals_;
|
||||
std::vector<XlaExpression> retvals_;
|
||||
|
||||
// Holds ownership of resources. The resources are not ordered.
|
||||
std::vector<std::unique_ptr<XlaResource>> resources_;
|
||||
|
||||
// Is this a top-level computation, or an inner computation (e.g., a while
|
||||
// body)?
|
||||
const bool is_entry_computation_;
|
||||
|
||||
// Describes the on-host shapes of parameters and return values. Also see:
|
||||
// XlaDevice::Options::shape_representation_fn.
|
||||
const std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>*
|
||||
|
145
tensorflow/compiler/tf2xla/xla_expression.cc
Normal file
145
tensorflow/compiler/tf2xla/xla_expression.cc
Normal file
@ -0,0 +1,145 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
XlaExpression::XlaExpression() = default;
|
||||
|
||||
XlaExpression XlaExpression::Invalid() {
|
||||
XlaExpression e;
|
||||
e.kind_ = Kind::kInvalid;
|
||||
return e;
|
||||
}
|
||||
|
||||
XlaExpression XlaExpression::Constant(Tensor value) {
|
||||
XlaExpression e;
|
||||
e.kind_ = Kind::kConstant;
|
||||
e.dtype_ = value.dtype();
|
||||
e.constant_value_ = value;
|
||||
return e;
|
||||
}
|
||||
|
||||
XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
|
||||
XlaExpression e;
|
||||
e.kind_ = Kind::kXlaOp;
|
||||
e.dtype_ = dtype;
|
||||
e.handle_ = value;
|
||||
return e;
|
||||
}
|
||||
|
||||
XlaExpression XlaExpression::Resource(XlaResource* resource) {
|
||||
XlaExpression e;
|
||||
e.kind_ = Kind::kResource;
|
||||
e.dtype_ = DT_RESOURCE;
|
||||
e.resource_ = resource;
|
||||
return e;
|
||||
}
|
||||
|
||||
string XlaExpression::HumanString() const {
|
||||
switch (kind_) {
|
||||
case Kind::kInvalid:
|
||||
return "invalid";
|
||||
case Kind::kConstant:
|
||||
return "constant";
|
||||
case Kind::kXlaOp:
|
||||
return "xla_op";
|
||||
case Kind::kResource:
|
||||
return "resource";
|
||||
}
|
||||
}
|
||||
|
||||
xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
switch (kind_) {
|
||||
case Kind::kConstant: {
|
||||
xla::BorrowingLiteral literal;
|
||||
TF_RETURN_IF_ERROR(
|
||||
HostTensorToBorrowingLiteral(constant_value_, &literal));
|
||||
return xla::ConstantLiteral(builder, literal);
|
||||
}
|
||||
case Kind::kXlaOp:
|
||||
if (builder != handle_.builder()) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatched builders in XlaExpression::AsXlaOp");
|
||||
}
|
||||
return handle_;
|
||||
default:
|
||||
return errors::InvalidArgument("AsXlaOp called on XlaExpression: ",
|
||||
HumanString());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
|
||||
xla::Client* client) const {
|
||||
switch (kind()) {
|
||||
case Kind::kConstant:
|
||||
return {constant_value()};
|
||||
case Kind::kXlaOp:
|
||||
break;
|
||||
case Kind::kResource:
|
||||
case Kind::kInvalid:
|
||||
return errors::InvalidArgument(
|
||||
"ResolveConstant called on XlaExpression: ", HumanString());
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool is_constant,
|
||||
handle().builder()->IsConstant(handle()));
|
||||
if (!is_constant) return {absl::nullopt};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
|
||||
handle().builder()->BuildConstantSubGraph(handle()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
|
||||
|
||||
// The XLA layout is specified minor to major, and TensorFlow uses a major to
|
||||
// minor order.
|
||||
std::vector<int64> layout_indices(shape.dims());
|
||||
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
|
||||
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
|
||||
TF_ASSIGN_OR_RETURN(xla::Literal literal,
|
||||
client->ComputeConstant(constant_graph, &layout));
|
||||
Tensor tensor;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor));
|
||||
return {tensor};
|
||||
}
|
||||
|
||||
xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
|
||||
switch (kind_) {
|
||||
case Kind::kConstant:
|
||||
return constant_value().shape();
|
||||
case Kind::kXlaOp: {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
|
||||
handle().builder()->GetShape(handle()));
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
|
||||
return shape;
|
||||
}
|
||||
case Kind::kResource:
|
||||
return TensorShape({});
|
||||
case Kind::kInvalid:
|
||||
return errors::InvalidArgument(
|
||||
"GetShape() called on invalid XlaExpression");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
115
tensorflow/compiler/tf2xla/xla_expression.h
Normal file
115
tensorflow/compiler/tf2xla/xla_expression.h
Normal file
@ -0,0 +1,115 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||
#include "tensorflow/compiler/xla/client/client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA
|
||||
// compilation.
|
||||
// An expression is one of:
|
||||
// * a constant tensor.
|
||||
// * an xla::XlaOp, representing a symbolic XLA value.
|
||||
// * a resource, e.g., a variable, represented as an XlaResource pointer.
|
||||
//
|
||||
// Constant tensors are mostly an optimization to avoid passing large constants
|
||||
// to XLA, but are also sometimes used to represent tensors that have no XLA
|
||||
// representation, for example, DT_STRING tensors. A canonical use case might be
|
||||
// an error message string.
|
||||
class XlaExpression {
|
||||
public:
|
||||
enum class Kind {
|
||||
kInvalid,
|
||||
kConstant,
|
||||
kXlaOp,
|
||||
kResource,
|
||||
};
|
||||
|
||||
XlaExpression();
|
||||
XlaExpression(const XlaExpression&) = default;
|
||||
XlaExpression& operator=(const XlaExpression&) = default;
|
||||
|
||||
// Builds an invalid expression. (Same as the default constructor, but makes
|
||||
// the intent clearer.)
|
||||
static XlaExpression Invalid();
|
||||
|
||||
// Builds a constant XLA expression.
|
||||
static XlaExpression Constant(Tensor value);
|
||||
|
||||
// Builds a XlaOp expression. Since the mapping from TF data types to XLA
|
||||
// types is not 1-1, the TF type must also be provided; in general it cannot
|
||||
// be derived from the XLA type.
|
||||
static XlaExpression XlaOp(xla::XlaOp value, DataType dtype);
|
||||
|
||||
// Builds a resource expression.
|
||||
static XlaExpression Resource(XlaResource* resource);
|
||||
|
||||
Kind kind() const { return kind_; }
|
||||
|
||||
DataType dtype() const { return dtype_; }
|
||||
|
||||
// handle() returns the XlaOp that backs a kXlaOp expression.
|
||||
const xla::XlaOp& handle() const { return handle_; }
|
||||
|
||||
const Tensor& constant_value() const { return constant_value_; }
|
||||
|
||||
XlaResource* resource() const { return resource_; }
|
||||
|
||||
// Returns a human-readable summary of the expression.
|
||||
string HumanString() const;
|
||||
|
||||
// Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns
|
||||
// an erroneous XlaOp if the expression is not a constant or an expression.
|
||||
xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const;
|
||||
|
||||
// If a kXlaOp or kConstant expression can be resolved to a compile-time
|
||||
// constant, returns the value as a host-memory Tensor. Returns an empty
|
||||
// optional if it cannot be resolved. Returns an error if passed a resource
|
||||
// expression.
|
||||
xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
|
||||
xla::Client* client) const;
|
||||
|
||||
// Returns the shape of the tensor.
|
||||
// The shape of a resource is the shape of a resource handle (i.e., a scalar),
|
||||
// not the shape of the resource's value.
|
||||
xla::StatusOr<TensorShape> GetShape() const;
|
||||
|
||||
private:
|
||||
Kind kind_ = Kind::kInvalid;
|
||||
|
||||
DataType dtype_ = DT_INVALID;
|
||||
|
||||
// The XLA handle of the expression's computation, if kind_ == kXlaOp.
|
||||
xla::XlaOp handle_;
|
||||
|
||||
// The value of the constant, if kind_ == kConstant.
|
||||
Tensor constant_value_;
|
||||
|
||||
// The resource, if kind_ == kResource. Not owned.
|
||||
XlaResource* resource_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
|
135
tensorflow/compiler/tf2xla/xla_expression_test.cc
Normal file
135
tensorflow/compiler/tf2xla/xla_expression_test.cc
Normal file
@ -0,0 +1,135 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class XlaExpressionTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
client_ = xla::ClientLibrary::LocalClientOrDie();
|
||||
builder_ = absl::make_unique<xla::XlaBuilder>("acomputation");
|
||||
constant_ = test::AsScalar<int32>(42);
|
||||
op_ = xla::ConstantR0<int32>(builder_.get(), 7);
|
||||
non_constant_op_ = xla::Parameter(
|
||||
builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x");
|
||||
resource_ = absl::make_unique<XlaResource>(
|
||||
XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"),
|
||||
DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1,
|
||||
/*tensor_array_gradients=*/std::set<string>(),
|
||||
/*tensor_array_multiple_writes_aggregate=*/false);
|
||||
}
|
||||
|
||||
xla::Client* client_;
|
||||
std::unique_ptr<xla::XlaBuilder> builder_;
|
||||
Tensor constant_;
|
||||
xla::XlaOp op_;
|
||||
xla::XlaOp non_constant_op_;
|
||||
std::unique_ptr<XlaResource> resource_;
|
||||
};
|
||||
|
||||
TEST_F(XlaExpressionTest, Kind) {
|
||||
EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind());
|
||||
EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind());
|
||||
EXPECT_TRUE(XlaExpression::Kind::kConstant ==
|
||||
XlaExpression::Constant(constant_).kind());
|
||||
EXPECT_TRUE(XlaExpression::Kind::kXlaOp ==
|
||||
XlaExpression::XlaOp(op_, DT_INT32).kind());
|
||||
EXPECT_TRUE(XlaExpression::Kind::kResource ==
|
||||
XlaExpression::Resource(resource_.get()).kind());
|
||||
}
|
||||
|
||||
TEST_F(XlaExpressionTest, HumanString) {
|
||||
EXPECT_EQ("invalid", XlaExpression().HumanString());
|
||||
EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString());
|
||||
EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString());
|
||||
EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString());
|
||||
EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString());
|
||||
}
|
||||
|
||||
TEST_F(XlaExpressionTest, AsXlaOp) {
|
||||
xla::XlaOp op_as_op =
|
||||
XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get());
|
||||
EXPECT_TRUE(op_.IsIdenticalTo(op_as_op));
|
||||
|
||||
xla::XlaOp const_as_op =
|
||||
XlaExpression::Constant(constant_).AsXlaOp(builder_.get());
|
||||
TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation,
|
||||
builder_->BuildConstantSubGraph(const_as_op));
|
||||
TF_ASSERT_OK_AND_ASSIGN(xla::Literal value,
|
||||
client_->ComputeConstant(computation));
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0<int32>(42),
|
||||
value));
|
||||
}
|
||||
|
||||
TEST_F(XlaExpressionTest, GetShape) {
|
||||
EXPECT_FALSE(XlaExpression().GetShape().ok());
|
||||
EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape,
|
||||
XlaExpression::Resource(resource_.get()).GetShape());
|
||||
EXPECT_EQ(TensorShape({}), resource_shape);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape,
|
||||
XlaExpression::XlaOp(op_, DT_INT32).GetShape());
|
||||
EXPECT_EQ(TensorShape({}), op_shape);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape,
|
||||
XlaExpression::Constant(constant_).GetShape());
|
||||
EXPECT_EQ(TensorShape({}), constant_shape);
|
||||
}
|
||||
|
||||
TEST_F(XlaExpressionTest, ResolveConstant) {
|
||||
EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
|
||||
EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
|
||||
EXPECT_FALSE(
|
||||
XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
absl::optional<Tensor> op_constant,
|
||||
XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_));
|
||||
ASSERT_TRUE(op_constant.has_value());
|
||||
test::ExpectTensorEqual<int32>(test::AsScalar<int32>(7), *op_constant);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(absl::optional<Tensor> op_nonconstant,
|
||||
XlaExpression::XlaOp(non_constant_op_, DT_FLOAT)
|
||||
.ResolveConstant(client_));
|
||||
EXPECT_FALSE(op_nonconstant.has_value());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
absl::optional<Tensor> constant_constant,
|
||||
XlaExpression::Constant(constant_).ResolveConstant(client_));
|
||||
ASSERT_TRUE(constant_constant.has_value());
|
||||
test::ExpectTensorEqual<int32>(constant_, *constant_constant);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
@ -43,32 +44,36 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const {
|
||||
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
||||
CHECK(expression->handle().valid() || expression->resource() != nullptr);
|
||||
VLOG(1) << "Fetched T" << expression->handle();
|
||||
CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
|
||||
<< expression->HumanString();
|
||||
return expression;
|
||||
}
|
||||
|
||||
// Retrieves an uninitialized XlaExpression from a newly-allocated tensor.
|
||||
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) {
|
||||
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
||||
static void AssignExpressionToTensor(Tensor* tensor,
|
||||
const XlaExpression& value) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
|
||||
CHECK(!expression->handle().valid());
|
||||
return const_cast<XlaExpression*>(expression);
|
||||
CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
|
||||
<< expression->HumanString();
|
||||
*const_cast<XlaExpression*>(expression) = value;
|
||||
}
|
||||
|
||||
// Retrieves the XlaOp from an input Tensor to an Op. This computation was
|
||||
// constructed by an Op that executed previously and created the output Tensor
|
||||
// using CreateOutputTensorFromComputation or CreateConstantOutputTensor.
|
||||
static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) {
|
||||
return CastExpressionFromTensor(tensor)->handle();
|
||||
const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
|
||||
return *CastExpressionFromTensor(context_->input(index));
|
||||
}
|
||||
|
||||
const xla::XlaOp& XlaOpKernelContext::Input(int index) {
|
||||
return GetComputationFromTensor(context_->input(index));
|
||||
const XlaExpression& XlaOpKernelContext::InputExpression(
|
||||
absl::string_view name) {
|
||||
return *CastExpressionFromTensor(GetInputTensorByName(name));
|
||||
}
|
||||
|
||||
const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) {
|
||||
return GetComputationFromTensor(GetInputTensorByName(name));
|
||||
xla::XlaOp XlaOpKernelContext::Input(int index) {
|
||||
return InputExpression(index).AsXlaOp(builder());
|
||||
}
|
||||
|
||||
xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
|
||||
return InputExpression(name).AsXlaOp(builder());
|
||||
}
|
||||
|
||||
TensorShape XlaOpKernelContext::InputShape(int index) {
|
||||
@ -125,59 +130,18 @@ Status XlaOpKernelContext::ConstantInput(absl::string_view name,
|
||||
Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
int index, absl::Span<const int64> new_dims,
|
||||
xla::Literal* constant_literal) {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
TensorShape new_shape(new_dims);
|
||||
if (tensor.NumElements() != new_shape.num_elements()) {
|
||||
return errors::InvalidArgument(
|
||||
context_->op_kernel().name(), " input ", index, " has shape ",
|
||||
tensor.shape().DebugString(),
|
||||
" but was asked to be reshaped to incompatible shape ",
|
||||
new_shape.DebugString());
|
||||
}
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
|
||||
// If the tensor has a known constant value, there is no need to invoke XLA.
|
||||
if (expression->has_constant_value()) {
|
||||
Tensor temp(tensor.dtype());
|
||||
if (!temp.CopyFrom(expression->constant_value(), new_shape)) {
|
||||
// This should never happen. The constant should have a shape compatible
|
||||
// with the enclosing Tensor.
|
||||
return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Make sure we treat zero-element tensors as constant.
|
||||
if (new_shape.num_elements() == 0) {
|
||||
Tensor temp(tensor.dtype(), new_shape);
|
||||
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::XlaOp handle = expression->handle();
|
||||
if (new_shape != tensor.shape()) {
|
||||
// Reshape the handle to the desired shape.
|
||||
handle = xla::Reshape(handle, new_shape.dim_sizes());
|
||||
}
|
||||
|
||||
// The XLA layout is specified minor to major, and TensorFlow's minor
|
||||
// dimension is the last one.
|
||||
std::vector<int64> layout_indices(new_shape.dims());
|
||||
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
|
||||
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
|
||||
|
||||
xla::StatusOr<bool> is_constant = builder()->IsConstant(handle);
|
||||
if (!is_constant.ok()) {
|
||||
Status status = is_constant.status();
|
||||
XlaExpression e = InputExpression(index);
|
||||
xla::StatusOr<absl::optional<Tensor>> constant_or_status =
|
||||
e.ResolveConstant(compiler()->client());
|
||||
if (!constant_or_status.ok()) {
|
||||
Status status = constant_or_status.status();
|
||||
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
|
||||
context_->op_kernel().type_string(),
|
||||
" operator as a compile-time constant.");
|
||||
return status;
|
||||
}
|
||||
|
||||
if (!is_constant.ValueOrDie()) {
|
||||
absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
|
||||
if (!constant.has_value()) {
|
||||
return errors::InvalidArgument(
|
||||
"Input ", index, " to ", context_->op_kernel().type_string(),
|
||||
" operator must be a compile-time constant.\n"
|
||||
@ -190,25 +154,16 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
"stateful operation such as a random number generator.");
|
||||
}
|
||||
|
||||
// Ask the XLA compiler to evaluate the data handle to a literal.
|
||||
xla::StatusOr<xla::XlaComputation> constant_graph =
|
||||
builder()->BuildConstantSubGraph(handle);
|
||||
if (!constant_graph.ok()) {
|
||||
return errors::Internal(
|
||||
"Error getting a compile-time constant graph for ",
|
||||
context_->op_kernel().name(), " input ", index,
|
||||
".\nError: ", constant_graph.status().error_message());
|
||||
Tensor temp(constant->dtype());
|
||||
if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
|
||||
return errors::InvalidArgument(
|
||||
context_->op_kernel().name(), " input ", index, " has shape ",
|
||||
constant->shape().DebugString(),
|
||||
" but was asked to be reshaped to incompatible shape ",
|
||||
TensorShape(new_dims).DebugString());
|
||||
}
|
||||
xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
|
||||
constant_graph.ValueOrDie(), &layout);
|
||||
if (!computed.ok()) {
|
||||
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
|
||||
" input ", index,
|
||||
" as a compile-time constant.\nError: ",
|
||||
computed.status().error_message());
|
||||
}
|
||||
*constant_literal = std::move(computed).ValueOrDie();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -363,7 +318,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
|
||||
handles->clear();
|
||||
shapes->clear();
|
||||
for (const Tensor& input : inputs) {
|
||||
handles->push_back(GetComputationFromTensor(input));
|
||||
handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
|
||||
shapes->push_back(input.shape());
|
||||
}
|
||||
return Status::OK();
|
||||
@ -449,90 +404,53 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape,
|
||||
Tensor** output) {
|
||||
// The step's default allocator is the dummy XlaCompilationAllocator which
|
||||
// simply allocates a metadata buffer to hold the expression to which it
|
||||
// corresponds.
|
||||
if (expected_output_dtype(index) == DT_VARIANT) {
|
||||
// tensor_data() is not supported for variant Tensor (i.e.,
|
||||
// DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
|
||||
// XlaExpression inside the Tensor's tensor_data() does not work for
|
||||
// variant. Instead construct a uint8 tensor and store the expression in its
|
||||
// value.
|
||||
// TODO(jpienaar): This should be refactored to stop masquerading
|
||||
// XlaExpressions as Tensors.
|
||||
*output = new Tensor();
|
||||
TensorShape tensor_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context_->allocate_temp(DT_UINT8, tensor_shape, *output));
|
||||
context_->set_output(index, **output);
|
||||
} else {
|
||||
TensorShape tensor_shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape));
|
||||
TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output));
|
||||
void XlaOpKernelContext::SetOutputExpression(int index,
|
||||
const XlaExpression& expression) {
|
||||
Status status = [&] {
|
||||
// The step's default allocator is the dummy XlaCompilationAllocator which
|
||||
// simply allocates a metadata buffer to hold the expression to which it
|
||||
// corresponds.
|
||||
Tensor* output = nullptr;
|
||||
// Provides a special behavior for DT_VARIANT: a variant is treated as
|
||||
// DT_UINT8 scalar as the type to allow mapping for variant to more generic
|
||||
// types.
|
||||
if (expression.dtype() == DT_VARIANT) {
|
||||
// tensor_data() is not supported for variant Tensor (i.e.,
|
||||
// DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
|
||||
// XlaExpression inside the Tensor's tensor_data() does not work for
|
||||
// variant. Instead construct a uint8 tensor and store the expression in
|
||||
// its value.
|
||||
// TODO(jpienaar): This should be refactored to stop masquerading
|
||||
// XlaExpressions as Tensors.
|
||||
output = new Tensor();
|
||||
TensorShape tensor_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context_->allocate_temp(DT_UINT8, tensor_shape, output));
|
||||
context_->set_output(index, *output);
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
|
||||
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
|
||||
}
|
||||
AssignExpressionToTensor(output, expression);
|
||||
return Status::OK();
|
||||
}();
|
||||
if (!status.ok()) {
|
||||
SetStatus(status);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
|
||||
// Makes the host Tensor that will refer to the expression.
|
||||
Tensor* output = nullptr;
|
||||
auto shape_or = builder()->GetShape(handle);
|
||||
if (!shape_or.ok()) {
|
||||
SetStatus(shape_or.status());
|
||||
return;
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context_,
|
||||
allocate_output(index, shape_or.ValueOrDie(), &output));
|
||||
|
||||
// The expression is stored in the tensor's data buffer. Fill in the
|
||||
// fields now.
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
expression->set_handle(handle);
|
||||
SetOutputExpression(
|
||||
index,
|
||||
XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
|
||||
const TensorShape& shape = constant.shape();
|
||||
|
||||
xla::BorrowingLiteral literal;
|
||||
OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal));
|
||||
|
||||
xla::XlaOp handle = xla::ConstantLiteral(builder(), literal);
|
||||
CHECK(handle.valid());
|
||||
|
||||
// Make the Tensor that will refer to the expression.
|
||||
Tensor* output = nullptr;
|
||||
// The step's default allocator is the dummy XlaCompilationAllocator which
|
||||
// simply allocates a metadata buffer to hold the expression to which it
|
||||
// corresponds.
|
||||
OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output));
|
||||
|
||||
// The expression is stored in the tensor's data buffer. Fill in the
|
||||
// fields now.
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
expression->set_handle(handle);
|
||||
expression->set_constant_value(constant);
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetInvalidOutput(int index) {
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context_,
|
||||
context_->allocate_output(index, TensorShape({}), &output));
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
xla::XlaOp handle;
|
||||
expression->set_handle(handle);
|
||||
SetOutputExpression(index, XlaExpression::Constant(constant));
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
|
||||
Tensor* output = nullptr;
|
||||
// The shape of the output tensor is the shape of the resource itself
|
||||
// (i.e., a scalar), not the shape of the resource's value.
|
||||
OP_REQUIRES_OK(context_,
|
||||
context_->allocate_output(index, TensorShape(), &output));
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
expression->set_resource(resource);
|
||||
SetOutputExpression(index, XlaExpression::Resource(resource));
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
|
||||
|
@ -88,9 +88,9 @@ class XlaOpKernelContext {
|
||||
// Returns input `index` as a XlaOp. Unlike
|
||||
// OpKernelContext::Input returns a symbolic value rather than a concrete
|
||||
// Tensor.
|
||||
const xla::XlaOp& Input(int index);
|
||||
xla::XlaOp Input(int index);
|
||||
// Returns input `name` as a XlaOp.
|
||||
const xla::XlaOp& Input(absl::string_view name);
|
||||
xla::XlaOp Input(absl::string_view name);
|
||||
|
||||
// Returns true if all inputs are the same shape, otherwise sets the
|
||||
// status to a non-OK value and returns false.
|
||||
@ -142,6 +142,10 @@ class XlaOpKernelContext {
|
||||
Status ConstantInputList(absl::string_view name,
|
||||
std::vector<xla::Literal>* literals);
|
||||
|
||||
// Returns an XlaExpression describing the value of 'index'.
|
||||
const XlaExpression& InputExpression(int index);
|
||||
const XlaExpression& InputExpression(absl::string_view name);
|
||||
|
||||
// Outputs
|
||||
|
||||
int num_outputs() const { return context_->num_outputs(); }
|
||||
@ -159,9 +163,8 @@ class XlaOpKernelContext {
|
||||
// SetConstantOutput where possible.
|
||||
void SetConstantOutput(int index, const Tensor& host_tensor);
|
||||
|
||||
// Sets output `index` to an invalid value.
|
||||
// Any subsequent attempt to consume this output will cause an error.
|
||||
void SetInvalidOutput(int index);
|
||||
// Returns an XlaExpression describing the value of 'index'.
|
||||
void SetOutputExpression(int index, const XlaExpression& expression);
|
||||
|
||||
// Status handling.
|
||||
void SetStatus(const Status& status) { context_->SetStatus(status); }
|
||||
@ -249,11 +252,6 @@ class XlaOpKernelContext {
|
||||
// Returns the tensor of input `name`.
|
||||
const Tensor& GetInputTensorByName(absl::string_view name);
|
||||
|
||||
// Wraps OpKernelContext's allocate_output method while providing special
|
||||
// behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the
|
||||
// type to allow mapping for variant to more generic types.
|
||||
Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
|
||||
|
||||
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
|
||||
// InputShape(index), and stores it in `*constant_literal`. If the input
|
||||
// cannot be evaluated, e.g., because it depends on unbound parameters,
|
||||
|
Loading…
Reference in New Issue
Block a user