diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index c23b40ab6cd..821710ed2a4 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -94,9 +94,12 @@ void SetInstructionAsConstant(HloInstructionProto* instr, int64 id, *instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); } -// Converts a HloComputation into ReducerOr with predicate types. -HloComputationProto CreateReduceOr(int64 reducer_id, - HloComputationProto* original_reducer) { +// Copy `original_reducer` into a new computation proto with `reducer_id` as new +// id. If `rewrite_into_pred` is true, the instructions in the reducer are +// rewritten into predicate form. +HloComputationProto CopyReducer(int64 reducer_id, + HloComputationProto* original_reducer, + bool rewrite_into_pred, int64* global_id) { HloComputationProto reducer; SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id); std::vector operands_id; @@ -106,19 +109,28 @@ HloComputationProto CreateReduceOr(int64 reducer_id, HloOpcode::kParameter) { HloInstructionProto* new_param = reducer.add_instructions(); *new_param = inst; - *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); - operands_id.push_back(inst.id()); + new_param->set_id((*global_id)++); + *new_param->mutable_name() = + GetFullName(inst.name(), '.', new_param->id()); + if (rewrite_into_pred) { + *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); + } + operands_id.push_back(new_param->id()); } if (inst.id() == original_reducer->root_id()) { HloInstructionProto* new_root = reducer.add_instructions(); *new_root = inst; - *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); - *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); + new_root->set_id((*global_id)++); + *new_root->mutable_name() = GetFullName(inst.name(), '.', new_root->id()); + if (rewrite_into_pred) { + *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); + *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); + } new_root->clear_operand_ids(); for (int64 operand_id : operands_id) { new_root->add_operand_ids(operand_id); } - reducer.set_root_id(inst.id()); + reducer.set_root_id(new_root->id()); } } return reducer; @@ -3323,7 +3335,7 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { *program_shape->mutable_result() = ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto(); - std::vector called_computatons; + std::vector called_computations; auto operand_is_constant = [&](const HloInstructionProto* instr_proto, int64 operand_index) -> StatusOr { int64 operand_id = instr_proto->operand_ids(operand_index); @@ -3336,7 +3348,8 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { // graph with have id set to `id`. auto process_instruction = [&](const HloInstructionProto* instr_proto, bool need_rewrite, int64 id, - absl::Span operand_ids) { + absl::Span operand_ids, + int64* global_id) { // Rewrite the instruction with following rules: // - Unary ops: Convert into bitcast (identity) with type Pred. // - Binary ops: Convert into binary or. @@ -3364,6 +3377,17 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { if (!need_rewrite) { *new_instr->mutable_name() = GetFullName(instr_proto->opcode(), kNameSeparator, id); + if (opcode == HloOpcode::kReduce) { + // Copy the reducer to the new module, with a new id that's same as the + // reduce op. + HloComputationProto* reducer = + &embedded_[new_instr->called_computation_ids(0)]; + int64 reducer_id = (*global_id)++; + new_instr->clear_called_computation_ids(); + new_instr->add_called_computation_ids(reducer_id); + called_computations.push_back(CopyReducer( + reducer_id, reducer, /*rewrite_into_pred=*/false, global_id)); + } return Status::OK(); } *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape()); @@ -3439,9 +3463,12 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { break; } case HloOpcode::kReduce: { - int64 reducer_id = new_instr->called_computation_ids(0); - called_computatons.push_back( - CreateReduceOr(reducer_id, &embedded_[reducer_id])); + auto* reducer = &embedded_[new_instr->called_computation_ids(0)]; + int64 reducer_id = (*global_id)++; + new_instr->clear_called_computation_ids(); + new_instr->add_called_computation_ids(reducer_id); + called_computations.push_back(CopyReducer( + reducer_id, reducer, /*rewrite_into_pred=*/true, global_id)); break; } case HloOpcode::kTuple: @@ -3567,10 +3594,11 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { if (next_operand >= instr_proto->operand_ids_size() || !should_visit_operand || InstrIsSetBound(instr_proto)) { // No more operands to process, process self. - int64 new_id = ++global_id; + int64 new_id = global_id++; VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name(); TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite, - new_id, item.processed_operands)); + new_id, item.processed_operands, + &global_id)); stacktop_id = new_id; seen[item_key] = stacktop_id; worklist.pop_back(); @@ -3602,10 +3630,14 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { module->set_entry_computation_name(entry.name()); module->set_entry_computation_id(entry.id()); *module->mutable_host_program_shape() = *program_shape; - for (auto& called_comp : called_computatons) { + for (auto& called_comp : called_computations) { *module->add_computations() = called_comp; } *module->add_computations() = std::move(entry); + // Make sure all ids appear in the computation with ascending order. + absl::c_sort(*module->mutable_computations(), + [](const HloComputationProto& c1, + const HloComputationProto& c2) { return c1.id() < c2.id(); }); XLA_VLOG_LINES(3, module->DebugString()); return std::move(computation); } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 7080e4460bf..9923a6494c4 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2098,7 +2098,10 @@ xla_test( name = "dynamism_inference_test", srcs = ["dynamism_inference_test.cc"], deps = [ + ":literal_test_util", ":test_macros_header", + ":test_utils", + ":xla_internal_test_main", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2109,10 +2112,8 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:prng", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc index 1763ed6090e..892fdb86362 100644 --- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc +++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/match.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -162,6 +163,21 @@ TEST_F(DynamismInferenceTest, PredValueUsedTwice) { } } +TEST_F(DynamismInferenceTest, ReduceUsedTwice) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0"); + auto zero = ConstantR0(&b, 0); + XlaComputation add_s32 = CreateScalarAddComputation(S32, &b); + auto reduce = Reduce(p, zero, add_s32, {0}); + auto pred = Eq(c, reduce); + auto result = Select(pred, reduce, c); + EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true); + } +} + TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type);