From 99fc31e82fc8d5a5506a8f2de4dd7eb5f7f160e2 Mon Sep 17 00:00:00 2001
From: Yunxing Dai <yunxing@google.com>
Date: Mon, 3 Aug 2020 12:35:28 -0700
Subject: [PATCH] Add a module config option to enable hlo deduplication.

PiperOrigin-RevId: 324660155
Change-Id: Ic7aac0daf851bb93b4f6c24e56b20234200efdbc
---
 .../compiler/xla/client/executable_build_options.cc   |  6 ++++++
 .../compiler/xla/client/executable_build_options.h    |  4 ++++
 .../compiler/xla/service/compile_only_service.cc      |  1 +
 tensorflow/compiler/xla/service/compiler.h            |  1 +
 tensorflow/compiler/xla/service/hlo_module.cc         |  1 +
 tensorflow/compiler/xla/service/hlo_module_config.h   | 11 +++++++++++
 tensorflow/compiler/xla/service/local_service.cc      |  1 +
 tensorflow/compiler/xla/service/service.cc            |  1 +
 tensorflow/compiler/xla/xla.proto                     |  4 ++++
 9 files changed, 30 insertions(+)

diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc
index 404f9eb7519..f39a3e79fe5 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.cc
+++ b/tensorflow/compiler/xla/client/executable_build_options.cc
@@ -76,6 +76,12 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_use_spmd_partitioning(
   return *this;
 }
 
+ExecutableBuildOptions& ExecutableBuildOptions::set_deduplicate_hlo(
+    bool deduplicate_hlo) {
+  deduplicate_hlo_ = deduplicate_hlo;
+  return *this;
+}
+
 ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment(
     const DeviceAssignment& device_assignment) {
   device_assignment_ = device_assignment;
diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index 9a7fdd974b1..d034eaa7fd6 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -82,6 +82,9 @@ class ExecutableBuildOptions {
   bool use_spmd_partitioning() const { return use_spmd_partitioning_; }
   ExecutableBuildOptions& set_use_spmd_partitioning(bool use_spmd_partitioning);
 
+  bool deduplicate_hlo() const { return deduplicate_hlo_; }
+  ExecutableBuildOptions& set_deduplicate_hlo(bool deduplicate_hlo);
+
   // If set, this specifies a static device assignment for the computation.
   // Otherwise, the computation will be compiled generically and can be run with
   // any device assignment compatible with the computation's replica and
@@ -110,6 +113,7 @@ class ExecutableBuildOptions {
   int num_replicas_ = 1;
   int num_partitions_ = 1;
   bool use_spmd_partitioning_ = false;
+  bool deduplicate_hlo_ = false;
   absl::optional<DeviceAssignment> device_assignment_;
   bool alias_passthrough_params_ = false;
 };
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index ce9c8a4ea62..f8e4f591a5d 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -92,6 +92,7 @@ CompileOnlyService::CompileAheadOfTime(
         execution_options.mutable_device_assignment()));
   }
   execution_options.set_use_spmd_partitioning(options.use_spmd_partitioning());
+  execution_options.set_deduplicate_hlo(options.deduplicate_hlo());
   for (const AotXlaComputationInstance& instance : computations) {
     TF_RET_CHECK(instance.computation.has_host_program_shape());
     *execution_options.mutable_shape_with_output_layout() =
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index 57b24e372e6..312a068ba65 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -77,6 +77,7 @@ class AotCompilationOptions {
   virtual int64 replica_count() const { return 0; }
   virtual int64 num_cores() const { return 0; }
   virtual bool use_spmd_partitioning() const { return false; }
+  virtual bool deduplicate_hlo() const { return false; }
 
   // Optional allocator that may be used for allocating temp space on the device
   // during compilation.
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 308b8e8f095..4a67c1d2146 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -443,6 +443,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
     }
     module_config.set_use_spmd_partitioning(
         execution_options->use_spmd_partitioning());
+    module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo());
     if (execution_options->has_device_assignment()) {
       TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
                           DeviceAssignment::Deserialize(
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 7ab0f24d06e..ae0a8aae838 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -138,6 +138,13 @@ class HloModuleConfig {
   }
   bool use_spmd_partitioning() const { return use_spmd_partitioning_; }
 
+  // If enabled, deduplicate equivalent hlos into function calls to reduce code
+  // size.
+  void set_deduplicate_hlo(bool deduplicate_hlo) {
+    deduplicate_hlo_ = deduplicate_hlo;
+  }
+  bool deduplicate_hlo() const { return deduplicate_hlo_; }
+
   // Return a string which unambiguously represents all the fields of this data
   // structure. Used for generating a cache key for storing the compiled
   // executable.
@@ -246,6 +253,10 @@ class HloModuleConfig {
   // needs to partition the module.
   bool use_spmd_partitioning_ = false;
 
+  // If enabled, deduplicate equivalent hlos into function calls to reduce code
+  // size.
+  bool deduplicate_hlo_ = false;
+
   // The target maximum parallelism at which to partition HLOs for parallel
   // execution on the CPU backend.
   int64 intra_op_parallelism_threads_ = -1;
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index c80646e0c70..5def5bbe9db 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -114,6 +114,7 @@ ExecutionOptions CreateExecutionOptions(
   execution_options.set_num_partitions(build_options.num_partitions());
   execution_options.set_use_spmd_partitioning(
       build_options.use_spmd_partitioning());
+  execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo());
   if (build_options.has_device_assignment()) {
     TF_CHECK_OK(build_options.device_assignment().Serialize(
         execution_options.mutable_device_assignment()));
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 2ed5e709d81..4437ec3d452 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -315,6 +315,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
     }
     config->set_use_spmd_partitioning(
         execution_options->use_spmd_partitioning());
+    config->set_deduplicate_hlo(execution_options->deduplicate_hlo());
     config->set_seed(execution_options->seed());
     config->set_launch_id(execution_options->launch_id());
     config->set_debug_options(execution_options->debug_options());
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 6b9917eac53..1cf30b10373 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -349,6 +349,10 @@ message ExecutionOptions {
   // Indicates whether to use SPMD (true) or MPMD (false) partitioning when
   // num_partitions > 1 and XLA is requested to partition the input program.
   bool use_spmd_partitioning = 11;
+
+  // If set, deduplicate hlo into function calls to reduce binary size. Only
+  // works on TPU.
+  bool deduplicate_hlo = 12;
 }
 
 message GetDeviceHandlesRequest {