Support inferring dynamism of reduce that shows up multiple times in a kSelect operand list.

This cl supports the case where a reduce is a common ancestor of kSelect's multiple operands.

Reduce
  |
 / \
..  ..  ..
|0  |1  |2
   Select

We do that by copying the copying the reduce twice into the inference graph:

Reduce(original form)  Reduce(Predicate form)
  |                     /
  ..                   ..
  |    +--------------/
  |0   |1  +--2---..
     Select

PiperOrigin-RevId: 346587426
Change-Id: Ic24cff2d147886a685b85d22c1d635c9c7901367
This commit is contained in:
Yunxing Dai 2020-12-09 10:49:50 -08:00 committed by TensorFlower Gardener
parent 193e2c295d
commit b70529fa16
3 changed files with 68 additions and 19 deletions

View File

@ -94,9 +94,12 @@ void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
*instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
}
// Converts a HloComputation into ReducerOr with predicate types.
HloComputationProto CreateReduceOr(int64 reducer_id,
HloComputationProto* original_reducer) {
// Copy `original_reducer` into a new computation proto with `reducer_id` as new
// id. If `rewrite_into_pred` is true, the instructions in the reducer are
// rewritten into predicate form.
HloComputationProto CopyReducer(int64 reducer_id,
HloComputationProto* original_reducer,
bool rewrite_into_pred, int64* global_id) {
HloComputationProto reducer;
SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
std::vector<int64> operands_id;
@ -106,19 +109,28 @@ HloComputationProto CreateReduceOr(int64 reducer_id,
HloOpcode::kParameter) {
HloInstructionProto* new_param = reducer.add_instructions();
*new_param = inst;
*new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
operands_id.push_back(inst.id());
new_param->set_id((*global_id)++);
*new_param->mutable_name() =
GetFullName(inst.name(), '.', new_param->id());
if (rewrite_into_pred) {
*new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
}
operands_id.push_back(new_param->id());
}
if (inst.id() == original_reducer->root_id()) {
HloInstructionProto* new_root = reducer.add_instructions();
*new_root = inst;
*new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
*new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
new_root->set_id((*global_id)++);
*new_root->mutable_name() = GetFullName(inst.name(), '.', new_root->id());
if (rewrite_into_pred) {
*new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
*new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
}
new_root->clear_operand_ids();
for (int64 operand_id : operands_id) {
new_root->add_operand_ids(operand_id);
}
reducer.set_root_id(inst.id());
reducer.set_root_id(new_root->id());
}
}
return reducer;
@ -3323,7 +3335,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
*program_shape->mutable_result() =
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
std::vector<HloComputationProto> called_computatons;
std::vector<HloComputationProto> called_computations;
auto operand_is_constant = [&](const HloInstructionProto* instr_proto,
int64 operand_index) -> StatusOr<bool> {
int64 operand_id = instr_proto->operand_ids(operand_index);
@ -3336,7 +3348,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
// graph with have id set to `id`.
auto process_instruction = [&](const HloInstructionProto* instr_proto,
bool need_rewrite, int64 id,
absl::Span<int64 const> operand_ids) {
absl::Span<int64 const> operand_ids,
int64* global_id) {
// Rewrite the instruction with following rules:
// - Unary ops: Convert into bitcast (identity) with type Pred.
// - Binary ops: Convert into binary or.
@ -3364,6 +3377,17 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
if (!need_rewrite) {
*new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, id);
if (opcode == HloOpcode::kReduce) {
// Copy the reducer to the new module, with a new id that's same as the
// reduce op.
HloComputationProto* reducer =
&embedded_[new_instr->called_computation_ids(0)];
int64 reducer_id = (*global_id)++;
new_instr->clear_called_computation_ids();
new_instr->add_called_computation_ids(reducer_id);
called_computations.push_back(CopyReducer(
reducer_id, reducer, /*rewrite_into_pred=*/false, global_id));
}
return Status::OK();
}
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
@ -3439,9 +3463,12 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
break;
}
case HloOpcode::kReduce: {
int64 reducer_id = new_instr->called_computation_ids(0);
called_computatons.push_back(
CreateReduceOr(reducer_id, &embedded_[reducer_id]));
auto* reducer = &embedded_[new_instr->called_computation_ids(0)];
int64 reducer_id = (*global_id)++;
new_instr->clear_called_computation_ids();
new_instr->add_called_computation_ids(reducer_id);
called_computations.push_back(CopyReducer(
reducer_id, reducer, /*rewrite_into_pred=*/true, global_id));
break;
}
case HloOpcode::kTuple:
@ -3567,10 +3594,11 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
if (next_operand >= instr_proto->operand_ids_size() ||
!should_visit_operand || InstrIsSetBound(instr_proto)) {
// No more operands to process, process self.
int64 new_id = ++global_id;
int64 new_id = global_id++;
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
new_id, item.processed_operands));
new_id, item.processed_operands,
&global_id));
stacktop_id = new_id;
seen[item_key] = stacktop_id;
worklist.pop_back();
@ -3602,10 +3630,14 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
module->set_entry_computation_name(entry.name());
module->set_entry_computation_id(entry.id());
*module->mutable_host_program_shape() = *program_shape;
for (auto& called_comp : called_computatons) {
for (auto& called_comp : called_computations) {
*module->add_computations() = called_comp;
}
*module->add_computations() = std::move(entry);
// Make sure all ids appear in the computation with ascending order.
absl::c_sort(*module->mutable_computations(),
[](const HloComputationProto& c1,
const HloComputationProto& c2) { return c1.id() < c2.id(); });
XLA_VLOG_LINES(3, module->DebugString());
return std::move(computation);
}

View File

@ -2098,7 +2098,10 @@ xla_test(
name = "dynamism_inference_test",
srcs = ["dynamism_inference_test.cc"],
deps = [
":literal_test_util",
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -2109,10 +2112,8 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@ -162,6 +163,21 @@ TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
}
}
TEST_F(DynamismInferenceTest, ReduceUsedTwice) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
auto zero = ConstantR0<int32>(&b, 0);
XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
auto reduce = Reduce(p, zero, add_s32, {0});
auto pred = Eq(c, reduce);
auto result = Select(pred, reduce, c);
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true);
}
}
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);