Yunxing Dai f86b74e27e [XLA][TF2XLA] Support tensor list with dynamic dimension.
Previously we don't allow a dynamic dimension to change in a HLO while
loop. But this constrain breaks tensor list where the true dynamic
dimension is only known inside the loop body.

This CL:
- Add the feature in dynamic padder to be able to change a dynamic dimension's size in the loop.
- Add a nice test to demonstrate how tensor list / stack can be handled more elegantly in xla.
- Add necessary machinery to wire this feature into tf2xla.

PiperOrigin-RevId: 307901191
Change-Id: I4d39f1d8a8c944f1e9834c39599e6cfbc99f6807
2020-04-22 14:37:18 -07:00

304 lines
13 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_util.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
namespace xla {
using absl::StrCat;
static StatusOr<HloComputation*> WidenWhileCondition(
HloComputation* narrow_condition, const Shape& wide_shape) {
const Shape& narrow_shape =
narrow_condition->parameter_instruction(0)->shape();
HloComputation* wide_while_cond = [&]() {
HloComputation::Builder builder(StrCat("wide.", narrow_condition->name()));
builder.AddInstruction(
HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
// This is needed so that the root instruction is shaped as a PRED[] -- we
// need to get this right to begin with since we can't mutate the type of
// the root instruction later. We later change the root instruction to
// something more appropriate.
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
}();
HloInstruction* truncated_parameter =
TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
narrow_shape.tuple_shapes_size());
HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
{truncated_parameter}, narrow_condition));
wide_while_cond->set_root_instruction(call_narrow_cond);
TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
return wide_while_cond;
}
static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape();
HloComputation* wide_while_body = [&]() {
HloComputation::Builder builder(StrCat("wide.", narrow_body->name()));
builder.AddInstruction(
HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
}();
HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
wide_parameter, narrow_shape.tuple_shapes_size());
HloInstruction* call_narrow_body =
wide_while_body->AddInstruction(HloInstruction::CreateCall(
narrow_shape, {truncated_parameter}, narrow_body));
std::vector<HloInstruction*> live_through_values;
for (int i = narrow_shape.tuple_shapes_size();
i < wide_shape.tuple_shapes_size(); i++) {
live_through_values.push_back(
wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
wide_shape.tuple_shapes(i), wide_parameter, i)));
}
wide_while_body->set_root_instruction(
TupleUtil::AppendSuffix(call_narrow_body, live_through_values));
TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
CallInliner::Inline(call_narrow_body));
return {{wide_while_body, std::move(inlined_instructions_map)}};
}
/*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
WhileUtil::MakeInstructionsLiveIn(
HloInstruction* while_instr,
absl::Span<HloInstruction* const> instructions) {
CHECK(while_instr->shape().IsTuple());
int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
Shape new_while_shape = while_instr->shape();
for (auto* instruction : instructions) {
*new_while_shape.add_tuple_shapes() = instruction->shape();
}
TF_ASSIGN_OR_RETURN(
HloComputation * new_while_condition,
WidenWhileCondition(while_instr->while_condition(), new_while_shape));
HloComputation* new_while_body;
CallInliner::InlinedInstructionMap inlined_instructions_map;
TF_ASSIGN_OR_RETURN(
std::tie(new_while_body, inlined_instructions_map),
WidenWhileBody(while_instr->while_body(), new_while_shape));
HloInstruction* new_while_init =
TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions);
HloComputation* containing_computation = while_instr->parent();
HloInstruction* new_while = containing_computation->AddInstruction(
HloInstruction::CreateWhile(new_while_shape, new_while_condition,
new_while_body, new_while_init));
// We want to get rid of the old while instruction even if it has side
// effecting operations so we do a manual HloComputation::RemoveInstruction
// instead of relying on HloComputation::ReplaceInstruction.
HloInstruction* replacement_instr = TupleUtil::ExtractPrefix(
new_while, while_instr->shape().tuple_shapes_size());
TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr));
TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
std::vector<HloInstruction*> live_in_instructions;
for (int64 i = elements_in_old_while_shape;
i < new_while_shape.tuple_shapes_size(); i++) {
live_in_instructions.push_back(
new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
instructions[i - elements_in_old_while_shape]->shape(),
while_body_param, i)));
}
WhileUtil::MakeInstructionsLiveInResult result;
result.new_while_instr = new_while;
result.replacement_instr = replacement_instr;
result.while_body_live_in_values = std::move(live_in_instructions);
result.while_body_instruction_map = std::move(inlined_instructions_map);
return std::move(result);
}
static StatusOr<std::unique_ptr<HloComputation>>
MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
int32 trip_count) {
Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> cond_computation,
CreateComputationWithSignature(
{&loop_state_shape}, scalar_pred, "while_cond"));
HloInstruction* trip_count_constant = cond_computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(trip_count)));
HloInstruction* param = cond_computation->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
MakeGetTupleElementHlo(param, 0));
TF_ASSIGN_OR_RETURN(
HloInstruction * compare,
MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant));
cond_computation->set_root_instruction(compare);
return std::move(cond_computation);
}
static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
const Shape& loop_state_shape,
const std::function<StatusOr<WhileUtil::LoopStateTy>(
HloInstruction*, const WhileUtil::LoopStateTy&)>& loop_body_generator) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> body_computation,
CreateComputationWithSignature(
{&loop_state_shape}, loop_state_shape, "while_body"));
HloInstruction* one = body_computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
HloInstruction* param = body_computation->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
MakeGetTupleElementHlo(param, 0));
TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar,
MakeBinaryHlo(HloOpcode::kAdd, indvar, one));
std::vector<HloInstruction*> loop_body_generator_args;
for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) {
TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element,
MakeGetTupleElementHlo(param, i));
loop_body_generator_args.push_back(tuple_element);
}
TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> next_state,
loop_body_generator(indvar, loop_body_generator_args));
next_state.insert(next_state.begin(), next_indvar);
HloInstruction* next_state_tuple =
body_computation->AddInstruction(HloInstruction::CreateTuple(next_state));
body_computation->set_root_instruction(next_state_tuple);
return std::move(body_computation);
}
static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
HloComputation* computation, const WhileUtil::LoopStateTy& init_values) {
std::vector<HloInstruction*> init_values_with_indvar;
init_values_with_indvar.reserve(init_values.size() + 1);
HloInstruction* zero = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
init_values_with_indvar.push_back(zero);
absl::c_copy(init_values, std::back_inserter(init_values_with_indvar));
return computation->AddInstruction(
HloInstruction::CreateTuple(init_values_with_indvar));
}
// Returns a tuple shape containing a S32, and a shape from each value in
// `init_values`. If a shape from a value in `init_values` doesn't have a
// layout, use a default layout for the shape.
static Shape MakeLoopStateShapeWithLayout(
const WhileUtil::LoopStateTy& init_values) {
std::vector<Shape> loop_state_shape_components;
loop_state_shape_components.reserve(init_values.size() + 1);
loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
absl::c_transform(init_values,
std::back_inserter(loop_state_shape_components),
[](HloInstruction* instr) {
Shape shape = instr->shape();
if (!shape.has_layout()) {
LayoutUtil::SetToDefaultLayout(&shape);
}
return shape;
});
return ShapeUtil::MakeTupleShape(loop_state_shape_components);
}
/*static*/ StatusOr<WhileUtil::LoopStateTy> WhileUtil::MakeCountedLoop(
HloComputation* computation, int32 trip_count,
const WhileUtil::LoopStateTy& init_values,
const WhileUtil::LoopBodyGeneratorTy& loop_body_generator,
const OpMetadata& metadata) {
CHECK_GE(trip_count, 0);
// Both MakeCountedLoopConditionComputation and MakeCountedLoopBodyComputation
// use loop_state_shape to create a literal, which requires loop_state_shape
// to have a layout.
Shape loop_state_shape = MakeLoopStateShapeWithLayout(init_values);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloComputation> cond,
MakeCountedLoopConditionComputation(loop_state_shape, trip_count));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloComputation> body,
MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator));
TF_ASSIGN_OR_RETURN(HloInstruction * init_tuple,
MakeInitTupleFromInitValues(computation, init_values));
HloModule* module = computation->parent();
HloInstruction* while_instr =
computation->AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, module->AddEmbeddedComputation(std::move(cond)),
module->AddEmbeddedComputation(std::move(body)), init_tuple));
while_instr->set_metadata(metadata);
std::vector<HloInstruction*> result;
for (int64 i = 0, e = init_values.size(); i < e; i++) {
TF_ASSIGN_OR_RETURN(HloInstruction * user_state,
MakeGetTupleElementHlo(while_instr, i + 1));
result.push_back(user_state);
}
return result;
}
/*static*/ std::vector<HloInstruction*> WhileUtil::GetInvariantGTEsForWhileBody(
const HloComputation& while_body) {
std::vector<HloInstruction*> result;
const HloInstruction::InstructionVector root_operands =
while_body.root_instruction()->operands();
for (int i = 0; i < root_operands.size(); i++) {
HloInstruction* instr = root_operands[i];
if (instr->opcode() == HloOpcode::kGetTupleElement &&
instr->tuple_index() == i &&
instr->operand(0) == while_body.parameter_instruction(0)) {
result.push_back(instr);
}
}
return result;
}
/*static*/ absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>
WhileUtil::GetGTEsMapForWhileConditional(
const HloComputation& while_conditional) {
absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>> result;
for (HloInstruction* user :
while_conditional.parameter_instruction(0)->users()) {
if (user->opcode() == HloOpcode::kGetTupleElement) {
result[user->tuple_index()].push_back(user);
}
}
return result;
}
} // namespace xla