[XLA] Reduce the amount of compute performed on the operand shape in
BatchNormInfernce. PiperOrigin-RevId: 271172669
This commit is contained in:
parent
2eb6dc0f2e
commit
7484cb2b12
@ -310,7 +310,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
|
||||
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
|
||||
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
operand_shape,
|
||||
feature_shape,
|
||||
computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(epsilon_literal))),
|
||||
{}));
|
||||
@ -334,42 +334,25 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
|
||||
HloInstruction* a, HloInstruction* b) {
|
||||
return add(HloInstruction::CreateBinary(shape, opcode, a, b));
|
||||
};
|
||||
auto feature_broadcast = [&](HloInstruction* a) {
|
||||
return add(
|
||||
HloInstruction::CreateBroadcast(operand_shape, a, {feature_index}));
|
||||
};
|
||||
|
||||
int64 instruction_count_before = computation_->instruction_count();
|
||||
auto true_scale = add_binary(
|
||||
feature_shape, HloOpcode::kMultiply, scale,
|
||||
add(Rsqrt(add_binary(feature_shape, HloOpcode::kAdd, var, epsilon),
|
||||
add)));
|
||||
auto true_shift = add_binary(
|
||||
feature_shape, HloOpcode::kSubtract, offset,
|
||||
add_binary(feature_shape, HloOpcode::kMultiply, mean, true_scale));
|
||||
|
||||
auto scale_broadcasted = add(
|
||||
HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
|
||||
|
||||
auto offset_broadcasted = add(
|
||||
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
|
||||
|
||||
auto mean_broadcasted = add(
|
||||
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
|
||||
|
||||
auto var_broadcasted =
|
||||
add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
|
||||
|
||||
// Var[X] + epsilon.
|
||||
auto var_add_epsilon =
|
||||
add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
|
||||
|
||||
// 1 / Sqrt[Var[X] + epsilon].
|
||||
auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
|
||||
|
||||
// X - E[X].
|
||||
auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
|
||||
operand, mean_broadcasted);
|
||||
|
||||
// (X - E[X]) / Sqrt[Var[X] + epsilon].
|
||||
auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
|
||||
operand_minus_mean, rsqrt_var_add_epsilon);
|
||||
|
||||
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
|
||||
auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
|
||||
normalized, scale_broadcasted);
|
||||
|
||||
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
|
||||
auto shifted_normalized = HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
|
||||
auto shifted_normalized =
|
||||
add_binary(operand_shape, HloOpcode::kAdd,
|
||||
add_binary(operand_shape, HloOpcode::kMultiply, operand,
|
||||
feature_broadcast(true_scale)),
|
||||
feature_broadcast(true_shift));
|
||||
|
||||
int64 instruction_count_after = computation_->instruction_count();
|
||||
CHECK_EQ(instruction_count_after,
|
||||
@ -390,8 +373,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
|
||||
}
|
||||
shifted_normalized->set_sharding(sharding);
|
||||
}
|
||||
TF_CHECK_OK(
|
||||
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
|
||||
TF_CHECK_OK(ReplaceInstruction(batch_norm, shifted_normalized));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user