[xla] Expose XlaComputation constructor in the Python API

PiperOrigin-RevId: 304430744
Change-Id: Iaef4f9e749b226c821aafcae440cdfc590828473
This commit is contained in:
A. Unique TensorFlower 2020-04-02 10:40:52 -07:00 committed by TensorFlower Gardener
parent eb0998340c
commit bb354b5f64
2 changed files with 21 additions and 0 deletions

View File

@ -1206,6 +1206,12 @@ PYBIND11_MODULE(xla_extension, m) {
py::return_value_policy::reference, py::keep_alive<1, 0>());
py::class_<XlaComputation>(m, "XlaComputation")
.def(py::init([](const py::bytes& serialized_hlo_module_proto)
-> std::unique_ptr<XlaComputation> {
HloModuleProto proto;
proto.ParseFromString(serialized_hlo_module_proto);
return absl::make_unique<XlaComputation>(proto);
}))
.def("GetProgramShape", &XlaComputation::GetProgramShape)
.def("GetSerializedProto", &GetComputationSerializedProto)
.def("GetHloText", &GetComputationHloText)

View File

@ -381,6 +381,21 @@ class ComputationsWithConstantsTest(ComputationTest):
self._ExecuteAndCompareClose(c, expected=[0.75])
class ComputationFromProtoTest(absltest.TestCase):
"""Test computation execution from HLO proto."""
def testExecuteFromProto(self):
# Build the HLO proto
b = xla_client.ComputationBuilder("computation")
b.Add(b.Constant(np.int8(1)), b.Constant(np.int8(2)))
serialized_proto = b.Build().GetSerializedProto()
# Load and execute the proto
c = xla_client.Computation(xla_client._xla.XlaComputation(serialized_proto))
ans, = xla_client.execute_with_python_values(c.Compile())
np.testing.assert_equal(ans, np.int8(3))
class ParametersTest(ComputationTest):
"""Tests focusing on Parameter ops and argument-passing."""