Merge pull request #41734 from tg-at-google:wsign-compare-semi-final-tf-c-compiler-jit
PiperOrigin-RevId: 324069278 Change-Id: I80ef17b15db7e847c6025abeddf498b15c05f10d
This commit is contained in:
commit
64cc68feb9
@ -172,7 +172,7 @@ string RewriteWithName(const string& name, string code,
|
||||
Status GenArgMethods(const tf2xla::Config& config,
|
||||
const xla::ProgramShapeProto& ps,
|
||||
const CompileResult& compile_result, string* methods) {
|
||||
size_t num_args = ps.parameters_size();
|
||||
const int num_args = ps.parameters_size();
|
||||
// feed_size() + variable_size() is the maximum number of args as an
|
||||
// implementation may not create an argument for an unused variable.
|
||||
if (config.feed_size() + config.variable_size() < num_args) {
|
||||
@ -229,8 +229,9 @@ Status GenResultMethods(const tf2xla::Config& config,
|
||||
int readonly_variables = absl::c_count_if(
|
||||
config.variable(),
|
||||
[](const tf2xla::Variable& var) { return var.readonly(); });
|
||||
if (config.fetch_size() + config.variable_size() - readonly_variables !=
|
||||
num_results) {
|
||||
const int actual_num_results =
|
||||
config.fetch_size() + config.variable_size() - readonly_variables;
|
||||
if (actual_num_results != num_results) {
|
||||
return errors::InvalidArgument("mismatch between fetch_size(",
|
||||
config.fetch_size(), ")+variable_size(",
|
||||
config.variable_size(), ") and tuple_size(",
|
||||
@ -273,7 +274,7 @@ Status GenResultMethods(const tf2xla::Config& config,
|
||||
// Generate methods for variables.
|
||||
Status GenVariableMethods(const tf2xla::Config& config,
|
||||
const xla::ProgramShapeProto& ps, string* methods) {
|
||||
size_t num_args = ps.parameters_size();
|
||||
const int num_args = ps.parameters_size();
|
||||
for (int i = config.feed_size(); i < num_args; ++i) {
|
||||
std::vector<std::pair<string, string>> rewrites;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -401,7 +402,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
|
||||
std::vector<string> buffer_infos_as_strings =
|
||||
BufferInfosToCppExpression(buffer_infos);
|
||||
if (result_index < 0 || result_index >= buffer_infos.size()) {
|
||||
const int64 buffer_infos_size = buffer_infos.size();
|
||||
if (result_index < 0 || result_index >= buffer_infos_size) {
|
||||
return errors::InvalidArgument("result index: ", result_index,
|
||||
" is outside the range of temp sizes: [0,",
|
||||
buffer_infos.size(), ")");
|
||||
@ -797,8 +799,8 @@ Status ParseCppClass(const string& cpp_class, string* class_name,
|
||||
// Allow a fully qualified name that starts with "::".
|
||||
parts.erase(parts.begin());
|
||||
}
|
||||
for (int i = 0; i < parts.size(); ++i) {
|
||||
if (i < parts.size() - 1) {
|
||||
for (int i = 0, end = parts.size(); i < end; ++i) {
|
||||
if (i < end - 1) {
|
||||
TF_RETURN_IF_ERROR(ValidateCppIdent(
|
||||
parts[i], "in namespace component of cpp_class: " + cpp_class));
|
||||
namespaces->push_back(parts[i]);
|
||||
|
@ -452,7 +452,7 @@ Status PredicateInt32Inputs(const Scope& root, Node* n,
|
||||
root.graph()->AddControlEdge(predicate_as_control.node(),
|
||||
identity_n.operation.node());
|
||||
|
||||
for (int i = 0; i < int32_inputs.size(); i++) {
|
||||
for (int i = 0, end = int32_inputs.size(); i < end; i++) {
|
||||
TF_RETURN_IF_ERROR(root.graph()->UpdateEdge(identity_n[i].node(), i, n,
|
||||
int32_inputs_input_idxs[i]));
|
||||
}
|
||||
|
@ -257,7 +257,7 @@ class RecursiveCompilabilityChecker {
|
||||
UncompilableNodesMap* uncompilable_nodes_map);
|
||||
|
||||
// Make sure we don't recurse infinitely on recursive functions.
|
||||
const int kMaxRecursionDepth = 10;
|
||||
const size_t kMaxRecursionDepth = 10;
|
||||
|
||||
const OperationFilter& op_filter_;
|
||||
const DeviceType& jit_device_type_;
|
||||
|
@ -26,8 +26,8 @@ using xla::StatusOr;
|
||||
void DeviceSet::Insert(DeviceId device_id) {
|
||||
int word_index = device_id.id() / kWordSize;
|
||||
int bit_index = device_id.id() % kWordSize;
|
||||
|
||||
if (word_index >= storage_.size()) {
|
||||
const int storage_size = storage_.size();
|
||||
if (word_index >= storage_size) {
|
||||
storage_.resize(word_index + 1, 0);
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ void DeviceSet::UnionWith(const DeviceSet& other) {
|
||||
storage_.resize(other.storage_.size(), 0);
|
||||
}
|
||||
|
||||
for (int i = 0; i < other.storage_.size(); i++) {
|
||||
for (int i = 0, end = other.storage_.size(); i < end; i++) {
|
||||
storage_[i] |= other.storage_[i];
|
||||
}
|
||||
}
|
||||
|
@ -72,7 +72,8 @@ class DeviceSet {
|
||||
void ForEach(FnTy func) const {
|
||||
// This is really a poor man's iterator, we should consider writing a proper
|
||||
// iterator if this ends up being used widely.
|
||||
for (int word_index = 0; word_index < storage_.size(); word_index++) {
|
||||
for (int word_index = 0, end = storage_.size(); word_index < end;
|
||||
word_index++) {
|
||||
uint64 word = storage_[word_index];
|
||||
while (word != 0) {
|
||||
uint64 only_lowest_bit_set = word & -word;
|
||||
|
@ -1132,7 +1132,8 @@ static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
if (index < 0 || index >= types->size()) {
|
||||
const int num_types = types->size();
|
||||
if (index < 0 || index >= num_types) {
|
||||
return errors::InvalidArgument("Invalid argument number");
|
||||
}
|
||||
(*types)[index] = n->output_type(0);
|
||||
@ -1149,7 +1150,8 @@ static Status RenumberArguments(Graph* graph,
|
||||
if (n->type_string() == kArgOp) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
if (index < 0 || index >= permutation.size()) {
|
||||
const int permutation_size = permutation.size();
|
||||
if (index < 0 || index >= permutation_size) {
|
||||
return errors::InvalidArgument("Invalid argument number");
|
||||
}
|
||||
n->AddAttr("index", permutation[index]);
|
||||
|
@ -139,7 +139,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
|
||||
// Remove the edge from host to outside compilation. Add a placeholder as
|
||||
// outside compilation node input.
|
||||
std::map<std::pair<string, int>, Node*> placeholders;
|
||||
for (int i = 0; i < edges.size(); i++) {
|
||||
for (int i = 0, end = edges.size(); i < end; i++) {
|
||||
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
||||
const Edge* e;
|
||||
TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
|
||||
@ -185,7 +185,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
|
||||
// Other edge in `edges` might have `e->dst()` as src or dst
|
||||
// node. Before removing `e->dst()`, replace those edges with
|
||||
// corresponding edges for `dst_replace_node`.
|
||||
for (int j = i + 1; j < edges.size(); j++) {
|
||||
for (int j = i + 1, end = edges.size(); j < end; j++) {
|
||||
if (edges[j].dst_node_id == edges[i].dst_node_id) {
|
||||
edges[j].dst_node_id = dst_replace_node->id();
|
||||
}
|
||||
@ -238,7 +238,7 @@ Status PostprocessDataEdgesBetweenOutsideCompilations(
|
||||
g->AddControlEdge(original_node, e->dst());
|
||||
g->RemoveEdge(e);
|
||||
}
|
||||
for (int i = 0; i < data_edges.size(); i++) {
|
||||
for (int i = 0, end = data_edges.size(); i < end; i++) {
|
||||
Node* dst = data_edges[i].dst;
|
||||
NodeDef new_def = dst->def();
|
||||
int dst_input = data_edges[i].dst_input;
|
||||
@ -253,7 +253,7 @@ Status PostprocessDataEdgesBetweenOutsideCompilations(
|
||||
|
||||
// Other edges might have `dst` as dst node. Update those edges with
|
||||
// `replace_node`.
|
||||
for (int j = i + 1; j < data_edges.size(); j++) {
|
||||
for (int j = i + 1, end = data_edges.size(); j < end; j++) {
|
||||
if (data_edges[j].dst == dst) {
|
||||
data_edges[j].dst = replace_node;
|
||||
}
|
||||
|
@ -351,14 +351,14 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
for (int i = 0; i < data_inputs.size(); ++i) {
|
||||
for (int i = 0, end = data_inputs.size(); i < end; ++i) {
|
||||
graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch,
|
||||
i);
|
||||
}
|
||||
for (Node* n : control_inputs) {
|
||||
graph->AddControlEdge(n, xla_launch);
|
||||
}
|
||||
for (int i = 0; i < data_outputs.size(); ++i) {
|
||||
for (int i = 0, end = data_outputs.size(); i < end; ++i) {
|
||||
for (const auto& successor : data_outputs[i]) {
|
||||
graph->AddEdge(xla_launch, i, successor.first, successor.second);
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ Status GetArgDataTypes(const std::vector<Node*>& arg_nodes,
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
|
||||
(*recv_at_host_dtypes)[index] = dtype;
|
||||
}
|
||||
for (int i = 0; i < recv_at_host_dtypes->size(); i++) {
|
||||
for (int i = 0, end = recv_at_host_dtypes->size(); i < end; i++) {
|
||||
if ((*recv_at_host_dtypes)[i] == DT_INVALID) {
|
||||
return errors::Internal("Cannot get datatype for input ", i);
|
||||
}
|
||||
@ -160,7 +160,7 @@ xla::StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
|
||||
}
|
||||
|
||||
// Rewrite dst nodes because their input changed.
|
||||
for (int i = 0; i < out_edge_info.size(); i++) {
|
||||
for (int i = 0, end = out_edge_info.size(); i < end; i++) {
|
||||
const OutEdgeInfo edge = out_edge_info[i];
|
||||
if (edge.dst_input == Graph::kControlSlot) {
|
||||
continue;
|
||||
@ -174,7 +174,7 @@ xla::StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
|
||||
|
||||
// Other edges might have `dst` as dst node as well. Update those edges
|
||||
// with `dst_replace`.
|
||||
for (int j = i + 1; j < out_edge_info.size(); j++) {
|
||||
for (int j = i + 1, end = out_edge_info.size(); j < end; j++) {
|
||||
if (out_edge_info[j].dst == dst) {
|
||||
out_edge_info[j].dst = dst_replace;
|
||||
}
|
||||
@ -196,7 +196,7 @@ Status GetRetDataTypes(const std::vector<Node*>& ret_nodes,
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
|
||||
(*send_from_host_dtypes)[index] = dtype;
|
||||
}
|
||||
for (int i = 0; i < send_from_host_dtypes->size(); i++) {
|
||||
for (int i = 0, end = send_from_host_dtypes->size(); i < end; i++) {
|
||||
if ((*send_from_host_dtypes)[i] == DT_INVALID) {
|
||||
return errors::Internal("Cannot get datatype for output ", i);
|
||||
}
|
||||
@ -226,7 +226,8 @@ xla::StatusOr<Node*> BuildSendFromHostNode(
|
||||
for (auto* n : ret_nodes) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
if (index < 0 || index >= send_from_host_dtypes.size()) {
|
||||
const int num_dtypes = send_from_host_dtypes.size();
|
||||
if (index < 0 || index >= num_dtypes) {
|
||||
return errors::Internal("Invalid _Retval index: ", index);
|
||||
}
|
||||
for (auto edge : n->in_edges()) {
|
||||
@ -361,7 +362,8 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (e->dst_input() < 0 || e->dst_input() >= input_dtypes.size()) {
|
||||
const int input_dtypes_size = input_dtypes.size();
|
||||
if (e->dst_input() < 0 || e->dst_input() >= input_dtypes_size) {
|
||||
return errors::Internal("Invalid dst_input: ", e->dst_input());
|
||||
}
|
||||
inputs[e->dst_input()] = NodeDefBuilder::NodeOut{
|
||||
@ -500,7 +502,7 @@ void AddEdgesFromOutsideCompilationNodes(
|
||||
const std::vector<DataType>& data_types,
|
||||
const std::vector<Node*>& outside_compilation_nodes, Graph* g, Node* n) {
|
||||
// Add edges from outside compilation nodes to While node.
|
||||
for (int i = original_arg_count; i < data_types.size(); i++) {
|
||||
for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
|
||||
Node* outside_compilation_node =
|
||||
outside_compilation_nodes[i - original_arg_count];
|
||||
g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset);
|
||||
@ -619,7 +621,7 @@ Status PostprocessLiftedArgsForWhile(
|
||||
lifted_arg_nodes_and_outside_compilation_nodes.end(),
|
||||
std::back_inserter(lifted_arg_nodes),
|
||||
[](const std::pair<Node*, Node*>& pair) { return pair.first; });
|
||||
for (int i = original_arg_count; i < data_types.size(); i++) {
|
||||
for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
|
||||
TF_ASSIGN_OR_RETURN(Node * arg_node,
|
||||
AddOutsideCompilationInputArgToFunctionBody(
|
||||
*body_function_body, i, data_types[i]));
|
||||
@ -648,7 +650,7 @@ Status PostprocessLiftedArgsForWhile(
|
||||
AttrSlice(&cond_func.attr()), fld,
|
||||
&cond_function_body));
|
||||
|
||||
for (int i = original_arg_count; i < data_types.size(); i++) {
|
||||
for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
|
||||
xla::StatusOr<Node*> arg_node_or =
|
||||
AddOutsideCompilationInputArgToFunctionBody(*cond_function_body, i,
|
||||
data_types[i]);
|
||||
@ -759,7 +761,7 @@ Status PostprocessLiftedArgsForIf(
|
||||
data_types, outside_compilation_nodes, g,
|
||||
n);
|
||||
|
||||
for (int i = original_arg_count; i < data_types.size(); ++i) {
|
||||
for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
|
||||
TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node,
|
||||
AddOutsideCompilationInputArgToFunctionBody(
|
||||
*then_branch_function_body, i, data_types[i]));
|
||||
@ -837,7 +839,7 @@ Status PostprocessLiftedArgsForCall(
|
||||
lifted_arg_nodes_and_outside_compilation_nodes.end(),
|
||||
std::back_inserter(lifted_arg_nodes),
|
||||
[](const std::pair<Node*, Node*>& pair) { return pair.first; });
|
||||
for (int i = original_arg_count; i < data_types.size(); ++i) {
|
||||
for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Node * arg_node,
|
||||
AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i]));
|
||||
@ -855,7 +857,7 @@ Status PostprocessLiftedArgsForCall(
|
||||
// We need to recreate the node. Otherwise TF will not know n->num_inputs()
|
||||
// has increased.
|
||||
NodeDef node_def = n->def();
|
||||
for (int i = original_arg_count; i < data_types.size(); i++) {
|
||||
for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
|
||||
Node* outside_compilation_node =
|
||||
lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
|
||||
.second;
|
||||
@ -1804,7 +1806,9 @@ TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode(
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(e->dst_input() >= 0 && e->dst_input() < inputs.size());
|
||||
const bool input_size_check =
|
||||
e->dst_input() < static_cast<int>(inputs.size());
|
||||
TF_RET_CHECK(e->dst_input() >= 0 && input_size_check);
|
||||
inputs[e->dst_input()] =
|
||||
NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
|
||||
e->src()->output_type(e->src_output())};
|
||||
|
@ -461,7 +461,7 @@ string GraphCycles::DebugString() const {
|
||||
}
|
||||
|
||||
string result = "digraph {\n";
|
||||
for (int i = 0; i < rep_->nodes_.size(); i++) {
|
||||
for (int i = 0, end = rep_->nodes_.size(); i < end; i++) {
|
||||
if (free_nodes_set.contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -194,7 +194,7 @@ Status ComputeSliceSize(const Scope& host_scope,
|
||||
ConstantCache constant_pool(host_scope, control_deps);
|
||||
|
||||
std::vector<Output> slice_size;
|
||||
for (int i = 0; i < slice_inputs.size_as_vector.size(); i++) {
|
||||
for (int i = 0, end = slice_inputs.size_as_vector.size(); i < end; i++) {
|
||||
if (slice_inputs.size_as_vector[i] >= 0) {
|
||||
slice_size.push_back(
|
||||
constant_pool.Get1DHostConstant(slice_inputs.size_as_vector[i]));
|
||||
|
@ -36,7 +36,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
|
||||
if (!context->RankKnown(handle)) return Status::OK();
|
||||
|
||||
std::vector<int64> dims(context->Rank(handle));
|
||||
for (int32 i = 0; i < dims.size(); ++i) {
|
||||
for (int32 i = 0, end = dims.size(); i < end; ++i) {
|
||||
dims[i] = context->Value(context->Dim(handle, i));
|
||||
}
|
||||
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
|
||||
|
@ -489,7 +489,7 @@ Status GetNodesRelatedToRefVariablesInDirection(
|
||||
/*stable_comparator=*/NodeComparatorName());
|
||||
}
|
||||
|
||||
int old_result_size;
|
||||
size_t old_result_size;
|
||||
int iterations = 0;
|
||||
|
||||
const int kMaxIterations = 10 * 1000;
|
||||
|
@ -97,7 +97,7 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
|
||||
if (arg_shapes != other.arg_shapes) return false;
|
||||
|
||||
if (arg_values.size() != other.arg_values.size()) return false;
|
||||
for (int i = 0; i < arg_values.size(); ++i) {
|
||||
for (int i = 0, end = arg_values.size(); i < end; ++i) {
|
||||
if (arg_values[i].dtype() != other.arg_values[i].dtype() ||
|
||||
arg_values[i].shape() != other.arg_values[i].shape() ||
|
||||
arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
|
||||
@ -158,7 +158,7 @@ Status XlaCompilationCache::BuildExecutable(
|
||||
|
||||
std::vector<const xla::Shape*> argument_layouts(
|
||||
result.xla_input_shapes.size());
|
||||
for (int i = 0; i < result.xla_input_shapes.size(); ++i) {
|
||||
for (int i = 0, end = result.xla_input_shapes.size(); i < end; ++i) {
|
||||
argument_layouts[i] = &result.xla_input_shapes[i];
|
||||
}
|
||||
xla::ExecutableBuildOptions build_options;
|
||||
@ -224,7 +224,7 @@ static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
||||
|
||||
// Create dummy _Arg nodes. Link these to `node` and also via a control
|
||||
// dependency edge to the _SOURCE node.
|
||||
for (int64 i = 0; i < args.size(); ++i) {
|
||||
for (int64 i = 0, end = args.size(); i < end; ++i) {
|
||||
Node* node;
|
||||
string arg_name = absl::StrCat("_arg", i);
|
||||
Status status =
|
||||
@ -240,7 +240,7 @@ static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
||||
}
|
||||
|
||||
// Similarly with return values, create dummy _Retval nodes fed by `node`.
|
||||
for (int64 i = 0; i < result_types.size(); ++i) {
|
||||
for (int64 i = 0, end = result_types.size(); i < end; ++i) {
|
||||
Node* node;
|
||||
string retval_name = absl::StrCat("_retval", i);
|
||||
Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
|
||||
@ -271,7 +271,7 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
auto compile_op = [&](XlaCompiler* compiler,
|
||||
XlaCompiler::CompilationResult* result) {
|
||||
std::vector<DataType> result_dtypes(ctx->num_outputs());
|
||||
for (int i = 0; i < result_dtypes.size(); ++i) {
|
||||
for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
|
||||
result_dtypes[i] = ctx->expected_output_dtype(i);
|
||||
}
|
||||
|
||||
@ -330,7 +330,7 @@ Status XlaCompilationCache::CompileImpl(
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "num_inputs=" << args.size();
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
for (int i = 0, end = args.size(); i < end; i++) {
|
||||
VLOG(3) << i << ": " << args[i].HumanString();
|
||||
}
|
||||
}
|
||||
|
@ -156,7 +156,7 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
ResourceVarsSnapshot* result) {
|
||||
for (int i = 0; i < variable_indices.size(); i++) {
|
||||
for (int i = 0, end = variable_indices.size(); i < end; i++) {
|
||||
Var* var = variable_infos[i].var();
|
||||
(*result)[variable_indices[i]] =
|
||||
var ? absl::make_optional(*var->tensor()) : absl::nullopt;
|
||||
@ -206,7 +206,8 @@ XlaComputationLaunchContext::PopulateInputs(
|
||||
|
||||
xla::TransferManager* transfer_manager =
|
||||
client_->backend().transfer_manager();
|
||||
for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
|
||||
for (int i = 0, end = compilation_result->xla_input_shapes.size(); i < end;
|
||||
++i) {
|
||||
int arg_num = compilation_result->input_mapping[i];
|
||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
||||
@ -466,7 +467,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
|
||||
// Copy XLA results to the OpOutputList.
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
for (int i = 0, end = ctx->num_outputs(); i < end; ++i) {
|
||||
const TensorShape& shape = output_tensor_shapes[i];
|
||||
const DataType& type = compilation_result->outputs[i].type;
|
||||
VLOG(2) << "Populating output for retval " << i << " shape "
|
||||
@ -514,7 +515,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
}
|
||||
|
||||
// Apply variable updates, if any.
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
for (int i = 0, end = compilation_result->resource_updates.size(); i < end;
|
||||
++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
|
Loading…
Reference in New Issue
Block a user