[XLA] Add ReducePrecisionInsertion pass.
This new HLO pass, intended for experimental purposes rather than optimization, inserts ReducePrecision instructions (with user-specified bitsizes) after all instructions of opcode types specified by the user. This makes it possible to do experiments on the numerical effects storing intermediate values in reduced precision without changing the HLO graph definition. PiperOrigin-RevId: 161117760
This commit is contained in:
parent
27b15d8b59
commit
a22dad9836
@ -1887,6 +1887,38 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "reduce_precision_insertion",
|
||||
srcs = ["reduce_precision_insertion.cc"],
|
||||
hdrs = ["reduce_precision_insertion.h"],
|
||||
deps = [
|
||||
":buffer_liveness",
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "reduce_precision_insertion_test",
|
||||
size = "small",
|
||||
srcs = ["reduce_precision_insertion_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":reduce_precision_insertion",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
@ -95,6 +95,7 @@ HLO_MATCHER(Parameter);
|
||||
HLO_MATCHER(Power);
|
||||
HLO_MATCHER(Recv);
|
||||
HLO_MATCHER(Reduce);
|
||||
HLO_MATCHER(ReducePrecision);
|
||||
HLO_MATCHER(ReduceWindow);
|
||||
HLO_MATCHER(Remainder);
|
||||
HLO_MATCHER(Reshape);
|
||||
|
@ -0,0 +1,61 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name();
|
||||
|
||||
for (auto& computation : module->computations()) {
|
||||
std::vector<HloInstruction*> instructions_to_suffix;
|
||||
|
||||
for (auto& instruction : computation->instructions()) {
|
||||
VLOG(3) << "Visited instruction: " << instruction->ToString();
|
||||
|
||||
// For now, ReducePrecision is only implemented for F32 data, so this
|
||||
// ignore instructions that produce other data. In particular, this
|
||||
// currently ignores instructions producing tuples, even if those tuples
|
||||
// contain F32 data inside them. The assumption is that in most cases
|
||||
// equivalent behavior can be obtained by adding ReducePrecision
|
||||
// instructions after the instructions that pull the F32 data out of the
|
||||
// tuples.
|
||||
if (instruction->shape().element_type() == PrimitiveType::F32 &&
|
||||
should_reduce_output_precision_(instruction->opcode())) {
|
||||
instructions_to_suffix.push_back(instruction.get());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& instruction : instructions_to_suffix) {
|
||||
HloInstruction* reduced =
|
||||
computation->AddInstruction(HloInstruction::CreateReducePrecision(
|
||||
instruction->shape(), instruction, exponent_bits_,
|
||||
mantissa_bits_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation->ReplaceUsesOfInstruction(instruction, reduced));
|
||||
VLOG(2) << "Inserted new op after instruction: "
|
||||
<< instruction->ToString();
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
67
tensorflow/compiler/xla/service/reduce_precision_insertion.h
Normal file
67
tensorflow/compiler/xla/service/reduce_precision_insertion.h
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// HLO pass which inserts reduce-precision instructions into the HLO graph, for
|
||||
// purposes of experimenting with the effects of reduced-precision storage of
|
||||
// intermediate values.
|
||||
class ReducePrecisionInsertion : public HloPassInterface {
|
||||
using OpcodeFilterFunction = std::function<bool(HloOpcode)>;
|
||||
|
||||
public:
|
||||
// The exponent_bits and mantissa_bits arguments specify the parameters of
|
||||
// the instructions to insert. The instructions will be inserted after each
|
||||
// instruction with an opcode for which the should_reduce_output_precision
|
||||
// function returns true and the output type is F32.
|
||||
explicit ReducePrecisionInsertion(
|
||||
const int exponent_bits, const int mantissa_bits,
|
||||
const OpcodeFilterFunction& should_reduce_output_precision)
|
||||
: exponent_bits_(exponent_bits),
|
||||
mantissa_bits_(mantissa_bits),
|
||||
should_reduce_output_precision_(should_reduce_output_precision) {}
|
||||
~ReducePrecisionInsertion() override{};
|
||||
|
||||
tensorflow::StringPiece name() const override {
|
||||
return "reduce-precision-insertion";
|
||||
}
|
||||
|
||||
// Run the pass on the given module. Returns whether the module was changed
|
||||
// (reduce-precision instructions were inserted).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
// Parameters for the precision reduction to be added.
|
||||
const int exponent_bits_;
|
||||
const int mantissa_bits_;
|
||||
|
||||
// Function to determine (from the opcode) whether a given instruction should
|
||||
// have a reduce-precision instruction inserted in its output stream.
|
||||
const OpcodeFilterFunction& should_reduce_output_precision_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
|
@ -0,0 +1,186 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
namespace xla {
|
||||
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
class ReducePrecisionInsertionTest : public HloTestBase {
|
||||
protected:
|
||||
bool InsertOps(HloModule* module,
|
||||
const std::function<bool(HloOpcode)>& filter) {
|
||||
ReducePrecisionInsertion op_insertion(5, 10, filter);
|
||||
StatusOr<bool> result = op_insertion.Run(module);
|
||||
EXPECT_IS_OK(result.status());
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, RootInstruction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
|
||||
// Create a simple graph with a parameter feeding a unary cosine function.
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected state before adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
|
||||
EXPECT_TRUE(InsertOps(module.get(),
|
||||
[](HloOpcode h) { return h == HloOpcode::kCos; }));
|
||||
|
||||
// Confirm expected graph after adding ops.
|
||||
EXPECT_THAT(computation->root_instruction(), op::ReducePrecision());
|
||||
EXPECT_EQ(computation->root_instruction()->operand(0), b);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, NonRootInstruction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
|
||||
// Create a graph with two parameters feeding into unary cosine functions,
|
||||
// and the output of those feeds into an add function. Feeding the outputs
|
||||
// from the suffixed cosine functions into a binary add function allows us to
|
||||
// confirm that the separate operand streams are not crossed when the new
|
||||
// instructions are inserted.
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* a_cos = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
|
||||
HloInstruction* b =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
|
||||
HloInstruction* b_cos = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, b));
|
||||
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected graph before adding ops.
|
||||
EXPECT_EQ(c->operand(0), a_cos);
|
||||
EXPECT_EQ(c->operand(1), b_cos);
|
||||
|
||||
EXPECT_TRUE(InsertOps(module.get(),
|
||||
[](HloOpcode h) { return h == HloOpcode::kCos; }));
|
||||
|
||||
// Confirm expected graph after adding ops.
|
||||
EXPECT_THAT(c->operand(0), op::ReducePrecision());
|
||||
EXPECT_EQ(c->operand(0)->operand(0), a_cos);
|
||||
EXPECT_THAT(c->operand(1), op::ReducePrecision());
|
||||
EXPECT_EQ(c->operand(1)->operand(0), b_cos);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(S32, {4});
|
||||
HloInstruction* x =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
HloInstruction* y = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected graph before adding ops.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
|
||||
// Since none of the instructions produce F32 data, this should not change
|
||||
// the graph.
|
||||
EXPECT_FALSE(InsertOps(module.get(), [](HloOpcode) { return true; }));
|
||||
|
||||
// Confirm that graph has not changed.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* x =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
HloInstruction* y = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected graph before adding ops.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
|
||||
// Since none of the instructions match the should_reduce_output_precision
|
||||
// function, this should not change the graph.
|
||||
EXPECT_FALSE(InsertOps(module.get(), [](HloOpcode h) { return false; }));
|
||||
|
||||
// Confirm that graph has not changed.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateReducePrecision(shape, a, 9, 23));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected state before adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
|
||||
// This should insert a new ReducePrecision after the existing one, but
|
||||
// should not then recurse by adding another after the just-inserted one.
|
||||
EXPECT_TRUE(InsertOps(module.get(), [](HloOpcode h) {
|
||||
return h == HloOpcode::kReducePrecision;
|
||||
}));
|
||||
|
||||
// Confirm expected graph after adding ops.
|
||||
EXPECT_THAT(computation->root_instruction(), op::ReducePrecision());
|
||||
EXPECT_EQ(computation->root_instruction()->operand(0), b);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
|
||||
}
|
Loading…
Reference in New Issue
Block a user