[XLA:Python] Add Python bindings for HloCostAnalysis.

PiperOrigin-RevId: 335543233
Change-Id: Ibcf508d3e787b7a01bb8d16d797dad4f6b984ce0
This commit is contained in:
Peter Hawkins 2020-10-05 18:20:33 -07:00 committed by TensorFlower Gardener
parent fc86fb1743
commit 4f0d3bd4d2
5 changed files with 25 additions and 1 deletions

View File

@ -141,6 +141,7 @@ cc_library(
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",

View File

@ -90,6 +90,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
@ -282,6 +283,11 @@ StatusOr<absl::flat_hash_set<int>> PjRtClient::GetParametersThatMustBeDonated(
return parameters_to_donate;
}
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
return absl::make_unique<HloCostAnalysis>(
client_->backend().compiler()->ShapeSizeBytesFunction());
}
namespace {
// Ensures that it is safe to deallocate any buffers that have been enqueued in

View File

@ -195,6 +195,9 @@ class PjRtClient {
return absl::optional<std::string>();
}
// Returns a backend-specific HLO cost analysis visitor.
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
protected:
friend class PjRtBuffer;
virtual void EnqueueCrossHostReceive(

View File

@ -75,7 +75,6 @@ namespace {
namespace py = pybind11;
struct Uniquer {
absl::Mutex mu;
NameUniquer name_uniquer TF_GUARDED_BY(mu);
@ -820,6 +819,14 @@ PYBIND11_MODULE(xla_extension, m) {
hlo_module.config().debug_options(),
RenderedGraphFormat::kDot);
});
m.def(
"hlo_module_cost_analysis",
[](PyClient* client,
const HloModule& module) -> StatusOr<std::map<string, float>> {
auto analysis = client->pjrt_client()->GetHloCostAnalysis();
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get()));
return analysis->properties();
});
py::class_<XlaOp> xla_op_class(m, "XlaOp");

View File

@ -173,6 +173,13 @@ def TestFactory(xla_backend, cloud_tpu=False):
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
self.assertIn("fusion", hlo_text)
@unittest.skipIf(cloud_tpu, "not implemented")
def testFlopEstimate(self):
computation = self.ExampleComputation()
properties = xla_client._xla.hlo_module_cost_analysis(
self.backend, computation.as_hlo_module())
self.assertEqual(properties["flops"], 8.0)
tests.append(ComputationPrinting)
class ComputationHashTest(absltest.TestCase):