[TF:XLA] Cleanups to handling of arguments during XLA compilation:
* combine resource kinds in XlaCompiler::Argument::Kind, use a separate XlaResource::Kind field to distinguish different kinds of resource. * merge XlaContext::HandleOrConstant and XlaExpression, which were almost identical. * remove XlaContext::Argument; instead, build XlaExpressions directly from XlaCompiler and add them to the XlaContext. PiperOrigin-RevId: 168439341
This commit is contained in:
parent
7f5346a809
commit
8f37f30027
tensorflow/compiler
@ -184,7 +184,8 @@ Status BuildArguments(int num_constant_args,
|
||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
||||
|
||||
arg.name = variable_args[variable_id].name;
|
||||
arg.kind = XlaCompiler::Argument::kVariable;
|
||||
arg.kind = XlaCompiler::Argument::kResource;
|
||||
arg.resource_kind = XlaResource::kVariable;
|
||||
if (variable_args[variable_id].present) {
|
||||
const Tensor& value = variable_args[variable_id].value;
|
||||
arg.type = value.dtype();
|
||||
|
@ -21,14 +21,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// This OpKernel implements the _Arg Op for XLA JIT devices. It
|
||||
// associates its output with one of the arguments to a
|
||||
// subcomputation.
|
||||
class ArgOp : public XlaOpKernel {
|
||||
class XlaArgOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ArgOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
explicit XlaArgOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
|
||||
}
|
||||
@ -49,35 +48,13 @@ class ArgOp : public XlaOpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
XlaContext& xc = XlaContext::Get(ctx);
|
||||
const XlaContext::Argument& arg = xc.args()[index_];
|
||||
if (arg.is_resource) {
|
||||
XlaResource::Kind kind;
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kVariable:
|
||||
kind = XlaResource::kVariable;
|
||||
break;
|
||||
case XlaCompiler::Argument::kTensorArray:
|
||||
kind = XlaResource::kTensorArray;
|
||||
break;
|
||||
case XlaCompiler::Argument::kStack:
|
||||
kind = XlaResource::kStack;
|
||||
break;
|
||||
default:
|
||||
CHECK(false);
|
||||
}
|
||||
|
||||
// TODO(phawkins): this code assumes that variables do not alias.
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
xc.CreateResource(kind, index_, arg.name, arg.value.type,
|
||||
arg.value.handle, &resource));
|
||||
resource->tensor_array_size = arg.tensor_array_size;
|
||||
ctx->SetResourceOutput(0, resource);
|
||||
} else if (arg.value.is_constant) {
|
||||
ctx->SetConstantOutput(0, arg.value.constant_value);
|
||||
const XlaExpression& arg = XlaContext::Get(ctx).args()[index_];
|
||||
if (arg.resource() != nullptr) {
|
||||
ctx->SetResourceOutput(0, arg.resource());
|
||||
} else if (arg.has_constant_value()) {
|
||||
ctx->SetConstantOutput(0, arg.constant_value());
|
||||
} else {
|
||||
ctx->SetOutput(0, arg.value.handle);
|
||||
ctx->SetOutput(0, arg.handle());
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,10 +62,9 @@ class ArgOp : public XlaOpKernel {
|
||||
int index_;
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), ArgOp);
|
||||
REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), XlaArgOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -50,19 +50,8 @@ Status MakeXlaCompilerArgumentsFromInputs(
|
||||
TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource));
|
||||
|
||||
arg.initialized = resource->value.handle() > 0;
|
||||
switch (resource->kind) {
|
||||
case XlaResource::kVariable:
|
||||
arg.kind = XlaCompiler::Argument::kVariable;
|
||||
break;
|
||||
case XlaResource::kTensorArray:
|
||||
arg.kind = XlaCompiler::Argument::kTensorArray;
|
||||
break;
|
||||
case XlaResource::kStack:
|
||||
arg.kind = XlaCompiler::Argument::kStack;
|
||||
break;
|
||||
case XlaResource::kInvalid:
|
||||
CHECK(false);
|
||||
}
|
||||
arg.kind = XlaCompiler::Argument::kResource;
|
||||
arg.resource_kind = resource->kind;
|
||||
arg.type = resource->type;
|
||||
if (arg.initialized) {
|
||||
auto shape = ctx->builder()->GetShape(resource->value);
|
||||
|
@ -139,8 +139,6 @@ class XlaExpression {
|
||||
Tensor constant_value_;
|
||||
|
||||
XlaResource* resource_ = nullptr; // Not owned.
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -104,9 +105,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||
|
||||
XlaCompiler::~XlaCompiler() = default;
|
||||
|
||||
int64 XlaCompiler::NextStepId() {
|
||||
return next_step_id_++;
|
||||
}
|
||||
int64 XlaCompiler::NextStepId() { return next_step_id_++; }
|
||||
|
||||
uint64 XlaCompiler::SignatureHash::operator()(
|
||||
const std::pair<string, std::vector<Argument>>& signature) const {
|
||||
@ -260,10 +259,11 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
// `args` are the arguments to the computation.
|
||||
Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
bool use_tuple_arg, xla::ComputationBuilder* builder,
|
||||
std::vector<XlaContext::Argument>* context_args,
|
||||
XlaContext* context,
|
||||
std::vector<XlaExpression>* arg_expressions,
|
||||
std::vector<int>* input_mapping,
|
||||
std::vector<xla::Shape>* input_shapes) {
|
||||
context_args->resize(args.size());
|
||||
arg_expressions->resize(args.size());
|
||||
|
||||
// Argument numbers of arguments and resources that are to be passed to the
|
||||
// XLA computation as runtime parameters.
|
||||
@ -271,33 +271,31 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
parameters.reserve(args.size());
|
||||
resources.reserve(args.size());
|
||||
|
||||
// Fills in constant arguments, and computes non-constant argument order.
|
||||
for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
|
||||
++i) {
|
||||
XlaContext::Argument& context_arg = (*context_args)[i];
|
||||
context_arg.kind = args[i].kind;
|
||||
context_arg.name = args[i].name;
|
||||
context_arg.value.constant_value = args[i].constant_value;
|
||||
context_arg.value.type = args[i].type;
|
||||
|
||||
switch (args[i].kind) {
|
||||
case XlaCompiler::Argument::kVariable:
|
||||
case XlaCompiler::Argument::kTensorArray:
|
||||
case XlaCompiler::Argument::kStack:
|
||||
context_arg.is_resource = true;
|
||||
if (args[i].initialized) {
|
||||
const XlaCompiler::Argument& arg = args[i];
|
||||
XlaExpression& arg_expression = (*arg_expressions)[i];
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kResource:
|
||||
TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
|
||||
// TODO(phawkins): this code assumes that resource arguments do not
|
||||
// alias.
|
||||
XlaResource* resource;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->CreateResource(arg.resource_kind, i, arg.name, arg.type,
|
||||
xla::ComputationDataHandle(), &resource));
|
||||
resource->tensor_array_size = arg.tensor_array_size;
|
||||
arg_expression.set_resource(resource);
|
||||
if (arg.initialized) {
|
||||
resources.push_back(i);
|
||||
context_arg.value.is_constant = false;
|
||||
} else {
|
||||
context_arg.value.is_constant = true;
|
||||
}
|
||||
context_arg.tensor_array_size = args[i].tensor_array_size;
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
parameters.push_back(i);
|
||||
context_arg.value.is_constant = false;
|
||||
break;
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
context_arg.value.is_constant = true;
|
||||
arg_expression.set_constant_value(arg.constant_value);
|
||||
break;
|
||||
case XlaCompiler::Argument::kInvalid:
|
||||
return errors::Internal("Unreachable case in BuildArguments()");
|
||||
@ -313,27 +311,48 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
|
||||
input_shapes->resize(parameters.size());
|
||||
input_mapping->resize(parameters.size());
|
||||
for (std::vector<int>::size_type i = 0; i < input_shapes->size(); ++i) {
|
||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args[parameters[i]];
|
||||
// Computes the shapes of non-constant arguments.
|
||||
(*input_shapes)[i] = arg.shape;
|
||||
(*input_mapping)[i] = parameters[i];
|
||||
}
|
||||
|
||||
// Build parameter handles for non-constant arguments.
|
||||
std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
|
||||
if (use_tuple_arg) {
|
||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(*input_shapes);
|
||||
xla::ComputationDataHandle tuple =
|
||||
builder->Parameter(0, tuple_shape, "arg_tuple");
|
||||
for (std::vector<int>::size_type i = 0; i < input_shapes->size(); ++i) {
|
||||
(*context_args)[parameters[i]].value.handle =
|
||||
builder->GetTupleElement(tuple, i);
|
||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
||||
arg_handles[i] = builder->GetTupleElement(tuple, i);
|
||||
}
|
||||
} else {
|
||||
for (std::vector<int>::size_type i = 0; i < input_shapes->size(); ++i) {
|
||||
(*context_args)[parameters[i]].value.handle =
|
||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
||||
arg_handles[i] =
|
||||
builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i));
|
||||
}
|
||||
}
|
||||
|
||||
// Fill in the handles in non-constant arguments.
|
||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args[parameters[i]];
|
||||
XlaExpression& arg_expression = (*arg_expressions)[parameters[i]];
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kResource:
|
||||
TF_RET_CHECK(arg.initialized);
|
||||
arg_expression.resource()->value = arg_handles[i];
|
||||
arg_expression.resource()->initial_value = arg_handles[i];
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
arg_expression.set_handle(arg_handles[i]);
|
||||
break;
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
case XlaCompiler::Argument::kInvalid:
|
||||
return errors::Internal("Unreachable case in BuildArguments()");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -354,7 +373,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
// index of a resource variable argument to the computation, and `type` is the
|
||||
// type of the final output.
|
||||
Status BuildComputation(
|
||||
const std::vector<XlaContext::HandleOrConstant>& retvals,
|
||||
const std::vector<XlaExpression>& retvals,
|
||||
const std::vector<std::unique_ptr<XlaResource>>& resources,
|
||||
bool has_side_effects, bool return_updated_values_for_all_resources,
|
||||
xla::ComputationBuilder* builder, xla::Computation* computation,
|
||||
@ -362,9 +381,9 @@ Status BuildComputation(
|
||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
||||
std::vector<xla::ComputationDataHandle> elems;
|
||||
elems.reserve(retvals.size());
|
||||
for (const XlaContext::HandleOrConstant& retval : retvals) {
|
||||
if (!retval.is_constant) {
|
||||
elems.push_back(retval.handle);
|
||||
for (const XlaExpression& retval : retvals) {
|
||||
if (!retval.has_constant_value()) {
|
||||
elems.push_back(retval.handle());
|
||||
}
|
||||
}
|
||||
*num_nonconst_outputs = elems.size();
|
||||
@ -460,11 +479,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
|
||||
result->tuple_arg = options.use_tuple_arg;
|
||||
|
||||
std::vector<XlaContext::Argument> context_args;
|
||||
TF_RETURN_IF_ERROR(BuildArguments(args, options.use_tuple_arg, &builder,
|
||||
&context_args, &result->input_mapping,
|
||||
&result->xla_input_shapes));
|
||||
context->set_args(std::move(context_args));
|
||||
std::vector<XlaExpression> arg_expressions;
|
||||
TF_RETURN_IF_ERROR(BuildArguments(
|
||||
args, options.use_tuple_arg, &builder, context, &arg_expressions,
|
||||
&result->input_mapping, &result->xla_input_shapes));
|
||||
context->set_args(std::move(arg_expressions));
|
||||
|
||||
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
|
||||
flib_runtime_, NextStepId()));
|
||||
@ -486,14 +505,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
VLOG(2) << "Outputs: total: " << context->retvals().size()
|
||||
<< " nonconstant: " << num_nonconst_outputs;
|
||||
result->outputs.resize(context->retvals().size());
|
||||
for (std::vector<XlaContext::HandleOrConstant>::size_type i = 0;
|
||||
for (std::vector<XlaExpression>::size_type i = 0;
|
||||
i < context->retvals().size(); ++i) {
|
||||
const XlaContext::HandleOrConstant& retval = context->retvals()[i];
|
||||
if (retval.is_constant) {
|
||||
const XlaExpression& retval = context->retvals()[i];
|
||||
if (retval.has_constant_value()) {
|
||||
OutputDescription& output = result->outputs[i];
|
||||
output.shape = retval.constant_value.shape();
|
||||
output.shape = retval.constant_value().shape();
|
||||
output.is_constant = true;
|
||||
output.constant_value = retval.constant_value;
|
||||
output.constant_value = retval.constant_value();
|
||||
}
|
||||
}
|
||||
|
||||
@ -518,10 +537,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
|
||||
// Converts the output shapes to TensorShapes.
|
||||
int computation_output = 0;
|
||||
for (std::vector<XlaContext::HandleOrConstant>::size_type i = 0;
|
||||
for (std::vector<XlaExpression>::size_type i = 0;
|
||||
i < context->retvals().size(); ++i) {
|
||||
const XlaContext::HandleOrConstant& retval = context->retvals()[i];
|
||||
if (!retval.is_constant) {
|
||||
const XlaExpression& retval = context->retvals()[i];
|
||||
if (!retval.has_constant_value()) {
|
||||
CHECK_LT(computation_output, num_computation_outputs);
|
||||
OutputDescription& output = result->outputs[i];
|
||||
output.is_constant = false;
|
||||
|
@ -85,17 +85,9 @@ class XlaCompiler {
|
||||
// Argument is a compile-time constant. No associated runtime parameter.
|
||||
kConstant,
|
||||
|
||||
// Argument is a Variable resource. Has an associated runtime parameter
|
||||
// iff `initialized` is true.
|
||||
kVariable,
|
||||
|
||||
// Argument is a TensorArray resource. Has an associated runtime parameter
|
||||
// iff `initialized` is true.
|
||||
kTensorArray,
|
||||
|
||||
// Argument is a Stack resource. Has an associated runtime parameter
|
||||
// iff `initialized` is true.
|
||||
kStack,
|
||||
// Argument is a Variable, TensorArray, or Stack resource. Has an
|
||||
// associated runtime parameter iff `initialized` is true.
|
||||
kResource,
|
||||
|
||||
// Argument is a run-time parameter.
|
||||
kParameter,
|
||||
@ -118,11 +110,14 @@ class XlaCompiler {
|
||||
// The name of this argument, used for debugging.
|
||||
string name;
|
||||
|
||||
// For a kVariable or kTensorArray, has this resource been initialized?
|
||||
// For a kResource, what kind of resource is it?
|
||||
XlaResource::Kind resource_kind = XlaResource::kInvalid;
|
||||
|
||||
// For a kResource, has this resource been initialized?
|
||||
bool initialized = false;
|
||||
|
||||
// For a kTensorArray, what is the array's declared size? (Used for lazy
|
||||
// initialization.)
|
||||
// For a TensorArray or Stack resource, what is the array's declared size?
|
||||
// (Used for lazy initialization.)
|
||||
int64 tensor_array_size = -1;
|
||||
|
||||
bool operator==(const Argument& other) const;
|
||||
|
@ -58,7 +58,7 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context";
|
||||
return Get(ctx->op_kernel_context());
|
||||
}
|
||||
|
||||
void XlaContext::set_args(std::vector<Argument> args) {
|
||||
void XlaContext::set_args(std::vector<XlaExpression> args) {
|
||||
args_ = std::move(args);
|
||||
}
|
||||
|
||||
@ -78,8 +78,8 @@ XlaContext::GetOrCreateRuntimeContextParameter() {
|
||||
|
||||
// Allocate the next available parameter for the context parameter.
|
||||
int num_parameters = 0;
|
||||
for (const Argument& arg : args_) {
|
||||
if (!arg.value.is_constant) {
|
||||
for (const XlaExpression& arg : args_) {
|
||||
if (!arg.has_constant_value()) {
|
||||
++num_parameters;
|
||||
}
|
||||
}
|
||||
@ -99,9 +99,7 @@ void XlaContext::AddRetval(int retval_index, DataType type,
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
retvals_[retval_index].is_constant = false;
|
||||
retvals_[retval_index].type = type;
|
||||
retvals_[retval_index].handle = handle;
|
||||
retvals_[retval_index].set_handle(handle);
|
||||
}
|
||||
|
||||
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
@ -111,14 +109,12 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
retvals_[retval_index].type = dtype;
|
||||
if (resolve_compile_time_constants_) {
|
||||
retvals_[retval_index].is_constant = true;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(
|
||||
literal, dtype, &retvals_[retval_index].constant_value));
|
||||
Tensor value;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
|
||||
retvals_[retval_index].set_constant_value(std::move(value));
|
||||
} else {
|
||||
retvals_[retval_index].is_constant = false;
|
||||
retvals_[retval_index].handle = builder_->ConstantLiteral(literal);
|
||||
retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
@ -37,34 +38,6 @@ class XlaOpKernelContext;
|
||||
// subgraph of Ops using XLA.
|
||||
class XlaContext : public ResourceBase {
|
||||
public:
|
||||
// A struct that represents either a compile-time constant, or an XLA
|
||||
// computation handle. Used to represent arguments and return values.
|
||||
struct HandleOrConstant {
|
||||
// Is this a compile-time constant? If so, what is its value?
|
||||
bool is_constant;
|
||||
Tensor constant_value; // Must be in host memory.
|
||||
|
||||
// If this is not a constant, a computation handle. Since the mapping from
|
||||
// Tensorflow types to XLA types is not necessarily injective (one-to-one),
|
||||
// we also require the Tensorflow type.
|
||||
DataType type;
|
||||
xla::ComputationDataHandle handle;
|
||||
};
|
||||
|
||||
struct Argument {
|
||||
XlaCompiler::Argument::Kind kind;
|
||||
|
||||
// Descriptive name for the resource, for use in error messages.
|
||||
string name;
|
||||
|
||||
// Is this a resource?
|
||||
bool is_resource = false;
|
||||
|
||||
HandleOrConstant value;
|
||||
|
||||
int64 tensor_array_size = -1;
|
||||
};
|
||||
|
||||
// Retrieves the XlaContext of the current compilation.
|
||||
static XlaContext& Get(const OpKernelContext* ctx);
|
||||
static XlaContext& Get(const XlaOpKernelContext* ctx);
|
||||
@ -85,14 +58,14 @@ class XlaContext : public ResourceBase {
|
||||
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
|
||||
bool has_context_parameter() const { return has_context_parameter_; }
|
||||
|
||||
const std::vector<Argument>& args() const { return args_; }
|
||||
void set_args(std::vector<Argument> args);
|
||||
const std::vector<XlaExpression>& args() const { return args_; }
|
||||
void set_args(std::vector<XlaExpression> args);
|
||||
|
||||
// Get the runtime context parameter, adding one if it does not already exist.
|
||||
// Dies if not compiling a local executable.
|
||||
const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter();
|
||||
|
||||
const std::vector<HandleOrConstant>& retvals() { return retvals_; }
|
||||
const std::vector<XlaExpression>& retvals() { return retvals_; }
|
||||
|
||||
// This is called by the Retval Op to associate a computed value
|
||||
// with a specific return value of the subgraph.
|
||||
@ -155,10 +128,10 @@ class XlaContext : public ResourceBase {
|
||||
|
||||
// Arguments to the Tensorflow graph, indexed by _Arg index.
|
||||
// Includes both compile-time constant arguments and runtime parameters.
|
||||
std::vector<Argument> args_;
|
||||
std::vector<XlaExpression> args_;
|
||||
|
||||
// Return values of the Tensorflow graph, indexed by _Retval index.
|
||||
std::vector<HandleOrConstant> retvals_;
|
||||
std::vector<XlaExpression> retvals_;
|
||||
|
||||
// Does the computation have side effects, i.e., Send() calls?
|
||||
bool has_side_effects_ = false;
|
||||
|
Loading…
Reference in New Issue
Block a user