Support inferring dynamism of reduce that shows up multiple times in a kSelect operand list.

This cl supports the case where a reduce is a common ancestor of kSelect's multiple operands.

Reduce
  |
 / \
..  ..  ..
|0  |1  |2
   Select

We do that by copying the copying the reduce twice into the inference graph:

Reduce(original form)  Reduce(Predicate form)
  |                     /
  ..                   ..
  |    +--------------/
  |0   |1  +--2---..
     Select

PiperOrigin-RevId: 346587426
Change-Id: Ic24cff2d147886a685b85d22c1d635c9c7901367
This commit is contained in:
Yunxing Dai 2020-12-09 10:49:50 -08:00 committed by TensorFlower Gardener
parent 193e2c295d
commit b70529fa16
3 changed files with 68 additions and 19 deletions

View File

@ -94,9 +94,12 @@ void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
*instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); *instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
} }
// Converts a HloComputation into ReducerOr with predicate types. // Copy `original_reducer` into a new computation proto with `reducer_id` as new
HloComputationProto CreateReduceOr(int64 reducer_id, // id. If `rewrite_into_pred` is true, the instructions in the reducer are
HloComputationProto* original_reducer) { // rewritten into predicate form.
HloComputationProto CopyReducer(int64 reducer_id,
HloComputationProto* original_reducer,
bool rewrite_into_pred, int64* global_id) {
HloComputationProto reducer; HloComputationProto reducer;
SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id); SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
std::vector<int64> operands_id; std::vector<int64> operands_id;
@ -106,19 +109,28 @@ HloComputationProto CreateReduceOr(int64 reducer_id,
HloOpcode::kParameter) { HloOpcode::kParameter) {
HloInstructionProto* new_param = reducer.add_instructions(); HloInstructionProto* new_param = reducer.add_instructions();
*new_param = inst; *new_param = inst;
*new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); new_param->set_id((*global_id)++);
operands_id.push_back(inst.id()); *new_param->mutable_name() =
GetFullName(inst.name(), '.', new_param->id());
if (rewrite_into_pred) {
*new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
}
operands_id.push_back(new_param->id());
} }
if (inst.id() == original_reducer->root_id()) { if (inst.id() == original_reducer->root_id()) {
HloInstructionProto* new_root = reducer.add_instructions(); HloInstructionProto* new_root = reducer.add_instructions();
*new_root = inst; *new_root = inst;
*new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); new_root->set_id((*global_id)++);
*new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); *new_root->mutable_name() = GetFullName(inst.name(), '.', new_root->id());
if (rewrite_into_pred) {
*new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
*new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
}
new_root->clear_operand_ids(); new_root->clear_operand_ids();
for (int64 operand_id : operands_id) { for (int64 operand_id : operands_id) {
new_root->add_operand_ids(operand_id); new_root->add_operand_ids(operand_id);
} }
reducer.set_root_id(inst.id()); reducer.set_root_id(new_root->id());
} }
} }
return reducer; return reducer;
@ -3323,7 +3335,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
*program_shape->mutable_result() = *program_shape->mutable_result() =
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto(); ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
std::vector<HloComputationProto> called_computatons; std::vector<HloComputationProto> called_computations;
auto operand_is_constant = [&](const HloInstructionProto* instr_proto, auto operand_is_constant = [&](const HloInstructionProto* instr_proto,
int64 operand_index) -> StatusOr<bool> { int64 operand_index) -> StatusOr<bool> {
int64 operand_id = instr_proto->operand_ids(operand_index); int64 operand_id = instr_proto->operand_ids(operand_index);
@ -3336,7 +3348,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
// graph with have id set to `id`. // graph with have id set to `id`.
auto process_instruction = [&](const HloInstructionProto* instr_proto, auto process_instruction = [&](const HloInstructionProto* instr_proto,
bool need_rewrite, int64 id, bool need_rewrite, int64 id,
absl::Span<int64 const> operand_ids) { absl::Span<int64 const> operand_ids,
int64* global_id) {
// Rewrite the instruction with following rules: // Rewrite the instruction with following rules:
// - Unary ops: Convert into bitcast (identity) with type Pred. // - Unary ops: Convert into bitcast (identity) with type Pred.
// - Binary ops: Convert into binary or. // - Binary ops: Convert into binary or.
@ -3364,6 +3377,17 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
if (!need_rewrite) { if (!need_rewrite) {
*new_instr->mutable_name() = *new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, id); GetFullName(instr_proto->opcode(), kNameSeparator, id);
if (opcode == HloOpcode::kReduce) {
// Copy the reducer to the new module, with a new id that's same as the
// reduce op.
HloComputationProto* reducer =
&embedded_[new_instr->called_computation_ids(0)];
int64 reducer_id = (*global_id)++;
new_instr->clear_called_computation_ids();
new_instr->add_called_computation_ids(reducer_id);
called_computations.push_back(CopyReducer(
reducer_id, reducer, /*rewrite_into_pred=*/false, global_id));
}
return Status::OK(); return Status::OK();
} }
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape()); *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
@ -3439,9 +3463,12 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
break; break;
} }
case HloOpcode::kReduce: { case HloOpcode::kReduce: {
int64 reducer_id = new_instr->called_computation_ids(0); auto* reducer = &embedded_[new_instr->called_computation_ids(0)];
called_computatons.push_back( int64 reducer_id = (*global_id)++;
CreateReduceOr(reducer_id, &embedded_[reducer_id])); new_instr->clear_called_computation_ids();
new_instr->add_called_computation_ids(reducer_id);
called_computations.push_back(CopyReducer(
reducer_id, reducer, /*rewrite_into_pred=*/true, global_id));
break; break;
} }
case HloOpcode::kTuple: case HloOpcode::kTuple:
@ -3567,10 +3594,11 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
if (next_operand >= instr_proto->operand_ids_size() || if (next_operand >= instr_proto->operand_ids_size() ||
!should_visit_operand || InstrIsSetBound(instr_proto)) { !should_visit_operand || InstrIsSetBound(instr_proto)) {
// No more operands to process, process self. // No more operands to process, process self.
int64 new_id = ++global_id; int64 new_id = global_id++;
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name(); VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite, TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
new_id, item.processed_operands)); new_id, item.processed_operands,
&global_id));
stacktop_id = new_id; stacktop_id = new_id;
seen[item_key] = stacktop_id; seen[item_key] = stacktop_id;
worklist.pop_back(); worklist.pop_back();
@ -3602,10 +3630,14 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
module->set_entry_computation_name(entry.name()); module->set_entry_computation_name(entry.name());
module->set_entry_computation_id(entry.id()); module->set_entry_computation_id(entry.id());
*module->mutable_host_program_shape() = *program_shape; *module->mutable_host_program_shape() = *program_shape;
for (auto& called_comp : called_computatons) { for (auto& called_comp : called_computations) {
*module->add_computations() = called_comp; *module->add_computations() = called_comp;
} }
*module->add_computations() = std::move(entry); *module->add_computations() = std::move(entry);
// Make sure all ids appear in the computation with ascending order.
absl::c_sort(*module->mutable_computations(),
[](const HloComputationProto& c1,
const HloComputationProto& c2) { return c1.id() < c2.id(); });
XLA_VLOG_LINES(3, module->DebugString()); XLA_VLOG_LINES(3, module->DebugString());
return std::move(computation); return std::move(computation);
} }

View File

@ -2098,7 +2098,10 @@ xla_test(
name = "dynamism_inference_test", name = "dynamism_inference_test",
srcs = ["dynamism_inference_test.cc"], srcs = ["dynamism_inference_test.cc"],
deps = [ deps = [
":literal_test_util",
":test_macros_header", ":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
@ -2109,10 +2112,8 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
@ -162,6 +163,21 @@ TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
} }
} }
TEST_F(DynamismInferenceTest, ReduceUsedTwice) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
auto zero = ConstantR0<int32>(&b, 0);
XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
auto reduce = Reduce(p, zero, add_s32, {0});
auto pred = Eq(c, reduce);
auto result = Select(pred, reduce, c);
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true);
}
}
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) { TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
for (ClientType client_type : client_types) { for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type); Client* client = ClientOrDie(platform_, client_type);