Remove ReducePrecisionInsertion pass.

It doesn't seem to be useful anymore.
Also remove the related XLA flags which could be used to enable it.

PiperOrigin-RevId: 278331729
Change-Id: I3c0094b60f4b51ee3b64ec35bca96ec17f69c5f6
This commit is contained in:
Adrian Kuegel 2019-11-04 02:05:05 -08:00 committed by TensorFlower Gardener
parent 32d76ec3e6
commit f211533f4d
17 changed files with 7 additions and 1458 deletions

View File

@ -900,7 +900,6 @@ cc_library(
[
":parse_flags_from_env",
":xla_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",

View File

@ -165,15 +165,6 @@ static void AllocateFlags() {
return true;
};
// Custom "sub-parser" lambda for xla_reduce_precision.
auto setter_for_xla_reduce_precision =
[](string reduce_precision_option_value) {
HloReducePrecisionOptions* option_proto =
flag_values->add_hlo_reduce_precision_options();
return parse_xla_reduce_precision_option(option_proto,
reduce_precision_option_value);
};
// Custom "sub-parser" for xla_fuel. Note that ConsumeFuel does not do any
// locking on the fuel global variables. This means that it's
// illegal/undefined behavior to modify this flag value while the compiler is
@ -389,19 +380,6 @@ static void AllocateFlags() {
"Extra options to pass to a backend; "
"comma-separated list of 'key=val' strings (=val "
"may be omitted); no whitespace around commas."),
tensorflow::Flag("xla_reduce_precision", setter_for_xla_reduce_precision,
"",
"Directions for adding reduce-precision operations. "
"Format is 'LOCATION=E,M:OPS;NAMES' where LOCATION is "
"the class of locations in which to insert the "
"operations (e.g., 'OP_OUTPUTS'), E and M are the "
"exponent and matissa bit counts respectively, and "
"OPS and NAMES are comma-separated (no spaces) lists "
"of the operation types and names to which to attach "
"the reduce-precision operations. The NAMES string "
"and its preceding ';' may be omitted. This option "
"may be repeated to define multiple sets of added "
"reduce-precision operations."),
tensorflow::Flag(
"xla_gpu_use_cudnn_batchnorm",
bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm),

View File

@ -16,28 +16,29 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_
#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_
#include <string>
#include <vector>
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/xla.pb.h"
namespace xla {
template <typename T>
void parse_xla_backend_extra_options(T* extra_options_map,
string comma_separated_values) {
std::vector<string> extra_options_parts =
std::string comma_separated_values) {
std::vector<std::string> extra_options_parts =
absl::StrSplit(comma_separated_values, ',');
// The flag contains a comma-separated list of options; some options
// have arguments following "=", some don't.
for (const auto& part : extra_options_parts) {
size_t eq_pos = part.find_first_of('=');
if (eq_pos == string::npos) {
if (eq_pos == std::string::npos) {
(*extra_options_map)[part] = "";
} else {
string value = "";
std::string value = "";
if (eq_pos + 1 < part.size()) {
value = part.substr(eq_pos + 1);
}
@ -46,98 +47,6 @@ void parse_xla_backend_extra_options(T* extra_options_map,
}
}
// The --xla_reduce_precision option has the format "LOCATION=E,M:OPS;NAME",
// where LOCATION is an HloReducePrecisionOptions::location, E and M are
// integers for the exponent and matissa bit counts respectively, and OPS and
// NAMES are comma-separated of the operation types and names to which to
// attach the reduce-precision operations. The OPS values are matches to the
// strings produced by HloOpcodeString, while the NAME values are arbitrary
// strings subject to the requirements that they not contain any of "=,:;".
// The NAME string (with its preceding semicolon) is optional.
inline bool parse_xla_reduce_precision_option(
HloReducePrecisionOptions* options, string option_string) {
// Split off "LOCATION" from remainder of string.
std::vector<string> eq_split = absl::StrSplit(option_string, '=');
if (eq_split.size() != 2) {
return false;
}
string& location = eq_split[0];
if (location == "OP_INPUTS") {
options->set_location(HloReducePrecisionOptions::OP_INPUTS);
} else if (location == "OP_OUTPUTS") {
options->set_location(HloReducePrecisionOptions::OP_OUTPUTS);
} else if (location == "UNFUSED_OP_OUTPUTS") {
options->set_location(HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS);
} else if (location == "FUSION_INPUTS_BY_CONTENT") {
options->set_location(HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT);
} else if (location == "FUSION_OUTPUTS_BY_CONTENT") {
options->set_location(HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT);
} else {
return false;
}
// Split off "E,M" from remainder of string.
std::vector<string> colon_split = absl::StrSplit(eq_split[1], ':');
if (colon_split.size() != 2) {
return false;
}
// Split E and M, and parse.
std::vector<int32> bitsizes;
for (const auto& s : absl::StrSplit(colon_split[0], ',')) {
bitsizes.emplace_back();
if (!absl::SimpleAtoi(s, &bitsizes.back())) {
return false;
}
}
options->set_exponent_bits(bitsizes[0]);
options->set_mantissa_bits(bitsizes[1]);
// Split off OPS comma-separated list from remainder of string, if the
// remainder exists.
std::vector<string> semicolon_split = absl::StrSplit(colon_split[1], ';');
if (semicolon_split.size() > 2) {
return false;
}
// The opcode values are either 'all' (meaning all opcodes), or matches to
// the strings returned by HloOpcodeString. An empty string is also
// interpreted as 'all', for convenience. Note that 'all' may not be part
// of a comma-separated list; it must stand alone.
string& opcode_string = semicolon_split[0];
if (opcode_string == "" || opcode_string == "all") {
for (int i = 0; i < HloOpcodeCount(); i++) {
options->add_opcodes_to_suffix(i);
}
} else {
std::vector<string> opcodes = absl::StrSplit(opcode_string, ',');
for (const string& opcode : opcodes) {
bool found = false;
for (int i = 0; i < HloOpcodeCount(); i++) {
if (opcode == HloOpcodeString(static_cast<HloOpcode>(i))) {
options->add_opcodes_to_suffix(i);
found = true;
break;
}
}
if (!found) {
return false;
}
}
}
// Process the NAMES string, if it exists.
if (semicolon_split.size() == 2) {
std::vector<string> opnames = absl::StrSplit(semicolon_split[1], ',');
for (const string& opname : opnames) {
if (opname.length() > 0) {
options->add_opname_substrings_to_suffix(opname);
}
}
}
return true;
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_

View File

@ -36,65 +36,6 @@ TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) {
EXPECT_EQ(test_map.at("ee"), "ff=gg");
}
// Test that the xla_reduce_precision flag is parsed correctly.
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) {
HloReducePrecisionOptions proto;
string test_string = "OP_OUTPUTS=5,10:add,dot";
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS);
EXPECT_EQ(proto.exponent_bits(), 5);
EXPECT_EQ(proto.mantissa_bits(), 10);
EXPECT_EQ(proto.opcodes_to_suffix_size(), 2);
EXPECT_EQ(static_cast<HloOpcode>(proto.opcodes_to_suffix(0)),
HloOpcode::kAdd);
EXPECT_EQ(static_cast<HloOpcode>(proto.opcodes_to_suffix(1)),
HloOpcode::kDot);
EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 0);
}
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) {
HloReducePrecisionOptions proto;
string test_string = "OP_OUTPUTS=5,10:add,dot;";
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS);
EXPECT_EQ(proto.exponent_bits(), 5);
EXPECT_EQ(proto.mantissa_bits(), 10);
EXPECT_EQ(proto.opcodes_to_suffix_size(), 2);
EXPECT_EQ(static_cast<HloOpcode>(proto.opcodes_to_suffix(0)),
HloOpcode::kAdd);
EXPECT_EQ(static_cast<HloOpcode>(proto.opcodes_to_suffix(1)),
HloOpcode::kDot);
EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 0);
}
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) {
HloReducePrecisionOptions proto;
string test_string = "UNFUSED_OP_OUTPUTS=5,10:;foo,bar/baz";
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS);
EXPECT_EQ(proto.exponent_bits(), 5);
EXPECT_EQ(proto.mantissa_bits(), 10);
EXPECT_EQ(proto.opcodes_to_suffix_size(), HloOpcodeCount());
EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 2);
EXPECT_EQ(proto.opname_substrings_to_suffix(0), "foo");
EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz");
}
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) {
HloReducePrecisionOptions proto;
string test_string = "UNFUSED_OP_OUTPUTS=5,10:subtract;foo,bar/baz";
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS);
EXPECT_EQ(proto.exponent_bits(), 5);
EXPECT_EQ(proto.mantissa_bits(), 10);
EXPECT_EQ(proto.opcodes_to_suffix_size(), 1);
EXPECT_EQ(static_cast<HloOpcode>(proto.opcodes_to_suffix(0)),
HloOpcode::kSubtract);
EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 2);
EXPECT_EQ(proto.opname_substrings_to_suffix(0), "foo");
EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz");
}
} // namespace xla
int main(int argc, char* argv[]) {

View File

@ -3763,36 +3763,6 @@ tf_cc_test(
],
)
cc_library(
name = "reduce_precision_insertion",
srcs = ["reduce_precision_insertion.cc"],
hdrs = ["reduce_precision_insertion.h"],
deps = [
":hlo",
":hlo_pass",
":hlo_pass_pipeline",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "reduce_precision_insertion_test",
size = "small",
srcs = ["reduce_precision_insertion_test.cc"],
deps = [
":hlo",
":hlo_matchers",
":reduce_precision_insertion",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library(
name = "hlo_runner",
srcs = ["hlo_runner.cc"],

View File

@ -135,7 +135,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:indexed_array_analysis",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:rng_expander",
"//tensorflow/compiler/xla/service:sort_simplifier",

View File

@ -93,7 +93,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/rng_expander.h"
#include "tensorflow/compiler/xla/service/scatter_expander.h"
@ -251,10 +250,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<DynamicIndexSplitter>();
pipeline.AddPass<CpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
pipeline.AddPass<ConditionalToSelect>();
pipeline.AddPass<MapInliner>();
@ -337,9 +332,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<CpuInstructionFusion>();
ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
return pipeline.Run(module).status();
}

View File

@ -1130,7 +1130,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:rng_expander",
"//tensorflow/compiler/xla/service:slice_sinker",

View File

@ -78,7 +78,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/rng_expander.h"
#include "tensorflow/compiler/xla/service/slice_sinker.h"
@ -136,9 +135,6 @@ Status GpuCompiler::OptimizeHloModule(
pipeline.AddPass<DynamicIndexSplitter>();
pipeline.AddPass<GpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
// TODO(b/64094172): make Call work on GPU instead of inlining.
pipeline.AddPass<CallInliner>();
@ -263,24 +259,6 @@ Status GpuCompiler::OptimizeHloModule(
/*only_fusion_computations=*/true);
fusion.AddPass<HloDCE>();
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
/* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
* fixing the ticket. */
reduce_pipeline.AddInvariantChecker<HloVerifier>(
/*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false,
LayoutAssignment::InstructionCanChangeLayout);
ReducePrecisionInsertion::AddPasses(
&reduce_pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
StatusOr<bool> reduce_result = reduce_pipeline.Run(hlo_module);
TF_RETURN_IF_ERROR(reduce_result.status());
if (reduce_result.ValueOrDie()) {
// Do another fusion pass, with the expectation that we may be able to
// fuse the new ReducePrecision operations.
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
}
}
return Status::OK();

View File

@ -51,7 +51,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:triangular_solve_expander",
"//tensorflow/compiler/xla/service:while_loop_simplifier",

View File

@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/interpreter/executable.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
@ -87,10 +86,6 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
hlo_module->mutable_entry_computation_layout(),
LayoutAssignment::InstructionCanChangeLayout);
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
return pipeline.Run(hlo_module).status();
}

View File

@ -1,310 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
std::vector<HloInstruction*> ReducePrecisionInsertion::instructions_to_modify(
const HloComputation* computation) {
std::vector<HloInstruction*> instruction_list;
switch (location_) {
case HloReducePrecisionOptions::OP_INPUTS:
case HloReducePrecisionOptions::OP_OUTPUTS:
case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
for (auto* instruction : computation->instructions()) {
VLOG(4) << "Visited instruction: " << instruction->ToString();
if (instruction_filter_function_(instruction)) {
instruction_list.push_back(instruction);
}
}
break;
case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
for (auto* instruction : computation->instructions()) {
VLOG(4) << "Visited instruction: " << instruction->ToString();
if (instruction->opcode() != HloOpcode::kFusion) {
continue;
}
for (auto* fused_instruction :
instruction->fused_instructions_computation()->instructions()) {
VLOG(4) << "Checking sub-instruction: "
<< fused_instruction->ToString();
if (instruction_filter_function_(fused_instruction)) {
instruction_list.push_back(instruction);
break;
}
}
}
break;
default:
break;
}
VLOG(1) << "Found " << instruction_list.size()
<< " candidate instruction(s) for reduce-precision insertion";
return instruction_list;
}
StatusOr<bool> ReducePrecisionInsertion::insert_after(
HloInstruction* instruction) {
// Check that this isn't already an equivalent operation.
if (is_redundant(instruction)) {
VLOG(2) << "Skipped: instruction is already an equivalent"
" reduce-precision instruction:"
<< instruction->ToString();
return false;
}
// Check that we haven't already inserted an equivalent reduce-precision
// operation after this instruction. (The zero-user case occurs when this is
// the root instruction.)
if (instruction->user_count() > 0) {
bool redundant_followers = true;
for (HloInstruction* user : instruction->users()) {
if (!is_redundant(user)) {
redundant_followers = false;
break;
}
}
if (redundant_followers) {
VLOG(2) << "Skipped: instruction already followed by equivalent"
" reduce-precision instructions";
return false;
}
}
HloInstruction* reduced = instruction->parent()->AddInstruction(
HloInstruction::CreateReducePrecision(instruction->shape(), instruction,
exponent_bits_, mantissa_bits_));
TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(reduced));
return true;
}
StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
const std::vector<HloInstruction*>& instructions) {
bool computation_changed = false;
for (auto instruction : instructions) {
VLOG(2) << "Adding reduce-precision operation to inputs of instruction: "
<< instruction->ToString();
for (int64 i = 0; i < instruction->operand_count(); i++) {
HloInstruction* operand = instruction->mutable_operand(i);
VLOG(2) << "Adding to operand " << i << ": " << operand;
if (!is_valid_shape(operand->shape())) {
VLOG(2) << "Skipped: value is not of type F32";
continue;
}
if (is_redundant(operand)) {
VLOG(2) << "Skipped: operand is already an equivalent reduce-precision"
" instruction";
continue;
}
if (instruction->IsInputFusion() || instruction->IsLoopFusion()) {
// Insert the reduce-precision operation inside the fusion computation,
// after the corresponding parameter instruction.
TF_ASSIGN_OR_RETURN(
bool instruction_changed,
insert_after(instruction->fused_instructions_computation()
->parameter_instruction(i)));
computation_changed |= instruction_changed;
} else {
// Look for an existing reduce-precision operation on the operand. (We
// need to be careful not to create a loop, though!)
HloInstruction* reduced = nullptr;
for (auto& user : operand->users()) {
if (user != instruction &&
user->opcode() == HloOpcode::kReducePrecision &&
user->exponent_bits() == exponent_bits_ &&
user->mantissa_bits() == mantissa_bits_) {
reduced = user;
break;
}
}
// If there wasn't an existing reduce-precision operation, create one.
if (!reduced) {
reduced = instruction->parent()->AddInstruction(
HloInstruction::CreateReducePrecision(
operand->shape(), operand, exponent_bits_, mantissa_bits_));
}
// Insert the reduce-precision operation before the operand.
TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(i, reduced));
computation_changed = true;
}
}
}
return computation_changed;
}
StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
const std::vector<HloInstruction*>& instructions) {
bool computation_changed = false;
for (const auto& instruction : instructions) {
VLOG(2) << "Adding reduce-precision operation to output of instruction: "
<< instruction->ToString();
if (!is_valid_shape(instruction->shape())) {
VLOG(2) << "Skipped: value is not of type F32";
continue;
}
if (instruction->IsLoopFusion() || instruction->IsOutputFusion()) {
// Insert the reduce-precision operation as the last operation inside
// the fusion computation.
HloInstruction* fusion_root = instruction->fused_expression_root();
VLOG(2) << "Inserting new operation after existing fusion root: "
<< fusion_root->ToString();
TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(fusion_root));
computation_changed |= instruction_changed;
} else {
// Insert the reduce-precision operation after the instruction.
TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(instruction));
computation_changed |= instruction_changed;
}
}
return computation_changed;
}
StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
bool changed = false;
VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name();
for (auto* computation : module->MakeNonfusionComputations()) {
StatusOr<bool> computation_changed;
switch (location_) {
case HloReducePrecisionOptions::OP_INPUTS:
case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
computation_changed = ReducePrecisionInsertion::insert_on_inputs(
instructions_to_modify(computation));
break;
case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
case HloReducePrecisionOptions::OP_OUTPUTS:
case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
computation_changed = ReducePrecisionInsertion::insert_on_outputs(
instructions_to_modify(computation));
break;
default:
break;
}
TF_RETURN_IF_ERROR(computation_changed.status());
if (computation_changed.ValueOrDie()) {
changed = true;
VLOG(3) << "Computation after reduce-precision insertion:";
XLA_VLOG_LINES(3, computation->ToString());
} else {
VLOG(3) << "Computation " << computation->name() << " unchanged";
}
}
return changed;
}
ReducePrecisionInsertion::InstructionFilterFunction
ReducePrecisionInsertion::make_filter_function(
const HloReducePrecisionOptions& reduce_precision_options) {
// Implement the filter function with a lookup table.
std::vector<bool> opcode_filter(HloOpcodeCount(), false);
for (const auto& opcode : reduce_precision_options.opcodes_to_suffix()) {
opcode_filter[opcode] = true;
}
if (reduce_precision_options.opname_substrings_to_suffix_size() == 0) {
return [opcode_filter](const HloInstruction* instruction) {
return opcode_filter[static_cast<unsigned int>(instruction->opcode())];
};
} else {
std::vector<string> opname_substrings;
for (const auto& substring :
reduce_precision_options.opname_substrings_to_suffix()) {
opname_substrings.push_back(substring);
}
return [opcode_filter,
opname_substrings](const HloInstruction* instruction) {
if (!opcode_filter[static_cast<unsigned int>(instruction->opcode())]) {
return false;
}
const auto& opname = instruction->metadata().op_name();
for (const auto& substring : opname_substrings) {
if (opname.find(substring) != string::npos) {
return true;
}
}
return false;
};
}
}
HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto(
const HloReducePrecisionOptions::Location location, const int exponent_bits,
const int mantissa_bits,
const std::function<bool(HloOpcode)>& opcode_filter_function,
const std::vector<string>& opname_substring_list) {
HloReducePrecisionOptions options;
options.set_location(location);
options.set_exponent_bits(exponent_bits);
options.set_mantissa_bits(mantissa_bits);
for (uint32_t opcode = 0; opcode < HloOpcodeCount(); opcode++) {
if (opcode_filter_function(static_cast<HloOpcode>(opcode))) {
options.add_opcodes_to_suffix(opcode);
}
}
for (auto& string : opname_substring_list) {
options.add_opname_substrings_to_suffix(string);
}
return options;
}
bool ReducePrecisionInsertion::AddPasses(HloPassPipeline* pipeline,
const DebugOptions& debug_options,
const PassTiming pass_timing) {
bool passes_added = false;
for (const auto& pass_options :
debug_options.hlo_reduce_precision_options()) {
bool add_pass;
switch (pass_options.location()) {
case HloReducePrecisionOptions::OP_INPUTS:
case HloReducePrecisionOptions::OP_OUTPUTS:
add_pass = pass_timing == PassTiming::BEFORE_OPTIMIZATION;
break;
case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
add_pass = pass_timing == PassTiming::AFTER_FUSION;
break;
default:
add_pass = false;
}
if (add_pass) {
pipeline->AddPass<ReducePrecisionInsertion>(pass_options);
passes_added = true;
}
}
return passes_added;
}
} // namespace xla

View File

@ -1,146 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
namespace xla {
// HLO pass which inserts reduce-precision instructions into the HLO graph, for
// purposes of experimenting with the effects of reduced-precision storage of
// intermediate values.
class ReducePrecisionInsertion : public HloModulePass {
using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
public:
// The exponent_bits and mantissa_bits arguments specify the parameters of
// the instructions to insert. The instructions will be inserted after each
// instruction with an opcode for which the instruction_filter_function
// function returns true and the output type is F32.
explicit ReducePrecisionInsertion(
const int exponent_bits, const int mantissa_bits,
const HloReducePrecisionOptions::Location location,
const InstructionFilterFunction& instruction_filter_function)
: exponent_bits_(exponent_bits),
mantissa_bits_(mantissa_bits),
location_(location),
instruction_filter_function_(instruction_filter_function) {}
// Version of the constructor that takes an HloReducePrecisionOptions proto
// rather than explicitly-enumerated parameters, for convenience when
// creating passes based on DebugOptions.
explicit ReducePrecisionInsertion(
const HloReducePrecisionOptions& reduce_precision_options)
: exponent_bits_(reduce_precision_options.exponent_bits()),
mantissa_bits_(reduce_precision_options.mantissa_bits()),
location_(reduce_precision_options.location()),
instruction_filter_function_(
make_filter_function(reduce_precision_options)) {}
~ReducePrecisionInsertion() override{};
absl::string_view name() const override {
return "reduce-precision-insertion";
}
// Run the pass on the given module. Returns whether the module was changed
// (reduce-precision instructions were inserted).
StatusOr<bool> Run(HloModule* module) override;
// Convert between the (inconvenient) xla.proto HloReducePrecisionOptions
// representation and InstructionFilterFunction functions.
static InstructionFilterFunction make_filter_function(
const HloReducePrecisionOptions& reduce_precision_options);
static HloReducePrecisionOptions make_options_proto(
const HloReducePrecisionOptions::Location location,
const int exponent_bits, const int mantissa_bits,
const std::function<bool(HloOpcode)>& opcode_filter_function,
const std::vector<string>& opname_substring_list = {});
// Enumeration to control which passes should be added.
enum class PassTiming { BEFORE_OPTIMIZATION, AFTER_FUSION };
// Add ReducePrecisionInsertion passes to an HloPassPipeline based on the list
// of HloReducePrecisionOptions in a DebugOptions proto. Returns true if any
// passes were added.
static bool AddPasses(HloPassPipeline* pipeline,
const DebugOptions& debug_options,
const PassTiming pass_timing);
private:
// Select the instructions that should have reduce-precision operations
// attached to them.
std::vector<HloInstruction*> instructions_to_modify(
const HloComputation* computation);
// Insert a reduce-precision operation into the graph on the output of the
// given instruction.
StatusOr<bool> insert_after(HloInstruction* instruction);
// Insert reduce-precision operations into the graph on the inputs of the
// given instructions. (For fusion instructions, the operations will be
// inserted inside the fusion computation, on the outputs of the relevant
// input parameters.)
StatusOr<bool> insert_on_inputs(
const std::vector<HloInstruction*>& instructions);
// Insert reduce-precision operations into the graph on the outputs of the
// given instructions. (For fusion instructions, the operations will be
// inserted inside the fusion computation as a new root.)
StatusOr<bool> insert_on_outputs(
const std::vector<HloInstruction*>& instructions);
// Is this shape valid for inserting a reduce-precision operation?
bool is_valid_shape(const Shape& shape) {
// For now, ReducePrecision is only implemented for F32 arrays, so this
// ignores instructions that produce other data. In particular, this
// currently ignores instructions producing tuples, even if those tuples
// contain F32 arrays inside them. The assumption is that in most cases
// equivalent behavior can be obtained by adding ReducePrecision
// instructions after the instructions that pull the F32 arrays out of
// the tuples.
return shape.element_type() == PrimitiveType::F32;
}
// Is this instruction one such that following or preceding it with a new
// reduce-precision operation will be redundant?
bool is_redundant(const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kReducePrecision &&
instruction->exponent_bits() <= exponent_bits_ &&
instruction->mantissa_bits() <= mantissa_bits_;
}
// Parameters for the precision reduction to be added.
const int exponent_bits_;
const int mantissa_bits_;
// Pass "timing" parameter. This also controls aspects of how the pass
// selects locations to insert instructions.
const HloReducePrecisionOptions::Location location_;
// User-provided Function to determine whether a given instruction should
// have a reduce-precision instruction inserted in its output stream.
const InstructionFilterFunction instruction_filter_function_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_

View File

@ -1,578 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace op = xla::testing::opcode_matchers;
namespace xla {
using ::testing::UnorderedElementsAre;
class ReducePrecisionInsertionTest : public HloTestBase {
protected:
bool InsertOps(HloModule* module,
const HloReducePrecisionOptions::Location location,
const std::function<bool(const HloInstruction*)>& filter) {
ReducePrecisionInsertion op_insertion(5, 10, location, filter);
StatusOr<bool> result = op_insertion.Run(module);
EXPECT_IS_OK(result.status());
return result.ValueOrDie();
}
};
TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
// Create a simple graph with a parameter feeding a unary cosine function.
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected state before adding ops.
EXPECT_EQ(computation->root_instruction(), b);
EXPECT_EQ(b->operand(0), a);
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos;
}));
// Confirm expected graph after adding ops.
EXPECT_EQ(computation->root_instruction(), b);
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
}
TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {});
// Create a simple graph with a parameter feeding a unary cosine function.
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected state before adding ops.
EXPECT_EQ(computation->root_instruction(), b);
EXPECT_EQ(b->operand(0), a);
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos;
}));
// Confirm expected graph after adding ops.
EXPECT_EQ(computation->root_instruction(), b);
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
}
TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
// Create a simple graph with parameter feeding a binary add function.
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected state before adding ops.
EXPECT_EQ(computation->root_instruction(), c);
EXPECT_EQ(c->operand(0), a);
EXPECT_EQ(c->operand(1), b);
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kAdd;
}));
// Confirm expected graph after adding ops.
EXPECT_EQ(computation->root_instruction(), c);
EXPECT_THAT(c->operand(0), op::ReducePrecision(a));
EXPECT_THAT(c->operand(1), op::ReducePrecision(b));
}
TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
// Create a simple graph with a parameter feeding a unary cosine function.
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected state before adding ops.
EXPECT_EQ(computation->root_instruction(), b);
EXPECT_EQ(b->operand(0), a);
EXPECT_FALSE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() ==
HloOpcode::kParameter;
}));
// Confirm that graph has not changed.
EXPECT_EQ(computation->root_instruction(), b);
EXPECT_EQ(b->operand(0), a);
}
TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
// Create a simple graph with parameter feeding a binary add function.
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kSin, a));
HloInstruction* d = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected state before adding ops.
EXPECT_EQ(computation->root_instruction(), d);
EXPECT_EQ(b->operand(0), a);
EXPECT_EQ(c->operand(0), a);
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos ||
instruction->opcode() == HloOpcode::kSin;
}));
// Confirm expected graph after adding ops. In particular, we want to confirm
// that the reduced-precision operation added for the input to b is re-used
// for the input to c.
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
EXPECT_THAT(c->operand(0), op::ReducePrecision(a));
EXPECT_EQ(b->operand(0), c->operand(0));
}
TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
// Create a simple graph with a parameter feeding a unary cosine function.
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected state before adding ops.
EXPECT_EQ(computation->root_instruction(), b);
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos;
}));
// Confirm expected graph after adding ops.
EXPECT_THAT(computation->root_instruction(), op::ReducePrecision(b));
}
TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
// Create a graph with two parameters feeding into unary cosine functions,
// and the output of those feeds into an add function. Feeding the outputs
// from the suffixed cosine functions into a binary add function allows us to
// confirm that the separate operand streams are not crossed when the new
// instructions are inserted.
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* a_cos = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
HloInstruction* b =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
HloInstruction* b_cos = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, b));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
// Confirm expected graph before adding ops.
EXPECT_EQ(c->operand(0), a_cos);
EXPECT_EQ(c->operand(1), b_cos);
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos;
}));
// Confirm expected graph after adding ops.
EXPECT_THAT(c->operand(0), op::ReducePrecision());
EXPECT_EQ(c->operand(0)->operand(0), a_cos);
EXPECT_THAT(c->operand(1), op::ReducePrecision());
EXPECT_EQ(c->operand(1)->operand(0), b_cos);
}
TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(S32, {4});
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
HloInstruction* y = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
auto module = CreateNewUnverifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected graph before adding ops.
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
EXPECT_EQ(computation->root_instruction(), y);
// Since none of the instructions produce F32 data, this should not change
// the graph.
EXPECT_FALSE(
InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
[](const HloInstruction* instruction) { return true; }));
// Confirm that graph has not changed.
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
EXPECT_EQ(computation->root_instruction(), y);
}
TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
HloInstruction* y = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected graph before adding ops.
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
EXPECT_EQ(computation->root_instruction(), y);
// Since none of the instructions match the should_reduce_output_precision
// function, this should not change the graph.
EXPECT_FALSE(
InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
[](const HloInstruction* instruction) { return false; }));
// Confirm that graph has not changed.
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
EXPECT_EQ(computation->root_instruction(), y);
}
TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateReducePrecision(shape, a, 8, 23));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected state before adding ops.
EXPECT_EQ(computation->root_instruction(), b);
// This should insert a new ReducePrecision after the existing one, but
// should not then recurse by adding another after the just-inserted one.
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() ==
HloOpcode::kReducePrecision;
}));
// Confirm expected graph after adding ops.
EXPECT_THAT(computation->root_instruction(), op::ReducePrecision());
EXPECT_EQ(computation->root_instruction()->operand(0), b);
}
TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
HloInstruction* y = builder.AddInstruction(
HloInstruction::CreateReducePrecision(shape, x, 5, 10));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected graph before adding ops.
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
EXPECT_EQ(computation->root_instruction(), y);
// Since the new reduce-precision operation would be redundant, this
// should not change the graph.
EXPECT_FALSE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() ==
HloOpcode::kParameter;
}));
// Confirm that graph has not changed.
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
EXPECT_EQ(computation->root_instruction(), y);
}
TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
HloInstruction* y = builder.AddInstruction(
HloInstruction::CreateReducePrecision(shape, x, 8, 23));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Confirm expected graph before adding ops.
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
EXPECT_EQ(computation->root_instruction(), y);
// Since the new reduce-precision operation is not the same as the existing
// one, this should add a new one.
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kParameter;
}));
// Confirm that graph is as expected.
EXPECT_EQ(computation->root_instruction(), y);
EXPECT_THAT(y->operand(0), op::ReducePrecision(x));
}
TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
HloInstruction* y = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Manually fuse the kCos operation into a fusion operation.
HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
shape, HloInstruction::FusionKind::kLoop, y));
EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
EXPECT_IS_OK(computation->RemoveInstruction(y));
// Confirm expected graph before adding reduce-precision ops.
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
EXPECT_EQ(computation->root_instruction(), z);
HloInstruction* y_fused = z->fused_expression_root();
EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
// The ReducePrecisionInsertion pass should not see inside the fusion
// operation, so this should not change the graph.
EXPECT_FALSE(InsertOps(module.get(),
HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos;
}));
// Confirm that graph has not changed.
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
EXPECT_EQ(computation->root_instruction(), z);
EXPECT_EQ(z->fused_expression_root(), y_fused);
}
TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
HloInstruction* y = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Manually fuse the kCos operation into a fusion operation.
HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
shape, HloInstruction::FusionKind::kLoop, y));
EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
EXPECT_IS_OK(computation->RemoveInstruction(y));
// Confirm expected graph before adding reduce-precision ops.
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
EXPECT_EQ(computation->root_instruction(), z);
HloInstruction* y_fused = z->fused_expression_root();
EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
// This should see that the fusion computation contains a kCos operation,
// and insert a new reduce-precision node at its input.
EXPECT_TRUE(InsertOps(module.get(),
HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos;
}));
// This should refuse to insert a second reduce-precision operation, as
// it would be redundant with the first.
EXPECT_FALSE(InsertOps(module.get(),
HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos;
}));
// Confirm that the top-level computation still only contains the fusion
// instruction, but that the fused computation now has a reduce-precision
// instruction inserted after its parameter instruction.
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
EXPECT_EQ(computation->root_instruction(), z);
EXPECT_THAT(z->fused_expression_root(), y_fused);
EXPECT_THAT(y_fused->operand(0), op::ReducePrecision(op::Parameter()));
}
TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
HloInstruction* y = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Manually fuse the kCos operation into a fusion operation.
HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
shape, HloInstruction::FusionKind::kLoop, y));
EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
EXPECT_IS_OK(computation->RemoveInstruction(y));
// Confirm expected graph before adding reduce-precision ops.
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
EXPECT_EQ(computation->root_instruction(), z);
HloInstruction* y_fused = z->fused_expression_root();
EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
// This should see that the fusion computation contains a kCos operation,
// and insert a new reduce-precision node at its root.
EXPECT_TRUE(InsertOps(module.get(),
HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos;
}));
// This should refuse to insert a second reduce-precision operation, as
// it would be redundant with the first.
EXPECT_FALSE(InsertOps(module.get(),
HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT,
[](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCos;
}));
// Confirm that the top-level computation still only contains the fusion
// instruction, but that the fused computation now has a reduce-precision
// instruction inserted as its root.
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
EXPECT_EQ(computation->root_instruction(), z);
EXPECT_THAT(z->fused_expression_root(), op::ReducePrecision(y_fused));
}
TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionNoSubstrings) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kSin, a));
auto options_proto = ReducePrecisionInsertion::make_options_proto(
HloReducePrecisionOptions::OP_OUTPUTS, 5, 10,
[](const HloOpcode opcode) { return opcode == HloOpcode::kCos; });
auto filter_function =
ReducePrecisionInsertion::make_filter_function(options_proto);
EXPECT_TRUE(filter_function(b));
EXPECT_FALSE(filter_function(c));
}
TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionWithSubstrings) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4});
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
OpMetadata b_metadata;
b_metadata.set_op_name("FlowTensor/foom");
b->set_metadata(b_metadata);
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
OpMetadata c_metadata;
c_metadata.set_op_name("FlowTensor/barn");
c->set_metadata(c_metadata);
auto options_proto = ReducePrecisionInsertion::make_options_proto(
HloReducePrecisionOptions::OP_OUTPUTS, 5, 10,
[](const HloOpcode opcode) { return opcode == HloOpcode::kCos; },
{"foo", "baz"});
auto filter_function =
ReducePrecisionInsertion::make_filter_function(options_proto);
EXPECT_TRUE(filter_function(b));
EXPECT_FALSE(filter_function(c));
}
} // namespace xla

View File

@ -910,7 +910,6 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@ -532,138 +531,5 @@ void ReducedPrecisionAccuracyTest::DoIt(
INSTANTIATE_TEST_CASE_P(ReducedPrecisionAccuracyTest,
ReducedPrecisionAccuracyTest, ::testing::Range(0, 4));
// Tests to confirm that the compiler optimization functions add the expected
// ReducePrecisionInsertion passes.
class ReducePrecisionInsertionTest : public ClientLibraryTestBase {};
// The interpreter has no fusion pass, so skip this test.
XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) {
XlaBuilder builder(TestName());
Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// Abs doesn't affect resolution.
auto abs = Abs(a);
// Near 1.0, Log(x) approximates x - 1; this lets us confirm that the
// reduce-precision operation showed up in the correct place in the
// graph.
Log(abs);
// Insert precision-reduction after the Abs(x) operation, rounding that
// result to exactly 1.0f.
auto reduce_precision_pass = execution_options_.mutable_debug_options()
->add_hlo_reduce_precision_options();
*reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
HloReducePrecisionOptions::OP_OUTPUTS, 5, 10,
[](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; });
ComputeAndCompareR1<float>(&builder, {0.0f}, {a_data.get()});
}
// The interpreter has no fusion pass, so skip this test.
XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) {
XlaBuilder builder(TestName());
Literal a_literal = LiteralUtil::CreateR1<float>({1.00001, 1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
Neg(abs);
// Add a pass after operation fusion, suffixing kAbs operations. This
// should not see into the fusion nodes and thus should not affect the
// result.
auto reduce_precision_pass = execution_options_.mutable_debug_options()
->add_hlo_reduce_precision_options();
*reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, 5, 10,
[](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; });
ComputeAndCompareR1<float>(&builder, {-1.00001f, -1.00001f}, {a_data.get()});
}
// The interpreter has no fusion pass, so skip this test.
XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) {
XlaBuilder builder(TestName());
Literal a_literal = LiteralUtil::CreateR1<float>({1.00001, 1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
Neg(abs);
// Add a pass after operation fusion, suffixing kFusion operations.
auto reduce_precision_pass = execution_options_.mutable_debug_options()
->add_hlo_reduce_precision_options();
*reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, 5, 10,
[](const HloOpcode opcode) { return opcode == HloOpcode::kFusion; });
ComputeAndCompareR1<float>(&builder, {-1.0f, -1.0f}, {a_data.get()});
}
// The interpreter has no fusion pass, so skip this test.
XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) {
XlaBuilder builder(TestName());
Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
Neg(abs);
// Add a pass suffixing fusion nodes containing kCos operations. This
// should have no effect.
auto reduce_precision_pass = execution_options_.mutable_debug_options()
->add_hlo_reduce_precision_options();
*reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT, 5, 10,
[](const HloOpcode opcode) { return opcode == HloOpcode::kCos; });
ComputeAndCompareR1<float>(&builder, {-1.00001f}, {a_data.get()});
}
// The interpreter has no fusion pass, so skip this test.
XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) {
XlaBuilder builder(TestName());
Literal a_literal = LiteralUtil::CreateR1<float>({1.00001, 1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
Neg(abs);
// Add a pass suffixing fusion nodes containing kAbs operations. This
// should see the kAbs operation within the above fusion node.
auto reduce_precision_pass = execution_options_.mutable_debug_options()
->add_hlo_reduce_precision_options();
*reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT, 5, 10,
[](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; });
ComputeAndCompareR1<float>(&builder, {-1.0f, -1.0f}, {a_data.get()});
}
} // namespace
} // namespace xla

View File

@ -20,44 +20,6 @@ package xla;
import "tensorflow/compiler/xla/service/hlo.proto";
import "tensorflow/compiler/xla/xla_data.proto";
// Options for the HLO insert-reduce-precision-operations pass.
message HloReducePrecisionOptions {
// Where and when the reduce-precision operations will be added.
enum Location {
// Add reduce-precision operations to the inputs of selected instructions.
// This is done before any optimization occurs.
OP_INPUTS = 0;
// Add reduce-precision operations to the outputs of selected instructions.
// This is done before any optimization occurs.
OP_OUTPUTS = 1;
// After operation-fusion occurs, add reduce-precision operations to the
// outputs of any selected instructions that have not been fused into
// fusion instructions.
UNFUSED_OP_OUTPUTS = 2;
// After operation-fusion occurs, add reduce-precision operations to the
// outputs of any fusion instructions that contain operations matching the
// selection criteria.
FUSION_INPUTS_BY_CONTENT = 3;
// After operation-fusion occurs, add reduce-precision operations to the
// outputs of any fusion instructions that contain operations matching the
// selection criteria.
FUSION_OUTPUTS_BY_CONTENT = 4;
}
Location location = 1;
// Exponent and mantissa bit counts for the reduced precision.
uint32 exponent_bits = 2;
uint32 mantissa_bits = 3;
// Operations matching these opcodes should be suffixed with reduce-precision
// operations.
repeated uint32 opcodes_to_suffix = 4;
// Operations with names containing these substrings should be suffixed with
// reduce-precision operations.
repeated string opname_substrings_to_suffix = 5;
}
// Debugging options for XLA. These options may change at any time - there are
// no guarantees about backward or forward compatibility for these fields.
message DebugOptions {
@ -122,10 +84,7 @@ message DebugOptions {
// If true, a set of expensive LLVM optimization passes will not be run.
bool xla_llvm_disable_expensive_passes = 73;
// Options for inserting reduce-precision operations for numerical
// experimentation. This is a repeated field, as we may want to have
// multiple passes with different parameters.
repeated HloReducePrecisionOptions hlo_reduce_precision_options = 80;
reserved 80; // Was hlo_reduce_precision_options
// This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
// computation will run n! times with all permunations of layouts for the