[XLA] Remove non-scalar requirement on reduce precision

This was fixed in cl/211936337

PiperOrigin-RevId: 221249970
This commit is contained in:
Benjamin Kramer 2018-11-13 05:09:19 -08:00 committed by TensorFlower Gardener
parent d8af932d02
commit b9247ca08f
3 changed files with 30 additions and 9 deletions

View File

@ -111,7 +111,7 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
VLOG(2) << "Adding to operand " << i << ": " << operand;
if (!is_valid_shape(operand->shape())) {
VLOG(2) << "Skipped: value is not an F32 vector";
VLOG(2) << "Skipped: value is not of type F32";
continue;
}
@ -168,7 +168,7 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
<< instruction->ToString();
if (!is_valid_shape(instruction->shape())) {
VLOG(2) << "Skipped: value is not an F32 nonscalar array";
VLOG(2) << "Skipped: value is not of type F32";
continue;
}

View File

@ -118,13 +118,7 @@ class ReducePrecisionInsertion : public HloModulePass {
// equivalent behavior can be obtained by adding ReducePrecision
// instructions after the instructions that pull the F32 arrays out of
// the tuples.
//
// TODO(b/64093391): Remove the IsScalar check once this won't cause
// failures on the GPU backend if the ReducePrecision instruction ends up
// inserted between a scalar constant and the init_value argument of a
// Reduce operation.
return shape.element_type() == PrimitiveType::F32 &&
!ShapeUtil::IsScalar(shape);
return shape.element_type() == PrimitiveType::F32;
}
// Is this instruction one such that following or preceding it with a new

View File

@ -71,6 +71,33 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) {
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
}
TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {});
// Create a simple graph with a parameter feeding a unary cosine function.
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
auto module = CreateNewUnverifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected state before adding ops.
EXPECT_EQ(computation->root_instruction(), b);
EXPECT_EQ(b->operand(0), a);
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos;
}));
// Confirm expected graph after adding ops.
EXPECT_EQ(computation->root_instruction(), b);
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
}
TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});