[XLA] Remove non-scalar requirement on reduce precision
This was fixed in cl/211936337 PiperOrigin-RevId: 221249970
This commit is contained in:
parent
d8af932d02
commit
b9247ca08f
@ -111,7 +111,7 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
|
||||
VLOG(2) << "Adding to operand " << i << ": " << operand;
|
||||
|
||||
if (!is_valid_shape(operand->shape())) {
|
||||
VLOG(2) << "Skipped: value is not an F32 vector";
|
||||
VLOG(2) << "Skipped: value is not of type F32";
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -168,7 +168,7 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
|
||||
<< instruction->ToString();
|
||||
|
||||
if (!is_valid_shape(instruction->shape())) {
|
||||
VLOG(2) << "Skipped: value is not an F32 nonscalar array";
|
||||
VLOG(2) << "Skipped: value is not of type F32";
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -118,13 +118,7 @@ class ReducePrecisionInsertion : public HloModulePass {
|
||||
// equivalent behavior can be obtained by adding ReducePrecision
|
||||
// instructions after the instructions that pull the F32 arrays out of
|
||||
// the tuples.
|
||||
//
|
||||
// TODO(b/64093391): Remove the IsScalar check once this won't cause
|
||||
// failures on the GPU backend if the ReducePrecision instruction ends up
|
||||
// inserted between a scalar constant and the init_value argument of a
|
||||
// Reduce operation.
|
||||
return shape.element_type() == PrimitiveType::F32 &&
|
||||
!ShapeUtil::IsScalar(shape);
|
||||
return shape.element_type() == PrimitiveType::F32;
|
||||
}
|
||||
|
||||
// Is this instruction one such that following or preceding it with a new
|
||||
|
@ -71,6 +71,33 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) {
|
||||
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
// Create a simple graph with a parameter feeding a unary cosine function.
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected state before adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
EXPECT_EQ(b->operand(0), a);
|
||||
|
||||
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos;
|
||||
}));
|
||||
|
||||
// Confirm expected graph after adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
|
Loading…
x
Reference in New Issue
Block a user