[XLA:Python] Add Python bindings for HloCostAnalysis.
PiperOrigin-RevId: 335543233 Change-Id: Ibcf508d3e787b7a01bb8d16d797dad4f6b984ce0
This commit is contained in:
parent
fc86fb1743
commit
4f0d3bd4d2
@ -141,6 +141,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||
|
@ -90,6 +90,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
@ -282,6 +283,11 @@ StatusOr<absl::flat_hash_set<int>> PjRtClient::GetParametersThatMustBeDonated(
|
||||
return parameters_to_donate;
|
||||
}
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
|
||||
return absl::make_unique<HloCostAnalysis>(
|
||||
client_->backend().compiler()->ShapeSizeBytesFunction());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Ensures that it is safe to deallocate any buffers that have been enqueued in
|
||||
|
@ -195,6 +195,9 @@ class PjRtClient {
|
||||
return absl::optional<std::string>();
|
||||
}
|
||||
|
||||
// Returns a backend-specific HLO cost analysis visitor.
|
||||
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
|
||||
|
||||
protected:
|
||||
friend class PjRtBuffer;
|
||||
virtual void EnqueueCrossHostReceive(
|
||||
|
@ -75,7 +75,6 @@ namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
struct Uniquer {
|
||||
absl::Mutex mu;
|
||||
NameUniquer name_uniquer TF_GUARDED_BY(mu);
|
||||
@ -820,6 +819,14 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
hlo_module.config().debug_options(),
|
||||
RenderedGraphFormat::kDot);
|
||||
});
|
||||
m.def(
|
||||
"hlo_module_cost_analysis",
|
||||
[](PyClient* client,
|
||||
const HloModule& module) -> StatusOr<std::map<string, float>> {
|
||||
auto analysis = client->pjrt_client()->GetHloCostAnalysis();
|
||||
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get()));
|
||||
return analysis->properties();
|
||||
});
|
||||
|
||||
py::class_<XlaOp> xla_op_class(m, "XlaOp");
|
||||
|
||||
|
@ -173,6 +173,13 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
|
||||
self.assertIn("fusion", hlo_text)
|
||||
|
||||
@unittest.skipIf(cloud_tpu, "not implemented")
|
||||
def testFlopEstimate(self):
|
||||
computation = self.ExampleComputation()
|
||||
properties = xla_client._xla.hlo_module_cost_analysis(
|
||||
self.backend, computation.as_hlo_module())
|
||||
self.assertEqual(properties["flops"], 8.0)
|
||||
|
||||
tests.append(ComputationPrinting)
|
||||
|
||||
class ComputationHashTest(absltest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user