diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 1d9baf5de10..535e5b605b4 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -153,6 +153,7 @@ cc_binary(
         "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/service",
+        "//tensorflow/compiler/xla/service:computation_tracker",
         "//tensorflow/compiler/xla/service:session_proto",
         "//tensorflow/core:lib",
     ],
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
index 8b96e134897..1f0ca31d6d6 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
@@ -21,6 +21,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/computation_tracker.h"
 #include "tensorflow/compiler/xla/service/service.h"
 #include "tensorflow/compiler/xla/service/session.pb.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -34,7 +35,7 @@ limitations under the License.
 namespace xla {
 namespace tools {
 
-void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
   LocalClient* client = ClientLibrary::LocalClientOrDie();
   LocalService* local_service =
       ClientLibrary::GetXlaService(client->platform());
@@ -50,23 +51,35 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
     }
     Computation computation = computation_status.ConsumeValueOrDie();
 
-    std::unique_ptr<ProgramShape> program_shape =
-        client->GetComputationShape(computation).ConsumeValueOrDie();
+    if (compile) {
+      std::unique_ptr<ProgramShape> program_shape =
+          client->GetComputationShape(computation).ConsumeValueOrDie();
 
-    std::vector<const Shape*> layouts;
-    for (int i = 0; i < program_shape->parameters_size(); ++i) {
-      layouts.push_back(&program_shape->parameters(i));
+      std::vector<const Shape*> layouts;
+      for (int i = 0; i < program_shape->parameters_size(); ++i) {
+        layouts.push_back(&program_shape->parameters(i));
+      }
+      StatusOr<std::unique_ptr<Executable>> executable =
+          local_service->CompileExecutable(
+              computation.handle(), layouts, &program_shape->result(),
+              /*device_ordinal=*/0, /*has_hybrid_result=*/true);
+
+      const HloModule& module = executable.ValueOrDie()->module();
+
+      fprintf(stdout, "HLO compiled for %s backend:\n%s\n",
+              local_service->backend().platform()->Name().c_str(),
+              module.ToString().c_str());
+    } else {
+      const ComputationTracker& tracker = local_service->computation_tracker();
+      UserComputation* user_computation =
+          tracker.Resolve(computation.handle()).ConsumeValueOrDie();
+      VersionedComputationHandle versioned_handle =
+          user_computation->GetVersionedHandle();
+      std::unique_ptr<HloModule> module =
+          tracker.BuildHloModule(versioned_handle).ConsumeValueOrDie();
+
+      fprintf(stdout, "%s\n", module->ToString().c_str());
     }
-    StatusOr<std::unique_ptr<Executable>> executable =
-        local_service->CompileExecutable(
-            computation.handle(), layouts, &program_shape->result(),
-            /*device_ordinal=*/0, /*has_hybrid_result=*/true);
-
-    const HloModule& module = executable.ValueOrDie()->module();
-
-    fprintf(stdout, "HLO for %s backend:\n%s\n",
-            local_service->backend().platform()->Name().c_str(),
-            module.ToString().c_str());
   }
 }
 
@@ -74,10 +87,21 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
 }  // namespace xla
 
 int main(int argc, char** argv) {
-  tensorflow::port::InitMain(argv[0], &argc, &argv);
+  bool compile = false;
+  std::vector<tensorflow::Flag> flag_list = {
+      {"compile", &compile,
+       "If true, compile the computation using the default client before "
+       "dumping the HLO. Otherwise dump the raw (uncompiled) HLO."},
+  };
+  const xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+  bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+  QCHECK(parsed_flags_ok) << "\n" << usage;
+
+  tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
+  QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage;
 
   tensorflow::gtl::ArraySlice<char*> args(argv, argc);
   args.pop_front();  // Pop off the binary name, argv[0]
-  xla::tools::RealMain(args);
+  xla::tools::RealMain(args, compile);
   return 0;
 }