[XLA] Reduce the amount of compute performed on the operand shape in

BatchNormInfernce.

PiperOrigin-RevId: 271172669
This commit is contained in:
Blake Hechtman 2019-09-25 11:51:48 -07:00 committed by TensorFlower Gardener
parent 2eb6dc0f2e
commit 7484cb2b12

View File

@ -310,7 +310,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
feature_shape,
computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(epsilon_literal))),
{}));
@ -334,42 +334,25 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
HloInstruction* a, HloInstruction* b) {
return add(HloInstruction::CreateBinary(shape, opcode, a, b));
};
auto feature_broadcast = [&](HloInstruction* a) {
return add(
HloInstruction::CreateBroadcast(operand_shape, a, {feature_index}));
};
int64 instruction_count_before = computation_->instruction_count();
auto true_scale = add_binary(
feature_shape, HloOpcode::kMultiply, scale,
add(Rsqrt(add_binary(feature_shape, HloOpcode::kAdd, var, epsilon),
add)));
auto true_shift = add_binary(
feature_shape, HloOpcode::kSubtract, offset,
add_binary(feature_shape, HloOpcode::kMultiply, mean, true_scale));
auto scale_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
auto offset_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
auto mean_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
auto var_broadcasted =
add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
// Var[X] + epsilon.
auto var_add_epsilon =
add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
// 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
// X - E[X].
auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
operand, mean_broadcasted);
// (X - E[X]) / Sqrt[Var[X] + epsilon].
auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
operand_minus_mean, rsqrt_var_add_epsilon);
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
normalized, scale_broadcasted);
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
auto shifted_normalized = HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
auto shifted_normalized =
add_binary(operand_shape, HloOpcode::kAdd,
add_binary(operand_shape, HloOpcode::kMultiply, operand,
feature_broadcast(true_scale)),
feature_broadcast(true_shift));
int64 instruction_count_after = computation_->instruction_count();
CHECK_EQ(instruction_count_after,
@ -390,8 +373,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
}
shifted_normalized->set_sharding(sharding);
}
TF_CHECK_OK(
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
TF_CHECK_OK(ReplaceInstruction(batch_norm, shifted_normalized));
return Status::OK();
}