[XLA] Remove non-scalar requirement on reduce precision
This was fixed in cl/211936337 PiperOrigin-RevId: 221249970
This commit is contained in:
parent
d8af932d02
commit
b9247ca08f
@ -111,7 +111,7 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
|
|||||||
VLOG(2) << "Adding to operand " << i << ": " << operand;
|
VLOG(2) << "Adding to operand " << i << ": " << operand;
|
||||||
|
|
||||||
if (!is_valid_shape(operand->shape())) {
|
if (!is_valid_shape(operand->shape())) {
|
||||||
VLOG(2) << "Skipped: value is not an F32 vector";
|
VLOG(2) << "Skipped: value is not of type F32";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,7 +168,7 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
|
|||||||
<< instruction->ToString();
|
<< instruction->ToString();
|
||||||
|
|
||||||
if (!is_valid_shape(instruction->shape())) {
|
if (!is_valid_shape(instruction->shape())) {
|
||||||
VLOG(2) << "Skipped: value is not an F32 nonscalar array";
|
VLOG(2) << "Skipped: value is not of type F32";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,13 +118,7 @@ class ReducePrecisionInsertion : public HloModulePass {
|
|||||||
// equivalent behavior can be obtained by adding ReducePrecision
|
// equivalent behavior can be obtained by adding ReducePrecision
|
||||||
// instructions after the instructions that pull the F32 arrays out of
|
// instructions after the instructions that pull the F32 arrays out of
|
||||||
// the tuples.
|
// the tuples.
|
||||||
//
|
return shape.element_type() == PrimitiveType::F32;
|
||||||
// TODO(b/64093391): Remove the IsScalar check once this won't cause
|
|
||||||
// failures on the GPU backend if the ReducePrecision instruction ends up
|
|
||||||
// inserted between a scalar constant and the init_value argument of a
|
|
||||||
// Reduce operation.
|
|
||||||
return shape.element_type() == PrimitiveType::F32 &&
|
|
||||||
!ShapeUtil::IsScalar(shape);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Is this instruction one such that following or preceding it with a new
|
// Is this instruction one such that following or preceding it with a new
|
||||||
|
@ -71,6 +71,33 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) {
|
|||||||
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
|
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) {
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
Shape shape = ShapeUtil::MakeShape(F32, {});
|
||||||
|
|
||||||
|
// Create a simple graph with a parameter feeding a unary cosine function.
|
||||||
|
HloInstruction* a =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||||
|
HloInstruction* b = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||||
|
|
||||||
|
auto module = CreateNewUnverifiedModule();
|
||||||
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
// Confirm expected state before adding ops.
|
||||||
|
EXPECT_EQ(computation->root_instruction(), b);
|
||||||
|
EXPECT_EQ(b->operand(0), a);
|
||||||
|
|
||||||
|
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
|
||||||
|
[](const HloInstruction* instruction) {
|
||||||
|
return instruction->opcode() == HloOpcode::kCos;
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Confirm expected graph after adding ops.
|
||||||
|
EXPECT_EQ(computation->root_instruction(), b);
|
||||||
|
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) {
|
TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) {
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||||
|
Loading…
x
Reference in New Issue
Block a user