Remove ReducePrecisionInsertion pass.
It doesn't seem to be useful anymore. Also remove the related XLA flags which could be used to enable it. PiperOrigin-RevId: 278331729 Change-Id: I3c0094b60f4b51ee3b64ec35bca96ec17f69c5f6
This commit is contained in:
parent
32d76ec3e6
commit
f211533f4d
@ -900,7 +900,6 @@ cc_library(
|
||||
[
|
||||
":parse_flags_from_env",
|
||||
":xla_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -165,15 +165,6 @@ static void AllocateFlags() {
|
||||
return true;
|
||||
};
|
||||
|
||||
// Custom "sub-parser" lambda for xla_reduce_precision.
|
||||
auto setter_for_xla_reduce_precision =
|
||||
[](string reduce_precision_option_value) {
|
||||
HloReducePrecisionOptions* option_proto =
|
||||
flag_values->add_hlo_reduce_precision_options();
|
||||
return parse_xla_reduce_precision_option(option_proto,
|
||||
reduce_precision_option_value);
|
||||
};
|
||||
|
||||
// Custom "sub-parser" for xla_fuel. Note that ConsumeFuel does not do any
|
||||
// locking on the fuel global variables. This means that it's
|
||||
// illegal/undefined behavior to modify this flag value while the compiler is
|
||||
@ -389,19 +380,6 @@ static void AllocateFlags() {
|
||||
"Extra options to pass to a backend; "
|
||||
"comma-separated list of 'key=val' strings (=val "
|
||||
"may be omitted); no whitespace around commas."),
|
||||
tensorflow::Flag("xla_reduce_precision", setter_for_xla_reduce_precision,
|
||||
"",
|
||||
"Directions for adding reduce-precision operations. "
|
||||
"Format is 'LOCATION=E,M:OPS;NAMES' where LOCATION is "
|
||||
"the class of locations in which to insert the "
|
||||
"operations (e.g., 'OP_OUTPUTS'), E and M are the "
|
||||
"exponent and matissa bit counts respectively, and "
|
||||
"OPS and NAMES are comma-separated (no spaces) lists "
|
||||
"of the operation types and names to which to attach "
|
||||
"the reduce-precision operations. The NAMES string "
|
||||
"and its preceding ';' may be omitted. This option "
|
||||
"may be repeated to define multiple sets of added "
|
||||
"reduce-precision operations."),
|
||||
tensorflow::Flag(
|
||||
"xla_gpu_use_cudnn_batchnorm",
|
||||
bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm),
|
||||
|
@ -16,28 +16,29 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
template <typename T>
|
||||
void parse_xla_backend_extra_options(T* extra_options_map,
|
||||
string comma_separated_values) {
|
||||
std::vector<string> extra_options_parts =
|
||||
std::string comma_separated_values) {
|
||||
std::vector<std::string> extra_options_parts =
|
||||
absl::StrSplit(comma_separated_values, ',');
|
||||
|
||||
// The flag contains a comma-separated list of options; some options
|
||||
// have arguments following "=", some don't.
|
||||
for (const auto& part : extra_options_parts) {
|
||||
size_t eq_pos = part.find_first_of('=');
|
||||
if (eq_pos == string::npos) {
|
||||
if (eq_pos == std::string::npos) {
|
||||
(*extra_options_map)[part] = "";
|
||||
} else {
|
||||
string value = "";
|
||||
std::string value = "";
|
||||
if (eq_pos + 1 < part.size()) {
|
||||
value = part.substr(eq_pos + 1);
|
||||
}
|
||||
@ -46,98 +47,6 @@ void parse_xla_backend_extra_options(T* extra_options_map,
|
||||
}
|
||||
}
|
||||
|
||||
// The --xla_reduce_precision option has the format "LOCATION=E,M:OPS;NAME",
|
||||
// where LOCATION is an HloReducePrecisionOptions::location, E and M are
|
||||
// integers for the exponent and matissa bit counts respectively, and OPS and
|
||||
// NAMES are comma-separated of the operation types and names to which to
|
||||
// attach the reduce-precision operations. The OPS values are matches to the
|
||||
// strings produced by HloOpcodeString, while the NAME values are arbitrary
|
||||
// strings subject to the requirements that they not contain any of "=,:;".
|
||||
// The NAME string (with its preceding semicolon) is optional.
|
||||
inline bool parse_xla_reduce_precision_option(
|
||||
HloReducePrecisionOptions* options, string option_string) {
|
||||
// Split off "LOCATION" from remainder of string.
|
||||
std::vector<string> eq_split = absl::StrSplit(option_string, '=');
|
||||
if (eq_split.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
string& location = eq_split[0];
|
||||
if (location == "OP_INPUTS") {
|
||||
options->set_location(HloReducePrecisionOptions::OP_INPUTS);
|
||||
} else if (location == "OP_OUTPUTS") {
|
||||
options->set_location(HloReducePrecisionOptions::OP_OUTPUTS);
|
||||
} else if (location == "UNFUSED_OP_OUTPUTS") {
|
||||
options->set_location(HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS);
|
||||
} else if (location == "FUSION_INPUTS_BY_CONTENT") {
|
||||
options->set_location(HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT);
|
||||
} else if (location == "FUSION_OUTPUTS_BY_CONTENT") {
|
||||
options->set_location(HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Split off "E,M" from remainder of string.
|
||||
std::vector<string> colon_split = absl::StrSplit(eq_split[1], ':');
|
||||
if (colon_split.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Split E and M, and parse.
|
||||
std::vector<int32> bitsizes;
|
||||
for (const auto& s : absl::StrSplit(colon_split[0], ',')) {
|
||||
bitsizes.emplace_back();
|
||||
if (!absl::SimpleAtoi(s, &bitsizes.back())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
options->set_exponent_bits(bitsizes[0]);
|
||||
options->set_mantissa_bits(bitsizes[1]);
|
||||
|
||||
// Split off OPS comma-separated list from remainder of string, if the
|
||||
// remainder exists.
|
||||
std::vector<string> semicolon_split = absl::StrSplit(colon_split[1], ';');
|
||||
if (semicolon_split.size() > 2) {
|
||||
return false;
|
||||
}
|
||||
// The opcode values are either 'all' (meaning all opcodes), or matches to
|
||||
// the strings returned by HloOpcodeString. An empty string is also
|
||||
// interpreted as 'all', for convenience. Note that 'all' may not be part
|
||||
// of a comma-separated list; it must stand alone.
|
||||
string& opcode_string = semicolon_split[0];
|
||||
if (opcode_string == "" || opcode_string == "all") {
|
||||
for (int i = 0; i < HloOpcodeCount(); i++) {
|
||||
options->add_opcodes_to_suffix(i);
|
||||
}
|
||||
} else {
|
||||
std::vector<string> opcodes = absl::StrSplit(opcode_string, ',');
|
||||
for (const string& opcode : opcodes) {
|
||||
bool found = false;
|
||||
for (int i = 0; i < HloOpcodeCount(); i++) {
|
||||
if (opcode == HloOpcodeString(static_cast<HloOpcode>(i))) {
|
||||
options->add_opcodes_to_suffix(i);
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the NAMES string, if it exists.
|
||||
if (semicolon_split.size() == 2) {
|
||||
std::vector<string> opnames = absl::StrSplit(semicolon_split[1], ',');
|
||||
for (const string& opname : opnames) {
|
||||
if (opname.length() > 0) {
|
||||
options->add_opname_substrings_to_suffix(opname);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_
|
||||
|
@ -36,65 +36,6 @@ TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) {
|
||||
EXPECT_EQ(test_map.at("ee"), "ff=gg");
|
||||
}
|
||||
|
||||
// Test that the xla_reduce_precision flag is parsed correctly.
|
||||
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) {
|
||||
HloReducePrecisionOptions proto;
|
||||
string test_string = "OP_OUTPUTS=5,10:add,dot";
|
||||
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
|
||||
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS);
|
||||
EXPECT_EQ(proto.exponent_bits(), 5);
|
||||
EXPECT_EQ(proto.mantissa_bits(), 10);
|
||||
EXPECT_EQ(proto.opcodes_to_suffix_size(), 2);
|
||||
EXPECT_EQ(static_cast<HloOpcode>(proto.opcodes_to_suffix(0)),
|
||||
HloOpcode::kAdd);
|
||||
EXPECT_EQ(static_cast<HloOpcode>(proto.opcodes_to_suffix(1)),
|
||||
HloOpcode::kDot);
|
||||
EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 0);
|
||||
}
|
||||
|
||||
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) {
|
||||
HloReducePrecisionOptions proto;
|
||||
string test_string = "OP_OUTPUTS=5,10:add,dot;";
|
||||
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
|
||||
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS);
|
||||
EXPECT_EQ(proto.exponent_bits(), 5);
|
||||
EXPECT_EQ(proto.mantissa_bits(), 10);
|
||||
EXPECT_EQ(proto.opcodes_to_suffix_size(), 2);
|
||||
EXPECT_EQ(static_cast<HloOpcode>(proto.opcodes_to_suffix(0)),
|
||||
HloOpcode::kAdd);
|
||||
EXPECT_EQ(static_cast<HloOpcode>(proto.opcodes_to_suffix(1)),
|
||||
HloOpcode::kDot);
|
||||
EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 0);
|
||||
}
|
||||
|
||||
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) {
|
||||
HloReducePrecisionOptions proto;
|
||||
string test_string = "UNFUSED_OP_OUTPUTS=5,10:;foo,bar/baz";
|
||||
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
|
||||
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS);
|
||||
EXPECT_EQ(proto.exponent_bits(), 5);
|
||||
EXPECT_EQ(proto.mantissa_bits(), 10);
|
||||
EXPECT_EQ(proto.opcodes_to_suffix_size(), HloOpcodeCount());
|
||||
EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 2);
|
||||
EXPECT_EQ(proto.opname_substrings_to_suffix(0), "foo");
|
||||
EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz");
|
||||
}
|
||||
|
||||
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) {
|
||||
HloReducePrecisionOptions proto;
|
||||
string test_string = "UNFUSED_OP_OUTPUTS=5,10:subtract;foo,bar/baz";
|
||||
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
|
||||
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS);
|
||||
EXPECT_EQ(proto.exponent_bits(), 5);
|
||||
EXPECT_EQ(proto.mantissa_bits(), 10);
|
||||
EXPECT_EQ(proto.opcodes_to_suffix_size(), 1);
|
||||
EXPECT_EQ(static_cast<HloOpcode>(proto.opcodes_to_suffix(0)),
|
||||
HloOpcode::kSubtract);
|
||||
EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 2);
|
||||
EXPECT_EQ(proto.opname_substrings_to_suffix(0), "foo");
|
||||
EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz");
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
|
@ -3763,36 +3763,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "reduce_precision_insertion",
|
||||
srcs = ["reduce_precision_insertion.cc"],
|
||||
hdrs = ["reduce_precision_insertion.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "reduce_precision_insertion_test",
|
||||
size = "small",
|
||||
srcs = ["reduce_precision_insertion_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":reduce_precision_insertion",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_runner",
|
||||
srcs = ["hlo_runner.cc"],
|
||||
|
@ -135,7 +135,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:indexed_array_analysis",
|
||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:rng_expander",
|
||||
"//tensorflow/compiler/xla/service:sort_simplifier",
|
||||
|
@ -93,7 +93,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/rng_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/scatter_expander.h"
|
||||
@ -251,10 +250,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pipeline.AddPass<DynamicIndexSplitter>();
|
||||
pipeline.AddPass<CpuHloSupportChecker>();
|
||||
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
&pipeline, module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
|
||||
|
||||
pipeline.AddPass<ConditionalToSelect>();
|
||||
pipeline.AddPass<MapInliner>();
|
||||
|
||||
@ -337,9 +332,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
|
||||
pipeline.AddPass<CpuInstructionFusion>();
|
||||
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
&pipeline, module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
|
||||
return pipeline.Run(module).status();
|
||||
}
|
||||
|
||||
|
@ -1130,7 +1130,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:rng_expander",
|
||||
"//tensorflow/compiler/xla/service:slice_sinker",
|
||||
|
@ -78,7 +78,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/rng_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/slice_sinker.h"
|
||||
@ -136,9 +135,6 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
|
||||
pipeline.AddPass<DynamicIndexSplitter>();
|
||||
pipeline.AddPass<GpuHloSupportChecker>();
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
&pipeline, hlo_module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
|
||||
|
||||
// TODO(b/64094172): make Call work on GPU instead of inlining.
|
||||
pipeline.AddPass<CallInliner>();
|
||||
@ -263,24 +259,6 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
/*only_fusion_computations=*/true);
|
||||
fusion.AddPass<HloDCE>();
|
||||
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
|
||||
|
||||
HloPassPipeline reduce_pipeline("reduce-precision");
|
||||
/* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
|
||||
* fixing the ticket. */
|
||||
reduce_pipeline.AddInvariantChecker<HloVerifier>(
|
||||
/*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false,
|
||||
LayoutAssignment::InstructionCanChangeLayout);
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
&reduce_pipeline, hlo_module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
|
||||
StatusOr<bool> reduce_result = reduce_pipeline.Run(hlo_module);
|
||||
TF_RETURN_IF_ERROR(reduce_result.status());
|
||||
|
||||
if (reduce_result.ValueOrDie()) {
|
||||
// Do another fusion pass, with the expectation that we may be able to
|
||||
// fuse the new ReducePrecision operations.
|
||||
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -51,7 +51,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:layout_assignment",
|
||||
"//tensorflow/compiler/xla/service:map_inliner",
|
||||
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:triangular_solve_expander",
|
||||
"//tensorflow/compiler/xla/service:while_loop_simplifier",
|
||||
|
@ -34,7 +34,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/interpreter/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
||||
@ -87,10 +86,6 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
|
||||
hlo_module->mutable_entry_computation_layout(),
|
||||
LayoutAssignment::InstructionCanChangeLayout);
|
||||
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
&pipeline, hlo_module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
|
||||
|
||||
return pipeline.Run(hlo_module).status();
|
||||
}
|
||||
|
||||
|
@ -1,310 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
std::vector<HloInstruction*> ReducePrecisionInsertion::instructions_to_modify(
|
||||
const HloComputation* computation) {
|
||||
std::vector<HloInstruction*> instruction_list;
|
||||
|
||||
switch (location_) {
|
||||
case HloReducePrecisionOptions::OP_INPUTS:
|
||||
case HloReducePrecisionOptions::OP_OUTPUTS:
|
||||
case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
VLOG(4) << "Visited instruction: " << instruction->ToString();
|
||||
if (instruction_filter_function_(instruction)) {
|
||||
instruction_list.push_back(instruction);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
|
||||
case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
VLOG(4) << "Visited instruction: " << instruction->ToString();
|
||||
if (instruction->opcode() != HloOpcode::kFusion) {
|
||||
continue;
|
||||
}
|
||||
for (auto* fused_instruction :
|
||||
instruction->fused_instructions_computation()->instructions()) {
|
||||
VLOG(4) << "Checking sub-instruction: "
|
||||
<< fused_instruction->ToString();
|
||||
if (instruction_filter_function_(fused_instruction)) {
|
||||
instruction_list.push_back(instruction);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
VLOG(1) << "Found " << instruction_list.size()
|
||||
<< " candidate instruction(s) for reduce-precision insertion";
|
||||
|
||||
return instruction_list;
|
||||
}
|
||||
|
||||
StatusOr<bool> ReducePrecisionInsertion::insert_after(
|
||||
HloInstruction* instruction) {
|
||||
// Check that this isn't already an equivalent operation.
|
||||
if (is_redundant(instruction)) {
|
||||
VLOG(2) << "Skipped: instruction is already an equivalent"
|
||||
" reduce-precision instruction:"
|
||||
<< instruction->ToString();
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that we haven't already inserted an equivalent reduce-precision
|
||||
// operation after this instruction. (The zero-user case occurs when this is
|
||||
// the root instruction.)
|
||||
if (instruction->user_count() > 0) {
|
||||
bool redundant_followers = true;
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
if (!is_redundant(user)) {
|
||||
redundant_followers = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (redundant_followers) {
|
||||
VLOG(2) << "Skipped: instruction already followed by equivalent"
|
||||
" reduce-precision instructions";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
HloInstruction* reduced = instruction->parent()->AddInstruction(
|
||||
HloInstruction::CreateReducePrecision(instruction->shape(), instruction,
|
||||
exponent_bits_, mantissa_bits_));
|
||||
TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(reduced));
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
|
||||
const std::vector<HloInstruction*>& instructions) {
|
||||
bool computation_changed = false;
|
||||
for (auto instruction : instructions) {
|
||||
VLOG(2) << "Adding reduce-precision operation to inputs of instruction: "
|
||||
<< instruction->ToString();
|
||||
for (int64 i = 0; i < instruction->operand_count(); i++) {
|
||||
HloInstruction* operand = instruction->mutable_operand(i);
|
||||
VLOG(2) << "Adding to operand " << i << ": " << operand;
|
||||
|
||||
if (!is_valid_shape(operand->shape())) {
|
||||
VLOG(2) << "Skipped: value is not of type F32";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (is_redundant(operand)) {
|
||||
VLOG(2) << "Skipped: operand is already an equivalent reduce-precision"
|
||||
" instruction";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (instruction->IsInputFusion() || instruction->IsLoopFusion()) {
|
||||
// Insert the reduce-precision operation inside the fusion computation,
|
||||
// after the corresponding parameter instruction.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool instruction_changed,
|
||||
insert_after(instruction->fused_instructions_computation()
|
||||
->parameter_instruction(i)));
|
||||
computation_changed |= instruction_changed;
|
||||
} else {
|
||||
// Look for an existing reduce-precision operation on the operand. (We
|
||||
// need to be careful not to create a loop, though!)
|
||||
HloInstruction* reduced = nullptr;
|
||||
for (auto& user : operand->users()) {
|
||||
if (user != instruction &&
|
||||
user->opcode() == HloOpcode::kReducePrecision &&
|
||||
user->exponent_bits() == exponent_bits_ &&
|
||||
user->mantissa_bits() == mantissa_bits_) {
|
||||
reduced = user;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If there wasn't an existing reduce-precision operation, create one.
|
||||
if (!reduced) {
|
||||
reduced = instruction->parent()->AddInstruction(
|
||||
HloInstruction::CreateReducePrecision(
|
||||
operand->shape(), operand, exponent_bits_, mantissa_bits_));
|
||||
}
|
||||
// Insert the reduce-precision operation before the operand.
|
||||
TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(i, reduced));
|
||||
computation_changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return computation_changed;
|
||||
}
|
||||
|
||||
StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
|
||||
const std::vector<HloInstruction*>& instructions) {
|
||||
bool computation_changed = false;
|
||||
for (const auto& instruction : instructions) {
|
||||
VLOG(2) << "Adding reduce-precision operation to output of instruction: "
|
||||
<< instruction->ToString();
|
||||
|
||||
if (!is_valid_shape(instruction->shape())) {
|
||||
VLOG(2) << "Skipped: value is not of type F32";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (instruction->IsLoopFusion() || instruction->IsOutputFusion()) {
|
||||
// Insert the reduce-precision operation as the last operation inside
|
||||
// the fusion computation.
|
||||
HloInstruction* fusion_root = instruction->fused_expression_root();
|
||||
VLOG(2) << "Inserting new operation after existing fusion root: "
|
||||
<< fusion_root->ToString();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(fusion_root));
|
||||
computation_changed |= instruction_changed;
|
||||
} else {
|
||||
// Insert the reduce-precision operation after the instruction.
|
||||
TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(instruction));
|
||||
computation_changed |= instruction_changed;
|
||||
}
|
||||
}
|
||||
|
||||
return computation_changed;
|
||||
}
|
||||
|
||||
StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name();
|
||||
|
||||
for (auto* computation : module->MakeNonfusionComputations()) {
|
||||
StatusOr<bool> computation_changed;
|
||||
switch (location_) {
|
||||
case HloReducePrecisionOptions::OP_INPUTS:
|
||||
case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
|
||||
computation_changed = ReducePrecisionInsertion::insert_on_inputs(
|
||||
instructions_to_modify(computation));
|
||||
break;
|
||||
|
||||
case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
|
||||
case HloReducePrecisionOptions::OP_OUTPUTS:
|
||||
case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
|
||||
computation_changed = ReducePrecisionInsertion::insert_on_outputs(
|
||||
instructions_to_modify(computation));
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(computation_changed.status());
|
||||
|
||||
if (computation_changed.ValueOrDie()) {
|
||||
changed = true;
|
||||
VLOG(3) << "Computation after reduce-precision insertion:";
|
||||
XLA_VLOG_LINES(3, computation->ToString());
|
||||
} else {
|
||||
VLOG(3) << "Computation " << computation->name() << " unchanged";
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
ReducePrecisionInsertion::InstructionFilterFunction
|
||||
ReducePrecisionInsertion::make_filter_function(
|
||||
const HloReducePrecisionOptions& reduce_precision_options) {
|
||||
// Implement the filter function with a lookup table.
|
||||
std::vector<bool> opcode_filter(HloOpcodeCount(), false);
|
||||
for (const auto& opcode : reduce_precision_options.opcodes_to_suffix()) {
|
||||
opcode_filter[opcode] = true;
|
||||
}
|
||||
if (reduce_precision_options.opname_substrings_to_suffix_size() == 0) {
|
||||
return [opcode_filter](const HloInstruction* instruction) {
|
||||
return opcode_filter[static_cast<unsigned int>(instruction->opcode())];
|
||||
};
|
||||
} else {
|
||||
std::vector<string> opname_substrings;
|
||||
for (const auto& substring :
|
||||
reduce_precision_options.opname_substrings_to_suffix()) {
|
||||
opname_substrings.push_back(substring);
|
||||
}
|
||||
return [opcode_filter,
|
||||
opname_substrings](const HloInstruction* instruction) {
|
||||
if (!opcode_filter[static_cast<unsigned int>(instruction->opcode())]) {
|
||||
return false;
|
||||
}
|
||||
const auto& opname = instruction->metadata().op_name();
|
||||
for (const auto& substring : opname_substrings) {
|
||||
if (opname.find(substring) != string::npos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto(
|
||||
const HloReducePrecisionOptions::Location location, const int exponent_bits,
|
||||
const int mantissa_bits,
|
||||
const std::function<bool(HloOpcode)>& opcode_filter_function,
|
||||
const std::vector<string>& opname_substring_list) {
|
||||
HloReducePrecisionOptions options;
|
||||
options.set_location(location);
|
||||
options.set_exponent_bits(exponent_bits);
|
||||
options.set_mantissa_bits(mantissa_bits);
|
||||
for (uint32_t opcode = 0; opcode < HloOpcodeCount(); opcode++) {
|
||||
if (opcode_filter_function(static_cast<HloOpcode>(opcode))) {
|
||||
options.add_opcodes_to_suffix(opcode);
|
||||
}
|
||||
}
|
||||
for (auto& string : opname_substring_list) {
|
||||
options.add_opname_substrings_to_suffix(string);
|
||||
}
|
||||
return options;
|
||||
}
|
||||
|
||||
bool ReducePrecisionInsertion::AddPasses(HloPassPipeline* pipeline,
|
||||
const DebugOptions& debug_options,
|
||||
const PassTiming pass_timing) {
|
||||
bool passes_added = false;
|
||||
for (const auto& pass_options :
|
||||
debug_options.hlo_reduce_precision_options()) {
|
||||
bool add_pass;
|
||||
switch (pass_options.location()) {
|
||||
case HloReducePrecisionOptions::OP_INPUTS:
|
||||
case HloReducePrecisionOptions::OP_OUTPUTS:
|
||||
add_pass = pass_timing == PassTiming::BEFORE_OPTIMIZATION;
|
||||
break;
|
||||
case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
|
||||
case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
|
||||
case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
|
||||
add_pass = pass_timing == PassTiming::AFTER_FUSION;
|
||||
break;
|
||||
default:
|
||||
add_pass = false;
|
||||
}
|
||||
if (add_pass) {
|
||||
pipeline->AddPass<ReducePrecisionInsertion>(pass_options);
|
||||
passes_added = true;
|
||||
}
|
||||
}
|
||||
return passes_added;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -1,146 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// HLO pass which inserts reduce-precision instructions into the HLO graph, for
|
||||
// purposes of experimenting with the effects of reduced-precision storage of
|
||||
// intermediate values.
|
||||
class ReducePrecisionInsertion : public HloModulePass {
|
||||
using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
|
||||
|
||||
public:
|
||||
// The exponent_bits and mantissa_bits arguments specify the parameters of
|
||||
// the instructions to insert. The instructions will be inserted after each
|
||||
// instruction with an opcode for which the instruction_filter_function
|
||||
// function returns true and the output type is F32.
|
||||
explicit ReducePrecisionInsertion(
|
||||
const int exponent_bits, const int mantissa_bits,
|
||||
const HloReducePrecisionOptions::Location location,
|
||||
const InstructionFilterFunction& instruction_filter_function)
|
||||
: exponent_bits_(exponent_bits),
|
||||
mantissa_bits_(mantissa_bits),
|
||||
location_(location),
|
||||
instruction_filter_function_(instruction_filter_function) {}
|
||||
|
||||
// Version of the constructor that takes an HloReducePrecisionOptions proto
|
||||
// rather than explicitly-enumerated parameters, for convenience when
|
||||
// creating passes based on DebugOptions.
|
||||
explicit ReducePrecisionInsertion(
|
||||
const HloReducePrecisionOptions& reduce_precision_options)
|
||||
: exponent_bits_(reduce_precision_options.exponent_bits()),
|
||||
mantissa_bits_(reduce_precision_options.mantissa_bits()),
|
||||
location_(reduce_precision_options.location()),
|
||||
instruction_filter_function_(
|
||||
make_filter_function(reduce_precision_options)) {}
|
||||
|
||||
~ReducePrecisionInsertion() override{};
|
||||
|
||||
absl::string_view name() const override {
|
||||
return "reduce-precision-insertion";
|
||||
}
|
||||
|
||||
// Run the pass on the given module. Returns whether the module was changed
|
||||
// (reduce-precision instructions were inserted).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
// Convert between the (inconvenient) xla.proto HloReducePrecisionOptions
|
||||
// representation and InstructionFilterFunction functions.
|
||||
static InstructionFilterFunction make_filter_function(
|
||||
const HloReducePrecisionOptions& reduce_precision_options);
|
||||
static HloReducePrecisionOptions make_options_proto(
|
||||
const HloReducePrecisionOptions::Location location,
|
||||
const int exponent_bits, const int mantissa_bits,
|
||||
const std::function<bool(HloOpcode)>& opcode_filter_function,
|
||||
const std::vector<string>& opname_substring_list = {});
|
||||
|
||||
// Enumeration to control which passes should be added.
|
||||
enum class PassTiming { BEFORE_OPTIMIZATION, AFTER_FUSION };
|
||||
|
||||
// Add ReducePrecisionInsertion passes to an HloPassPipeline based on the list
|
||||
// of HloReducePrecisionOptions in a DebugOptions proto. Returns true if any
|
||||
// passes were added.
|
||||
static bool AddPasses(HloPassPipeline* pipeline,
|
||||
const DebugOptions& debug_options,
|
||||
const PassTiming pass_timing);
|
||||
|
||||
private:
|
||||
// Select the instructions that should have reduce-precision operations
|
||||
// attached to them.
|
||||
std::vector<HloInstruction*> instructions_to_modify(
|
||||
const HloComputation* computation);
|
||||
|
||||
// Insert a reduce-precision operation into the graph on the output of the
|
||||
// given instruction.
|
||||
StatusOr<bool> insert_after(HloInstruction* instruction);
|
||||
|
||||
// Insert reduce-precision operations into the graph on the inputs of the
|
||||
// given instructions. (For fusion instructions, the operations will be
|
||||
// inserted inside the fusion computation, on the outputs of the relevant
|
||||
// input parameters.)
|
||||
StatusOr<bool> insert_on_inputs(
|
||||
const std::vector<HloInstruction*>& instructions);
|
||||
|
||||
// Insert reduce-precision operations into the graph on the outputs of the
|
||||
// given instructions. (For fusion instructions, the operations will be
|
||||
// inserted inside the fusion computation as a new root.)
|
||||
StatusOr<bool> insert_on_outputs(
|
||||
const std::vector<HloInstruction*>& instructions);
|
||||
|
||||
// Is this shape valid for inserting a reduce-precision operation?
|
||||
bool is_valid_shape(const Shape& shape) {
|
||||
// For now, ReducePrecision is only implemented for F32 arrays, so this
|
||||
// ignores instructions that produce other data. In particular, this
|
||||
// currently ignores instructions producing tuples, even if those tuples
|
||||
// contain F32 arrays inside them. The assumption is that in most cases
|
||||
// equivalent behavior can be obtained by adding ReducePrecision
|
||||
// instructions after the instructions that pull the F32 arrays out of
|
||||
// the tuples.
|
||||
return shape.element_type() == PrimitiveType::F32;
|
||||
}
|
||||
|
||||
// Is this instruction one such that following or preceding it with a new
|
||||
// reduce-precision operation will be redundant?
|
||||
bool is_redundant(const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kReducePrecision &&
|
||||
instruction->exponent_bits() <= exponent_bits_ &&
|
||||
instruction->mantissa_bits() <= mantissa_bits_;
|
||||
}
|
||||
|
||||
// Parameters for the precision reduction to be added.
|
||||
const int exponent_bits_;
|
||||
const int mantissa_bits_;
|
||||
|
||||
// Pass "timing" parameter. This also controls aspects of how the pass
|
||||
// selects locations to insert instructions.
|
||||
const HloReducePrecisionOptions::Location location_;
|
||||
|
||||
// User-provided Function to determine whether a given instruction should
|
||||
// have a reduce-precision instruction inserted in its output stream.
|
||||
const InstructionFilterFunction instruction_filter_function_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
|
@ -1,578 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
namespace xla {
|
||||
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
class ReducePrecisionInsertionTest : public HloTestBase {
|
||||
protected:
|
||||
bool InsertOps(HloModule* module,
|
||||
const HloReducePrecisionOptions::Location location,
|
||||
const std::function<bool(const HloInstruction*)>& filter) {
|
||||
ReducePrecisionInsertion op_insertion(5, 10, location, filter);
|
||||
StatusOr<bool> result = op_insertion.Run(module);
|
||||
EXPECT_IS_OK(result.status());
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
|
||||
// Create a simple graph with a parameter feeding a unary cosine function.
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected state before adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
EXPECT_EQ(b->operand(0), a);
|
||||
|
||||
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos;
|
||||
}));
|
||||
|
||||
// Confirm expected graph after adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
// Create a simple graph with a parameter feeding a unary cosine function.
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected state before adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
EXPECT_EQ(b->operand(0), a);
|
||||
|
||||
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos;
|
||||
}));
|
||||
|
||||
// Confirm expected graph after adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
|
||||
// Create a simple graph with parameter feeding a binary add function.
|
||||
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected state before adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), c);
|
||||
EXPECT_EQ(c->operand(0), a);
|
||||
EXPECT_EQ(c->operand(1), b);
|
||||
|
||||
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kAdd;
|
||||
}));
|
||||
|
||||
// Confirm expected graph after adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), c);
|
||||
EXPECT_THAT(c->operand(0), op::ReducePrecision(a));
|
||||
EXPECT_THAT(c->operand(1), op::ReducePrecision(b));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
|
||||
// Create a simple graph with a parameter feeding a unary cosine function.
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected state before adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
EXPECT_EQ(b->operand(0), a);
|
||||
|
||||
EXPECT_FALSE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() ==
|
||||
HloOpcode::kParameter;
|
||||
}));
|
||||
|
||||
// Confirm that graph has not changed.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
EXPECT_EQ(b->operand(0), a);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
|
||||
// Create a simple graph with parameter feeding a binary add function.
|
||||
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kSin, a));
|
||||
HloInstruction* d = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected state before adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), d);
|
||||
EXPECT_EQ(b->operand(0), a);
|
||||
EXPECT_EQ(c->operand(0), a);
|
||||
|
||||
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos ||
|
||||
instruction->opcode() == HloOpcode::kSin;
|
||||
}));
|
||||
|
||||
// Confirm expected graph after adding ops. In particular, we want to confirm
|
||||
// that the reduced-precision operation added for the input to b is re-used
|
||||
// for the input to c.
|
||||
EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
|
||||
EXPECT_THAT(c->operand(0), op::ReducePrecision(a));
|
||||
EXPECT_EQ(b->operand(0), c->operand(0));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
|
||||
// Create a simple graph with a parameter feeding a unary cosine function.
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected state before adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
|
||||
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos;
|
||||
}));
|
||||
|
||||
// Confirm expected graph after adding ops.
|
||||
EXPECT_THAT(computation->root_instruction(), op::ReducePrecision(b));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
|
||||
// Create a graph with two parameters feeding into unary cosine functions,
|
||||
// and the output of those feeds into an add function. Feeding the outputs
|
||||
// from the suffixed cosine functions into a binary add function allows us to
|
||||
// confirm that the separate operand streams are not crossed when the new
|
||||
// instructions are inserted.
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* a_cos = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
|
||||
HloInstruction* b =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
|
||||
HloInstruction* b_cos = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, b));
|
||||
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected graph before adding ops.
|
||||
EXPECT_EQ(c->operand(0), a_cos);
|
||||
EXPECT_EQ(c->operand(1), b_cos);
|
||||
|
||||
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos;
|
||||
}));
|
||||
|
||||
// Confirm expected graph after adding ops.
|
||||
EXPECT_THAT(c->operand(0), op::ReducePrecision());
|
||||
EXPECT_EQ(c->operand(0)->operand(0), a_cos);
|
||||
EXPECT_THAT(c->operand(1), op::ReducePrecision());
|
||||
EXPECT_EQ(c->operand(1)->operand(0), b_cos);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(S32, {4});
|
||||
HloInstruction* x =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
HloInstruction* y = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
|
||||
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected graph before adding ops.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
|
||||
// Since none of the instructions produce F32 data, this should not change
|
||||
// the graph.
|
||||
EXPECT_FALSE(
|
||||
InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
|
||||
[](const HloInstruction* instruction) { return true; }));
|
||||
|
||||
// Confirm that graph has not changed.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* x =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
HloInstruction* y = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected graph before adding ops.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
|
||||
// Since none of the instructions match the should_reduce_output_precision
|
||||
// function, this should not change the graph.
|
||||
EXPECT_FALSE(
|
||||
InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
|
||||
[](const HloInstruction* instruction) { return false; }));
|
||||
|
||||
// Confirm that graph has not changed.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateReducePrecision(shape, a, 8, 23));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected state before adding ops.
|
||||
EXPECT_EQ(computation->root_instruction(), b);
|
||||
|
||||
// This should insert a new ReducePrecision after the existing one, but
|
||||
// should not then recurse by adding another after the just-inserted one.
|
||||
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() ==
|
||||
HloOpcode::kReducePrecision;
|
||||
}));
|
||||
|
||||
// Confirm expected graph after adding ops.
|
||||
EXPECT_THAT(computation->root_instruction(), op::ReducePrecision());
|
||||
EXPECT_EQ(computation->root_instruction()->operand(0), b);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* x =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
HloInstruction* y = builder.AddInstruction(
|
||||
HloInstruction::CreateReducePrecision(shape, x, 5, 10));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected graph before adding ops.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
|
||||
// Since the new reduce-precision operation would be redundant, this
|
||||
// should not change the graph.
|
||||
EXPECT_FALSE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() ==
|
||||
HloOpcode::kParameter;
|
||||
}));
|
||||
|
||||
// Confirm that graph has not changed.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* x =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
HloInstruction* y = builder.AddInstruction(
|
||||
HloInstruction::CreateReducePrecision(shape, x, 8, 23));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Confirm expected graph before adding ops.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
|
||||
// Since the new reduce-precision operation is not the same as the existing
|
||||
// one, this should add a new one.
|
||||
EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kParameter;
|
||||
}));
|
||||
|
||||
// Confirm that graph is as expected.
|
||||
EXPECT_EQ(computation->root_instruction(), y);
|
||||
EXPECT_THAT(y->operand(0), op::ReducePrecision(x));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* x =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
HloInstruction* y = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Manually fuse the kCos operation into a fusion operation.
|
||||
HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
|
||||
shape, HloInstruction::FusionKind::kLoop, y));
|
||||
EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
|
||||
EXPECT_IS_OK(computation->RemoveInstruction(y));
|
||||
|
||||
// Confirm expected graph before adding reduce-precision ops.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
|
||||
EXPECT_EQ(computation->root_instruction(), z);
|
||||
HloInstruction* y_fused = z->fused_expression_root();
|
||||
EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
|
||||
|
||||
// The ReducePrecisionInsertion pass should not see inside the fusion
|
||||
// operation, so this should not change the graph.
|
||||
EXPECT_FALSE(InsertOps(module.get(),
|
||||
HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos;
|
||||
}));
|
||||
|
||||
// Confirm that graph has not changed.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
|
||||
EXPECT_EQ(computation->root_instruction(), z);
|
||||
EXPECT_EQ(z->fused_expression_root(), y_fused);
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* x =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
HloInstruction* y = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Manually fuse the kCos operation into a fusion operation.
|
||||
HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
|
||||
shape, HloInstruction::FusionKind::kLoop, y));
|
||||
EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
|
||||
EXPECT_IS_OK(computation->RemoveInstruction(y));
|
||||
|
||||
// Confirm expected graph before adding reduce-precision ops.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
|
||||
EXPECT_EQ(computation->root_instruction(), z);
|
||||
HloInstruction* y_fused = z->fused_expression_root();
|
||||
EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
|
||||
|
||||
// This should see that the fusion computation contains a kCos operation,
|
||||
// and insert a new reduce-precision node at its input.
|
||||
EXPECT_TRUE(InsertOps(module.get(),
|
||||
HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos;
|
||||
}));
|
||||
|
||||
// This should refuse to insert a second reduce-precision operation, as
|
||||
// it would be redundant with the first.
|
||||
EXPECT_FALSE(InsertOps(module.get(),
|
||||
HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos;
|
||||
}));
|
||||
|
||||
// Confirm that the top-level computation still only contains the fusion
|
||||
// instruction, but that the fused computation now has a reduce-precision
|
||||
// instruction inserted after its parameter instruction.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
|
||||
EXPECT_EQ(computation->root_instruction(), z);
|
||||
EXPECT_THAT(z->fused_expression_root(), y_fused);
|
||||
EXPECT_THAT(y_fused->operand(0), op::ReducePrecision(op::Parameter()));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* x =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
HloInstruction* y = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Manually fuse the kCos operation into a fusion operation.
|
||||
HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
|
||||
shape, HloInstruction::FusionKind::kLoop, y));
|
||||
EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
|
||||
EXPECT_IS_OK(computation->RemoveInstruction(y));
|
||||
|
||||
// Confirm expected graph before adding reduce-precision ops.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
|
||||
EXPECT_EQ(computation->root_instruction(), z);
|
||||
HloInstruction* y_fused = z->fused_expression_root();
|
||||
EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
|
||||
|
||||
// This should see that the fusion computation contains a kCos operation,
|
||||
// and insert a new reduce-precision node at its root.
|
||||
EXPECT_TRUE(InsertOps(module.get(),
|
||||
HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos;
|
||||
}));
|
||||
|
||||
// This should refuse to insert a second reduce-precision operation, as
|
||||
// it would be redundant with the first.
|
||||
EXPECT_FALSE(InsertOps(module.get(),
|
||||
HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT,
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCos;
|
||||
}));
|
||||
|
||||
// Confirm that the top-level computation still only contains the fusion
|
||||
// instruction, but that the fused computation now has a reduce-precision
|
||||
// instruction inserted as its root.
|
||||
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
|
||||
EXPECT_EQ(computation->root_instruction(), z);
|
||||
EXPECT_THAT(z->fused_expression_root(), op::ReducePrecision(y_fused));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionNoSubstrings) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kSin, a));
|
||||
|
||||
auto options_proto = ReducePrecisionInsertion::make_options_proto(
|
||||
HloReducePrecisionOptions::OP_OUTPUTS, 5, 10,
|
||||
[](const HloOpcode opcode) { return opcode == HloOpcode::kCos; });
|
||||
|
||||
auto filter_function =
|
||||
ReducePrecisionInsertion::make_filter_function(options_proto);
|
||||
|
||||
EXPECT_TRUE(filter_function(b));
|
||||
EXPECT_FALSE(filter_function(c));
|
||||
}
|
||||
|
||||
TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionWithSubstrings) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4});
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
OpMetadata b_metadata;
|
||||
b_metadata.set_op_name("FlowTensor/foom");
|
||||
b->set_metadata(b_metadata);
|
||||
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
|
||||
OpMetadata c_metadata;
|
||||
c_metadata.set_op_name("FlowTensor/barn");
|
||||
c->set_metadata(c_metadata);
|
||||
|
||||
auto options_proto = ReducePrecisionInsertion::make_options_proto(
|
||||
HloReducePrecisionOptions::OP_OUTPUTS, 5, 10,
|
||||
[](const HloOpcode opcode) { return opcode == HloOpcode::kCos; },
|
||||
{"foo", "baz"});
|
||||
|
||||
auto filter_function =
|
||||
ReducePrecisionInsertion::make_filter_function(options_proto);
|
||||
|
||||
EXPECT_TRUE(filter_function(b));
|
||||
EXPECT_FALSE(filter_function(c));
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -910,7 +910,6 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
@ -532,138 +531,5 @@ void ReducedPrecisionAccuracyTest::DoIt(
|
||||
INSTANTIATE_TEST_CASE_P(ReducedPrecisionAccuracyTest,
|
||||
ReducedPrecisionAccuracyTest, ::testing::Range(0, 4));
|
||||
|
||||
// Tests to confirm that the compiler optimization functions add the expected
|
||||
// ReducePrecisionInsertion passes.
|
||||
class ReducePrecisionInsertionTest : public ClientLibraryTestBase {};
|
||||
|
||||
// The interpreter has no fusion pass, so skip this test.
|
||||
XLA_TEST_F(ReducePrecisionInsertionTest,
|
||||
DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
|
||||
std::unique_ptr<GlobalData> a_data =
|
||||
client_->TransferToServer(a_literal).ConsumeValueOrDie();
|
||||
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
|
||||
|
||||
// Abs doesn't affect resolution.
|
||||
auto abs = Abs(a);
|
||||
|
||||
// Near 1.0, Log(x) approximates x - 1; this lets us confirm that the
|
||||
// reduce-precision operation showed up in the correct place in the
|
||||
// graph.
|
||||
Log(abs);
|
||||
|
||||
// Insert precision-reduction after the Abs(x) operation, rounding that
|
||||
// result to exactly 1.0f.
|
||||
auto reduce_precision_pass = execution_options_.mutable_debug_options()
|
||||
->add_hlo_reduce_precision_options();
|
||||
*reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
|
||||
HloReducePrecisionOptions::OP_OUTPUTS, 5, 10,
|
||||
[](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; });
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {0.0f}, {a_data.get()});
|
||||
}
|
||||
|
||||
// The interpreter has no fusion pass, so skip this test.
|
||||
XLA_TEST_F(ReducePrecisionInsertionTest,
|
||||
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Literal a_literal = LiteralUtil::CreateR1<float>({1.00001, 1.00001});
|
||||
std::unique_ptr<GlobalData> a_data =
|
||||
client_->TransferToServer(a_literal).ConsumeValueOrDie();
|
||||
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
|
||||
|
||||
// These two operations should be fused by any reasonable backend.
|
||||
auto abs = Abs(a);
|
||||
Neg(abs);
|
||||
|
||||
// Add a pass after operation fusion, suffixing kAbs operations. This
|
||||
// should not see into the fusion nodes and thus should not affect the
|
||||
// result.
|
||||
auto reduce_precision_pass = execution_options_.mutable_debug_options()
|
||||
->add_hlo_reduce_precision_options();
|
||||
*reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
|
||||
HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, 5, 10,
|
||||
[](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; });
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {-1.00001f, -1.00001f}, {a_data.get()});
|
||||
}
|
||||
|
||||
// The interpreter has no fusion pass, so skip this test.
|
||||
XLA_TEST_F(ReducePrecisionInsertionTest,
|
||||
DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Literal a_literal = LiteralUtil::CreateR1<float>({1.00001, 1.00001});
|
||||
std::unique_ptr<GlobalData> a_data =
|
||||
client_->TransferToServer(a_literal).ConsumeValueOrDie();
|
||||
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
|
||||
|
||||
// These two operations should be fused by any reasonable backend.
|
||||
auto abs = Abs(a);
|
||||
Neg(abs);
|
||||
|
||||
// Add a pass after operation fusion, suffixing kFusion operations.
|
||||
auto reduce_precision_pass = execution_options_.mutable_debug_options()
|
||||
->add_hlo_reduce_precision_options();
|
||||
*reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
|
||||
HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, 5, 10,
|
||||
[](const HloOpcode opcode) { return opcode == HloOpcode::kFusion; });
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {-1.0f, -1.0f}, {a_data.get()});
|
||||
}
|
||||
|
||||
// The interpreter has no fusion pass, so skip this test.
|
||||
XLA_TEST_F(ReducePrecisionInsertionTest,
|
||||
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
|
||||
std::unique_ptr<GlobalData> a_data =
|
||||
client_->TransferToServer(a_literal).ConsumeValueOrDie();
|
||||
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
|
||||
|
||||
// These two operations should be fused by any reasonable backend.
|
||||
auto abs = Abs(a);
|
||||
Neg(abs);
|
||||
|
||||
// Add a pass suffixing fusion nodes containing kCos operations. This
|
||||
// should have no effect.
|
||||
auto reduce_precision_pass = execution_options_.mutable_debug_options()
|
||||
->add_hlo_reduce_precision_options();
|
||||
*reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
|
||||
HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT, 5, 10,
|
||||
[](const HloOpcode opcode) { return opcode == HloOpcode::kCos; });
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {-1.00001f}, {a_data.get()});
|
||||
}
|
||||
|
||||
// The interpreter has no fusion pass, so skip this test.
|
||||
XLA_TEST_F(ReducePrecisionInsertionTest,
|
||||
DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Literal a_literal = LiteralUtil::CreateR1<float>({1.00001, 1.00001});
|
||||
std::unique_ptr<GlobalData> a_data =
|
||||
client_->TransferToServer(a_literal).ConsumeValueOrDie();
|
||||
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
|
||||
|
||||
// These two operations should be fused by any reasonable backend.
|
||||
auto abs = Abs(a);
|
||||
Neg(abs);
|
||||
|
||||
// Add a pass suffixing fusion nodes containing kAbs operations. This
|
||||
// should see the kAbs operation within the above fusion node.
|
||||
auto reduce_precision_pass = execution_options_.mutable_debug_options()
|
||||
->add_hlo_reduce_precision_options();
|
||||
*reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
|
||||
HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT, 5, 10,
|
||||
[](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; });
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {-1.0f, -1.0f}, {a_data.get()});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -20,44 +20,6 @@ package xla;
|
||||
import "tensorflow/compiler/xla/service/hlo.proto";
|
||||
import "tensorflow/compiler/xla/xla_data.proto";
|
||||
|
||||
// Options for the HLO insert-reduce-precision-operations pass.
|
||||
message HloReducePrecisionOptions {
|
||||
// Where and when the reduce-precision operations will be added.
|
||||
enum Location {
|
||||
// Add reduce-precision operations to the inputs of selected instructions.
|
||||
// This is done before any optimization occurs.
|
||||
OP_INPUTS = 0;
|
||||
// Add reduce-precision operations to the outputs of selected instructions.
|
||||
// This is done before any optimization occurs.
|
||||
OP_OUTPUTS = 1;
|
||||
// After operation-fusion occurs, add reduce-precision operations to the
|
||||
// outputs of any selected instructions that have not been fused into
|
||||
// fusion instructions.
|
||||
UNFUSED_OP_OUTPUTS = 2;
|
||||
// After operation-fusion occurs, add reduce-precision operations to the
|
||||
// outputs of any fusion instructions that contain operations matching the
|
||||
// selection criteria.
|
||||
FUSION_INPUTS_BY_CONTENT = 3;
|
||||
// After operation-fusion occurs, add reduce-precision operations to the
|
||||
// outputs of any fusion instructions that contain operations matching the
|
||||
// selection criteria.
|
||||
FUSION_OUTPUTS_BY_CONTENT = 4;
|
||||
}
|
||||
Location location = 1;
|
||||
|
||||
// Exponent and mantissa bit counts for the reduced precision.
|
||||
uint32 exponent_bits = 2;
|
||||
uint32 mantissa_bits = 3;
|
||||
|
||||
// Operations matching these opcodes should be suffixed with reduce-precision
|
||||
// operations.
|
||||
repeated uint32 opcodes_to_suffix = 4;
|
||||
|
||||
// Operations with names containing these substrings should be suffixed with
|
||||
// reduce-precision operations.
|
||||
repeated string opname_substrings_to_suffix = 5;
|
||||
}
|
||||
|
||||
// Debugging options for XLA. These options may change at any time - there are
|
||||
// no guarantees about backward or forward compatibility for these fields.
|
||||
message DebugOptions {
|
||||
@ -122,10 +84,7 @@ message DebugOptions {
|
||||
// If true, a set of expensive LLVM optimization passes will not be run.
|
||||
bool xla_llvm_disable_expensive_passes = 73;
|
||||
|
||||
// Options for inserting reduce-precision operations for numerical
|
||||
// experimentation. This is a repeated field, as we may want to have
|
||||
// multiple passes with different parameters.
|
||||
repeated HloReducePrecisionOptions hlo_reduce_precision_options = 80;
|
||||
reserved 80; // Was hlo_reduce_precision_options
|
||||
|
||||
// This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
|
||||
// computation will run n! times with all permunations of layouts for the
|
||||
|
Loading…
Reference in New Issue
Block a user