diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 9cf8b83d5b9..d08bc61e37d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1887,6 +1887,38 @@ cc_library( ], ) +cc_library( + name = "reduce_precision_insertion", + srcs = ["reduce_precision_insertion.cc"], + hdrs = ["reduce_precision_insertion.h"], + deps = [ + ":buffer_liveness", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "reduce_precision_insertion_test", + size = "small", + srcs = ["reduce_precision_insertion_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":reduce_precision_insertion", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 141251011cc..79f17bbb6bd 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -95,6 +95,7 @@ HLO_MATCHER(Parameter); HLO_MATCHER(Power); HLO_MATCHER(Recv); HLO_MATCHER(Reduce); +HLO_MATCHER(ReducePrecision); HLO_MATCHER(ReduceWindow); HLO_MATCHER(Remainder); HLO_MATCHER(Reshape); diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc new file mode 100644 index 00000000000..dafefdc4910 --- /dev/null +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr ReducePrecisionInsertion::Run(HloModule* module) { + bool changed = false; + VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name(); + + for (auto& computation : module->computations()) { + std::vector instructions_to_suffix; + + for (auto& instruction : computation->instructions()) { + VLOG(3) << "Visited instruction: " << instruction->ToString(); + + // For now, ReducePrecision is only implemented for F32 data, so this + // ignore instructions that produce other data. In particular, this + // currently ignores instructions producing tuples, even if those tuples + // contain F32 data inside them. The assumption is that in most cases + // equivalent behavior can be obtained by adding ReducePrecision + // instructions after the instructions that pull the F32 data out of the + // tuples. + if (instruction->shape().element_type() == PrimitiveType::F32 && + should_reduce_output_precision_(instruction->opcode())) { + instructions_to_suffix.push_back(instruction.get()); + } + } + + for (auto& instruction : instructions_to_suffix) { + HloInstruction* reduced = + computation->AddInstruction(HloInstruction::CreateReducePrecision( + instruction->shape(), instruction, exponent_bits_, + mantissa_bits_)); + TF_RETURN_IF_ERROR( + computation->ReplaceUsesOfInstruction(instruction, reduced)); + VLOG(2) << "Inserted new op after instruction: " + << instruction->ToString(); + changed = true; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h new file mode 100644 index 00000000000..e9c8bba0313 --- /dev/null +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ + +#include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +// HLO pass which inserts reduce-precision instructions into the HLO graph, for +// purposes of experimenting with the effects of reduced-precision storage of +// intermediate values. +class ReducePrecisionInsertion : public HloPassInterface { + using OpcodeFilterFunction = std::function; + + public: + // The exponent_bits and mantissa_bits arguments specify the parameters of + // the instructions to insert. The instructions will be inserted after each + // instruction with an opcode for which the should_reduce_output_precision + // function returns true and the output type is F32. + explicit ReducePrecisionInsertion( + const int exponent_bits, const int mantissa_bits, + const OpcodeFilterFunction& should_reduce_output_precision) + : exponent_bits_(exponent_bits), + mantissa_bits_(mantissa_bits), + should_reduce_output_precision_(should_reduce_output_precision) {} + ~ReducePrecisionInsertion() override{}; + + tensorflow::StringPiece name() const override { + return "reduce-precision-insertion"; + } + + // Run the pass on the given module. Returns whether the module was changed + // (reduce-precision instructions were inserted). + StatusOr Run(HloModule* module) override; + + private: + // Parameters for the precision reduction to be added. + const int exponent_bits_; + const int mantissa_bits_; + + // Function to determine (from the opcode) whether a given instruction should + // have a reduce-precision instruction inserted in its output stream. + const OpcodeFilterFunction& should_reduce_output_precision_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc new file mode 100644 index 00000000000..80717ec2e3f --- /dev/null +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { + +using ::testing::UnorderedElementsAre; + +class ReducePrecisionInsertionTest : public HloTestBase { + protected: + bool InsertOps(HloModule* module, + const std::function& filter) { + ReducePrecisionInsertion op_insertion(5, 10, filter); + StatusOr result = op_insertion.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(ReducePrecisionInsertionTest, RootInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + // Create a simple graph with a parameter feeding a unary cosine function. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + + EXPECT_TRUE(InsertOps(module.get(), + [](HloOpcode h) { return h == HloOpcode::kCos; })); + + // Confirm expected graph after adding ops. + EXPECT_THAT(computation->root_instruction(), op::ReducePrecision()); + EXPECT_EQ(computation->root_instruction()->operand(0), b); +} + +TEST_F(ReducePrecisionInsertionTest, NonRootInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + // Create a graph with two parameters feeding into unary cosine functions, + // and the output of those feeds into an add function. Feeding the outputs + // from the suffixed cosine functions into a binary add function allows us to + // confirm that the separate operand streams are not crossed when the new + // instructions are inserted. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* a_cos = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* b_cos = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, b)); + + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + // Confirm expected graph before adding ops. + EXPECT_EQ(c->operand(0), a_cos); + EXPECT_EQ(c->operand(1), b_cos); + + EXPECT_TRUE(InsertOps(module.get(), + [](HloOpcode h) { return h == HloOpcode::kCos; })); + + // Confirm expected graph after adding ops. + EXPECT_THAT(c->operand(0), op::ReducePrecision()); + EXPECT_EQ(c->operand(0)->operand(0), a_cos); + EXPECT_THAT(c->operand(1), op::ReducePrecision()); + EXPECT_EQ(c->operand(1)->operand(0), b_cos); +} + +TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(S32, {4}); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected graph before adding ops. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); + + // Since none of the instructions produce F32 data, this should not change + // the graph. + EXPECT_FALSE(InsertOps(module.get(), [](HloOpcode) { return true; })); + + // Confirm that graph has not changed. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); +} + +TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected graph before adding ops. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); + + // Since none of the instructions match the should_reduce_output_precision + // function, this should not change the graph. + EXPECT_FALSE(InsertOps(module.get(), [](HloOpcode h) { return false; })); + + // Confirm that graph has not changed. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); +} + +TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateReducePrecision(shape, a, 9, 23)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + + // This should insert a new ReducePrecision after the existing one, but + // should not then recurse by adding another after the just-inserted one. + EXPECT_TRUE(InsertOps(module.get(), [](HloOpcode h) { + return h == HloOpcode::kReducePrecision; + })); + + // Confirm expected graph after adding ops. + EXPECT_THAT(computation->root_instruction(), op::ReducePrecision()); + EXPECT_EQ(computation->root_instruction()->operand(0), b); +} + +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +}