[xla] Expose XlaComputation constructor in the Python API
PiperOrigin-RevId: 304430744 Change-Id: Iaef4f9e749b226c821aafcae440cdfc590828473
This commit is contained in:
parent
eb0998340c
commit
bb354b5f64
@ -1206,6 +1206,12 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
py::return_value_policy::reference, py::keep_alive<1, 0>());
|
||||
|
||||
py::class_<XlaComputation>(m, "XlaComputation")
|
||||
.def(py::init([](const py::bytes& serialized_hlo_module_proto)
|
||||
-> std::unique_ptr<XlaComputation> {
|
||||
HloModuleProto proto;
|
||||
proto.ParseFromString(serialized_hlo_module_proto);
|
||||
return absl::make_unique<XlaComputation>(proto);
|
||||
}))
|
||||
.def("GetProgramShape", &XlaComputation::GetProgramShape)
|
||||
.def("GetSerializedProto", &GetComputationSerializedProto)
|
||||
.def("GetHloText", &GetComputationHloText)
|
||||
|
@ -381,6 +381,21 @@ class ComputationsWithConstantsTest(ComputationTest):
|
||||
self._ExecuteAndCompareClose(c, expected=[0.75])
|
||||
|
||||
|
||||
class ComputationFromProtoTest(absltest.TestCase):
|
||||
"""Test computation execution from HLO proto."""
|
||||
|
||||
def testExecuteFromProto(self):
|
||||
# Build the HLO proto
|
||||
b = xla_client.ComputationBuilder("computation")
|
||||
b.Add(b.Constant(np.int8(1)), b.Constant(np.int8(2)))
|
||||
serialized_proto = b.Build().GetSerializedProto()
|
||||
|
||||
# Load and execute the proto
|
||||
c = xla_client.Computation(xla_client._xla.XlaComputation(serialized_proto))
|
||||
ans, = xla_client.execute_with_python_values(c.Compile())
|
||||
np.testing.assert_equal(ans, np.int8(3))
|
||||
|
||||
|
||||
class ParametersTest(ComputationTest):
|
||||
"""Tests focusing on Parameter ops and argument-passing."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user