Support inferring dynamism of reduce that shows up multiple times in a kSelect operand list.
This cl supports the case where a reduce is a common ancestor of kSelect's multiple operands. Reduce | / \ .. .. .. |0 |1 |2 Select We do that by copying the copying the reduce twice into the inference graph: Reduce(original form) Reduce(Predicate form) | / .. .. | +--------------/ |0 |1 +--2---.. Select PiperOrigin-RevId: 346587426 Change-Id: Ic24cff2d147886a685b85d22c1d635c9c7901367
This commit is contained in:
parent
193e2c295d
commit
b70529fa16
@ -94,9 +94,12 @@ void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
|
||||
*instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
|
||||
}
|
||||
|
||||
// Converts a HloComputation into ReducerOr with predicate types.
|
||||
HloComputationProto CreateReduceOr(int64 reducer_id,
|
||||
HloComputationProto* original_reducer) {
|
||||
// Copy `original_reducer` into a new computation proto with `reducer_id` as new
|
||||
// id. If `rewrite_into_pred` is true, the instructions in the reducer are
|
||||
// rewritten into predicate form.
|
||||
HloComputationProto CopyReducer(int64 reducer_id,
|
||||
HloComputationProto* original_reducer,
|
||||
bool rewrite_into_pred, int64* global_id) {
|
||||
HloComputationProto reducer;
|
||||
SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
|
||||
std::vector<int64> operands_id;
|
||||
@ -106,19 +109,28 @@ HloComputationProto CreateReduceOr(int64 reducer_id,
|
||||
HloOpcode::kParameter) {
|
||||
HloInstructionProto* new_param = reducer.add_instructions();
|
||||
*new_param = inst;
|
||||
*new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
|
||||
operands_id.push_back(inst.id());
|
||||
new_param->set_id((*global_id)++);
|
||||
*new_param->mutable_name() =
|
||||
GetFullName(inst.name(), '.', new_param->id());
|
||||
if (rewrite_into_pred) {
|
||||
*new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
|
||||
}
|
||||
operands_id.push_back(new_param->id());
|
||||
}
|
||||
if (inst.id() == original_reducer->root_id()) {
|
||||
HloInstructionProto* new_root = reducer.add_instructions();
|
||||
*new_root = inst;
|
||||
*new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
|
||||
*new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
|
||||
new_root->set_id((*global_id)++);
|
||||
*new_root->mutable_name() = GetFullName(inst.name(), '.', new_root->id());
|
||||
if (rewrite_into_pred) {
|
||||
*new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
|
||||
*new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
|
||||
}
|
||||
new_root->clear_operand_ids();
|
||||
for (int64 operand_id : operands_id) {
|
||||
new_root->add_operand_ids(operand_id);
|
||||
}
|
||||
reducer.set_root_id(inst.id());
|
||||
reducer.set_root_id(new_root->id());
|
||||
}
|
||||
}
|
||||
return reducer;
|
||||
@ -3323,7 +3335,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
*program_shape->mutable_result() =
|
||||
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
|
||||
|
||||
std::vector<HloComputationProto> called_computatons;
|
||||
std::vector<HloComputationProto> called_computations;
|
||||
auto operand_is_constant = [&](const HloInstructionProto* instr_proto,
|
||||
int64 operand_index) -> StatusOr<bool> {
|
||||
int64 operand_id = instr_proto->operand_ids(operand_index);
|
||||
@ -3336,7 +3348,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
// graph with have id set to `id`.
|
||||
auto process_instruction = [&](const HloInstructionProto* instr_proto,
|
||||
bool need_rewrite, int64 id,
|
||||
absl::Span<int64 const> operand_ids) {
|
||||
absl::Span<int64 const> operand_ids,
|
||||
int64* global_id) {
|
||||
// Rewrite the instruction with following rules:
|
||||
// - Unary ops: Convert into bitcast (identity) with type Pred.
|
||||
// - Binary ops: Convert into binary or.
|
||||
@ -3364,6 +3377,17 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
if (!need_rewrite) {
|
||||
*new_instr->mutable_name() =
|
||||
GetFullName(instr_proto->opcode(), kNameSeparator, id);
|
||||
if (opcode == HloOpcode::kReduce) {
|
||||
// Copy the reducer to the new module, with a new id that's same as the
|
||||
// reduce op.
|
||||
HloComputationProto* reducer =
|
||||
&embedded_[new_instr->called_computation_ids(0)];
|
||||
int64 reducer_id = (*global_id)++;
|
||||
new_instr->clear_called_computation_ids();
|
||||
new_instr->add_called_computation_ids(reducer_id);
|
||||
called_computations.push_back(CopyReducer(
|
||||
reducer_id, reducer, /*rewrite_into_pred=*/false, global_id));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
|
||||
@ -3439,9 +3463,12 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReduce: {
|
||||
int64 reducer_id = new_instr->called_computation_ids(0);
|
||||
called_computatons.push_back(
|
||||
CreateReduceOr(reducer_id, &embedded_[reducer_id]));
|
||||
auto* reducer = &embedded_[new_instr->called_computation_ids(0)];
|
||||
int64 reducer_id = (*global_id)++;
|
||||
new_instr->clear_called_computation_ids();
|
||||
new_instr->add_called_computation_ids(reducer_id);
|
||||
called_computations.push_back(CopyReducer(
|
||||
reducer_id, reducer, /*rewrite_into_pred=*/true, global_id));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kTuple:
|
||||
@ -3567,10 +3594,11 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
if (next_operand >= instr_proto->operand_ids_size() ||
|
||||
!should_visit_operand || InstrIsSetBound(instr_proto)) {
|
||||
// No more operands to process, process self.
|
||||
int64 new_id = ++global_id;
|
||||
int64 new_id = global_id++;
|
||||
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
|
||||
TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
|
||||
new_id, item.processed_operands));
|
||||
new_id, item.processed_operands,
|
||||
&global_id));
|
||||
stacktop_id = new_id;
|
||||
seen[item_key] = stacktop_id;
|
||||
worklist.pop_back();
|
||||
@ -3602,10 +3630,14 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
module->set_entry_computation_name(entry.name());
|
||||
module->set_entry_computation_id(entry.id());
|
||||
*module->mutable_host_program_shape() = *program_shape;
|
||||
for (auto& called_comp : called_computatons) {
|
||||
for (auto& called_comp : called_computations) {
|
||||
*module->add_computations() = called_comp;
|
||||
}
|
||||
*module->add_computations() = std::move(entry);
|
||||
// Make sure all ids appear in the computation with ascending order.
|
||||
absl::c_sort(*module->mutable_computations(),
|
||||
[](const HloComputationProto& c1,
|
||||
const HloComputationProto& c2) { return c1.id() < c2.id(); });
|
||||
XLA_VLOG_LINES(3, module->DebugString());
|
||||
return std::move(computation);
|
||||
}
|
||||
|
@ -2098,7 +2098,10 @@ xla_test(
|
||||
name = "dynamism_inference_test",
|
||||
srcs = ["dynamism_inference_test.cc"],
|
||||
deps = [
|
||||
":literal_test_util",
|
||||
":test_macros_header",
|
||||
":test_utils",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -2109,10 +2112,8 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:prng",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
@ -162,6 +163,21 @@ TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, ReduceUsedTwice) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
|
||||
auto zero = ConstantR0<int32>(&b, 0);
|
||||
XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
|
||||
auto reduce = Reduce(p, zero, add_s32, {0});
|
||||
auto pred = Eq(c, reduce);
|
||||
auto result = Select(pred, reduce, c);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
|
Loading…
x
Reference in New Issue
Block a user