From b9247ca08f491d02208732e51d180bb3304d977b Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 13 Nov 2018 05:09:19 -0800 Subject: [PATCH] [XLA] Remove non-scalar requirement on reduce precision This was fixed in cl/211936337 PiperOrigin-RevId: 221249970 --- .../xla/service/reduce_precision_insertion.cc | 4 +-- .../xla/service/reduce_precision_insertion.h | 8 +----- .../reduce_precision_insertion_test.cc | 27 +++++++++++++++++++ 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index 688cceff0cd..b70cb705747 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -111,7 +111,7 @@ StatusOr ReducePrecisionInsertion::insert_on_inputs( VLOG(2) << "Adding to operand " << i << ": " << operand; if (!is_valid_shape(operand->shape())) { - VLOG(2) << "Skipped: value is not an F32 vector"; + VLOG(2) << "Skipped: value is not of type F32"; continue; } @@ -168,7 +168,7 @@ StatusOr ReducePrecisionInsertion::insert_on_outputs( << instruction->ToString(); if (!is_valid_shape(instruction->shape())) { - VLOG(2) << "Skipped: value is not an F32 nonscalar array"; + VLOG(2) << "Skipped: value is not of type F32"; continue; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 0b4e82e8d60..76c6a87f176 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -118,13 +118,7 @@ class ReducePrecisionInsertion : public HloModulePass { // equivalent behavior can be obtained by adding ReducePrecision // instructions after the instructions that pull the F32 arrays out of // the tuples. - // - // TODO(b/64093391): Remove the IsScalar check once this won't cause - // failures on the GPU backend if the ReducePrecision instruction ends up - // inserted between a scalar constant and the init_value argument of a - // Reduce operation. - return shape.element_type() == PrimitiveType::F32 && - !ShapeUtil::IsScalar(shape); + return shape.element_type() == PrimitiveType::F32; } // Is this instruction one such that following or preceding it with a new diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc index e4e4757c501..16fa80d53e7 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -71,6 +71,33 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) { EXPECT_THAT(b->operand(0), op::ReducePrecision(a)); } +TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {}); + + // Create a simple graph with a parameter feeding a unary cosine function. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + auto module = CreateNewUnverifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_EQ(b->operand(0), a); + + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCos; + })); + + // Confirm expected graph after adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_THAT(b->operand(0), op::ReducePrecision(a)); +} + TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) { auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4});