Rename "Inliner" to "MapInliner".
PiperOrigin-RevId: 215801897
This commit is contained in:
parent
a2e48d849f
commit
b01ea7a51c
@ -1841,42 +1841,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inliner",
|
||||
srcs = ["inliner.cc"],
|
||||
hdrs = ["inliner.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "inliner_test",
|
||||
srcs = ["inliner_test.cc"],
|
||||
deps = [
|
||||
":cpu_plugin",
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":inliner",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "computation_placer",
|
||||
srcs = ["computation_placer.cc"],
|
||||
@ -3492,6 +3456,39 @@ cc_library(
|
||||
deps = ["//tensorflow/core:lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_inliner",
|
||||
srcs = ["map_inliner.cc"],
|
||||
hdrs = ["map_inliner.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "map_inliner_test",
|
||||
srcs = ["map_inliner_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":map_inliner",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_casting_utils_test",
|
||||
srcs = ["hlo_casting_utils_test.cc"],
|
||||
|
@ -94,6 +94,7 @@ cc_library(
|
||||
":target_machine_features",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla/service:map_inliner",
|
||||
"//tensorflow/compiler/xla/service:scatter_expander",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
@ -127,7 +128,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:indexed_array_analysis",
|
||||
"//tensorflow/compiler/xla/service:inliner",
|
||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
|
@ -86,8 +86,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/scatter_expander.h"
|
||||
@ -249,7 +249,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
&pipeline, module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
|
||||
|
||||
pipeline.AddPass<Inliner>();
|
||||
pipeline.AddPass<MapInliner>();
|
||||
|
||||
// TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
|
||||
// pass.
|
||||
|
@ -45,8 +45,8 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:inliner",
|
||||
"//tensorflow/compiler/xla/service:layout_assignment",
|
||||
"//tensorflow/compiler/xla/service:map_inliner",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:while_loop_simplifier",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -28,9 +28,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/interpreter/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -32,10 +32,10 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// InlinerVisitor traverses the HLO computation and inlines maps.
|
||||
class InlinerVisitor : public DfsHloVisitorWithDefault {
|
||||
// MapInlinerVisitor traverses the HLO computation and inlines maps.
|
||||
class MapInlinerVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
explicit InlinerVisitor(HloComputation* computation)
|
||||
explicit MapInlinerVisitor(HloComputation* computation)
|
||||
: computation_(computation) {}
|
||||
|
||||
// Default visitor action is to do nothing and return OK.
|
||||
@ -49,24 +49,23 @@ class InlinerVisitor : public DfsHloVisitorWithDefault {
|
||||
StatusOr<bool> Run(HloComputation* computation);
|
||||
|
||||
private:
|
||||
// Current HloComputation instance the InlinerVisitor is traversing.
|
||||
// Current HloComputation instance the MapInlinerVisitor is traversing.
|
||||
HloComputation* computation_;
|
||||
|
||||
// Whether algebraic simplification has occurred.
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
StatusOr<bool> InlinerVisitor::Run(HloComputation* computation) {
|
||||
StatusOr<bool> MapInlinerVisitor::Run(HloComputation* computation) {
|
||||
changed_ = false;
|
||||
computation_ = computation;
|
||||
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this));
|
||||
return changed_;
|
||||
}
|
||||
|
||||
Status InlinerVisitor::HandleMap(HloInstruction* map) {
|
||||
Status MapInlinerVisitor::HandleMap(HloInstruction* map) {
|
||||
HloComputation* function = map->to_apply();
|
||||
HloInstruction& root = *function->root_instruction();
|
||||
// TODO(b/29249531): Add DCE pass to remove unused HloComputations.
|
||||
// Only inlining functions that are simply a single operation until a better
|
||||
// profitability model for inlining is defined.
|
||||
if (hlo_query::AllOperandsAreParameters(root)) {
|
||||
@ -112,8 +111,8 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> Inliner::Run(HloModule* module) {
|
||||
InlinerVisitor visitor(/*computation=*/nullptr);
|
||||
StatusOr<bool> MapInliner::Run(HloModule* module) {
|
||||
MapInlinerVisitor visitor(/*computation=*/nullptr);
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation));
|
@ -13,27 +13,27 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A pass which performs inlining. Which can result, for example, in functions
|
||||
// that were previously being mapped by Map instead directly applied to the
|
||||
// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
|
||||
class Inliner : public HloModulePass {
|
||||
// A pass which performs map inlining. This replaces kMap instructions with
|
||||
// their equivalent sequence of array operations. For example:
|
||||
// map({X, Y}, add) -> add(X, Y)).
|
||||
class MapInliner : public HloModulePass {
|
||||
public:
|
||||
~Inliner() override = default;
|
||||
absl::string_view name() const override { return "inline"; }
|
||||
~MapInliner() override = default;
|
||||
absl::string_view name() const override { return "map-inline"; }
|
||||
|
||||
// Run inlining on the given computation. Returns whether the computation was
|
||||
// changed.
|
||||
// Run map inlining on the given computation. Returns whether the computation
|
||||
// was changed.
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
@ -35,10 +35,10 @@ namespace op = xla::testing::opcode_matchers;
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using InlinerTest = HloVerifiedTestBase;
|
||||
using MapInlinerTest = HloVerifiedTestBase;
|
||||
|
||||
// Test that `map` with `max` is transformed to `max`
|
||||
TEST_F(InlinerTest, MapMax) {
|
||||
TEST_F(MapInlinerTest, MapMax) {
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
auto max_builder = HloComputation::Builder(TestName());
|
||||
@ -63,7 +63,7 @@ TEST_F(InlinerTest, MapMax) {
|
||||
hlo_module->AddEmbeddedComputation(std::move(max_f32));
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
|
||||
Inliner inliner;
|
||||
MapInliner inliner;
|
||||
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
|
||||
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
|
||||
op::Maximum(lhs, rhs));
|
||||
@ -75,7 +75,7 @@ TEST_F(InlinerTest, MapMax) {
|
||||
}
|
||||
|
||||
// Test that `constant` function is changed to `broadcast`.
|
||||
TEST_F(InlinerTest, MapConstant) {
|
||||
TEST_F(MapInlinerTest, MapConstant) {
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
auto const2_builder = HloComputation::Builder(TestName());
|
||||
@ -97,7 +97,7 @@ TEST_F(InlinerTest, MapConstant) {
|
||||
hlo_module->AddEmbeddedComputation(std::move(const2_f32));
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
|
||||
Inliner inliner;
|
||||
MapInliner inliner;
|
||||
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
|
||||
root = hlo_module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Broadcast(op::Constant()));
|
||||
@ -108,7 +108,7 @@ TEST_F(InlinerTest, MapConstant) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
|
||||
}
|
||||
|
||||
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
||||
TEST_F(MapInlinerTest, MapSubtractOppositeOrder) {
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
// Note that the parameter ordinals are in the opposite order to their
|
||||
@ -135,7 +135,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
||||
hlo_module->AddEmbeddedComputation(std::move(max_f32));
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
|
||||
Inliner inliner;
|
||||
MapInliner inliner;
|
||||
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
|
||||
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
|
||||
op::Subtract(rhs, lhs));
|
||||
@ -146,7 +146,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
|
||||
}
|
||||
|
||||
TEST_F(InlinerTest, MapParameter) {
|
||||
TEST_F(MapInlinerTest, MapParameter) {
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
auto param_builder = HloComputation::Builder(TestName());
|
||||
@ -167,7 +167,7 @@ TEST_F(InlinerTest, MapParameter) {
|
||||
hlo_module->AddEmbeddedComputation(std::move(param_f32));
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
|
||||
Inliner inliner;
|
||||
MapInliner inliner;
|
||||
EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
|
||||
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
|
||||
|
Loading…
x
Reference in New Issue
Block a user