From 4f0d3bd4d29545394c049820362875c0fc943b8f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 5 Oct 2020 18:20:33 -0700 Subject: [PATCH] [XLA:Python] Add Python bindings for HloCostAnalysis. PiperOrigin-RevId: 335543233 Change-Id: Ibcf508d3e787b7a01bb8d16d797dad4f6b984ce0 --- tensorflow/compiler/xla/pjrt/BUILD | 1 + tensorflow/compiler/xla/pjrt/pjrt_client.cc | 6 ++++++ tensorflow/compiler/xla/pjrt/pjrt_client.h | 3 +++ tensorflow/compiler/xla/python/xla.cc | 9 ++++++++- tensorflow/compiler/xla/python/xla_client_test.py | 7 +++++++ 5 files changed, 25 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 50fd6a12d66..d81b9a4b84c 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -141,6 +141,7 @@ cc_library( "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index f3dafd71ba5..0a5761c721b 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -90,6 +90,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -282,6 +283,11 @@ StatusOr> PjRtClient::GetParametersThatMustBeDonated( return parameters_to_donate; } +std::unique_ptr PjRtClient::GetHloCostAnalysis() { + return absl::make_unique( + client_->backend().compiler()->ShapeSizeBytesFunction()); +} + namespace { // Ensures that it is safe to deallocate any buffers that have been enqueued in diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 39711534f79..1e4169d0d2b 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -195,6 +195,9 @@ class PjRtClient { return absl::optional(); } + // Returns a backend-specific HLO cost analysis visitor. + virtual std::unique_ptr GetHloCostAnalysis(); + protected: friend class PjRtBuffer; virtual void EnqueueCrossHostReceive( diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 06605660b63..482524ffc18 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -75,7 +75,6 @@ namespace { namespace py = pybind11; - struct Uniquer { absl::Mutex mu; NameUniquer name_uniquer TF_GUARDED_BY(mu); @@ -820,6 +819,14 @@ PYBIND11_MODULE(xla_extension, m) { hlo_module.config().debug_options(), RenderedGraphFormat::kDot); }); + m.def( + "hlo_module_cost_analysis", + [](PyClient* client, + const HloModule& module) -> StatusOr> { + auto analysis = client->pjrt_client()->GetHloCostAnalysis(); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); + return analysis->properties(); + }); py::class_ xla_op_class(m, "XlaOp"); diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 0699f4f0200..3863d8a1481 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -173,6 +173,13 @@ def TestFactory(xla_backend, cloud_tpu=False): self.assertTrue(hlo_text.startswith("HloModule acomputation")) self.assertIn("fusion", hlo_text) + @unittest.skipIf(cloud_tpu, "not implemented") + def testFlopEstimate(self): + computation = self.ExampleComputation() + properties = xla_client._xla.hlo_module_cost_analysis( + self.backend, computation.as_hlo_module()) + self.assertEqual(properties["flops"], 8.0) + tests.append(ComputationPrinting) class ComputationHashTest(absltest.TestCase):