[XLA:Python] Remove xla_python:: namespace. Refactoring only, no functional changes.
Rename LocalShapedBuffer to xla::PyLocalBuffer to more clearly indicate that it is a Python binding class. Move some XlaComputation helpers into an anonymous namespace inside xla.cc so we don't have the potential for collisions outside the Python bindings. They didn't really fit in the local_client.{cc,h} anyway. PiperOrigin-RevId: 244392801
This commit is contained in:
parent
96ce8df98a
commit
1dca7db621
@ -145,6 +145,7 @@ tf_pybind_extension(
|
||||
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||
"//tensorflow/compiler/xla/client/lib:svd",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:name_uniquer",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -41,7 +40,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace xla {
|
||||
namespace xla_python {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
@ -104,7 +102,7 @@ StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
|
||||
return LiteralToPython(absl::make_unique<Literal>(std::move(literal)));
|
||||
}
|
||||
|
||||
static StatusOr<LocalShapedBuffer> TransferHostToDeviceAsync(
|
||||
static StatusOr<PyLocalBuffer> TransferHostToDeviceAsync(
|
||||
const PythonBufferTree& tree, int device_ordinal, PyLocalClient* client,
|
||||
se::Stream* stream) {
|
||||
DeviceMemoryAllocator* allocator =
|
||||
@ -132,37 +130,38 @@ static StatusOr<LocalShapedBuffer> TransferHostToDeviceAsync(
|
||||
transfer_manager->TransferLiteralToDeviceAsync(stream, *it, leaf));
|
||||
++it;
|
||||
}
|
||||
return LocalShapedBuffer(std::move(buffer), client);
|
||||
return PyLocalBuffer(std::move(buffer), client);
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<LocalShapedBuffer> LocalShapedBuffer::FromPython(
|
||||
const py::object& argument, PyLocalClient* client, int device_ordinal) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::FromPython");
|
||||
StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(const py::object& argument,
|
||||
PyLocalClient* client,
|
||||
int device_ordinal) {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython");
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
|
||||
|
||||
// We are done manipulating Python objects; release the GIL.
|
||||
py::gil_scoped_release gil_release;
|
||||
VLOG(1) << "LocalShapedBuffer::FromPython: shape: " << tree.shape.ToString()
|
||||
VLOG(1) << "PyLocalBuffer::FromPython: shape: " << tree.shape.ToString()
|
||||
<< " device ordinal: " << device_ordinal;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
StreamPool::Ptr stream,
|
||||
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
LocalShapedBuffer buffer,
|
||||
PyLocalBuffer buffer,
|
||||
TransferHostToDeviceAsync(tree, device_ordinal, client, stream.get()));
|
||||
stream->BlockHostUntilDone();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/*static */ StatusOr<std::vector<LocalShapedBuffer>>
|
||||
LocalShapedBuffer::FromPythonValues(
|
||||
/*static */ StatusOr<std::vector<PyLocalBuffer>>
|
||||
PyLocalBuffer::FromPythonValues(
|
||||
const std::vector<std::pair<py::object, int>>& arguments,
|
||||
PyLocalClient* client) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::FromPythonValues");
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPythonValues");
|
||||
int num_arguments = static_cast<int>(arguments.size());
|
||||
std::vector<LocalShapedBuffer> outputs(num_arguments);
|
||||
std::vector<PyLocalBuffer> outputs(num_arguments);
|
||||
if (num_arguments == 0) {
|
||||
return outputs;
|
||||
}
|
||||
@ -170,7 +169,7 @@ LocalShapedBuffer::FromPythonValues(
|
||||
struct H2DTransfer {
|
||||
PythonBufferTree tree;
|
||||
StreamPool::Ptr stream;
|
||||
StatusOr<LocalShapedBuffer> buffer;
|
||||
StatusOr<PyLocalBuffer> buffer;
|
||||
};
|
||||
|
||||
std::vector<H2DTransfer> transfers(num_arguments);
|
||||
@ -188,7 +187,7 @@ LocalShapedBuffer::FromPythonValues(
|
||||
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
||||
}
|
||||
|
||||
auto transfer_h2d = [&](int i) -> StatusOr<LocalShapedBuffer> {
|
||||
auto transfer_h2d = [&](int i) -> StatusOr<PyLocalBuffer> {
|
||||
int device_ordinal = arguments[i].second;
|
||||
return TransferHostToDeviceAsync(transfers[i].tree, device_ordinal, client,
|
||||
transfers[i].stream.get());
|
||||
@ -225,26 +224,26 @@ LocalShapedBuffer::FromPythonValues(
|
||||
return outputs;
|
||||
}
|
||||
|
||||
LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer,
|
||||
PyLocalClient* client)
|
||||
PyLocalBuffer::PyLocalBuffer(ScopedShapedBuffer shaped_buffer,
|
||||
PyLocalClient* client)
|
||||
: shaped_buffer_(std::move(shaped_buffer)), client_(client) {}
|
||||
|
||||
const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
|
||||
const ScopedShapedBuffer* PyLocalBuffer::shaped_buffer() const {
|
||||
return &shaped_buffer_.value();
|
||||
}
|
||||
|
||||
ScopedShapedBuffer LocalShapedBuffer::Release() {
|
||||
ScopedShapedBuffer PyLocalBuffer::Release() {
|
||||
ScopedShapedBuffer result = std::move(*shaped_buffer_);
|
||||
shaped_buffer_ = absl::nullopt;
|
||||
return result;
|
||||
}
|
||||
|
||||
const Shape& LocalShapedBuffer::shape() const {
|
||||
const Shape& PyLocalBuffer::shape() const {
|
||||
return shaped_buffer()->on_device_shape();
|
||||
}
|
||||
|
||||
StatusOr<py::object> LocalShapedBuffer::ToPython() const {
|
||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::ToPython");
|
||||
StatusOr<py::object> PyLocalBuffer::ToPython() const {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython");
|
||||
auto literal = absl::make_unique<Literal>();
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
@ -254,13 +253,13 @@ StatusOr<py::object> LocalShapedBuffer::ToPython() const {
|
||||
return LiteralToPython(std::move(literal));
|
||||
}
|
||||
|
||||
StatusOr<std::vector<LocalShapedBuffer>> LocalShapedBuffer::DestructureTuple() {
|
||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::DestructureTuple");
|
||||
StatusOr<std::vector<PyLocalBuffer>> PyLocalBuffer::DestructureTuple() {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::DestructureTuple");
|
||||
const Shape tuple_shape = shape();
|
||||
|
||||
if (!tuple_shape.IsTuple()) {
|
||||
return InvalidArgument(
|
||||
"Attemped to destructure a LocalShapedBuffer that did not have a tuple "
|
||||
"Attemped to destructure a PyLocalBuffer that did not have a tuple "
|
||||
"shape; shape: %s",
|
||||
ShapeUtil::HumanString(tuple_shape));
|
||||
}
|
||||
@ -273,7 +272,7 @@ StatusOr<std::vector<LocalShapedBuffer>> LocalShapedBuffer::DestructureTuple() {
|
||||
int device_ordinal = tuple_buffer.device_ordinal();
|
||||
|
||||
ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
|
||||
std::vector<LocalShapedBuffer> results;
|
||||
std::vector<PyLocalBuffer> results;
|
||||
for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
|
||||
// Create a shaped buffer for this destructured tuple element.
|
||||
const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
|
||||
@ -291,7 +290,7 @@ StatusOr<std::vector<LocalShapedBuffer>> LocalShapedBuffer::DestructureTuple() {
|
||||
});
|
||||
|
||||
VLOG(3) << "Completed tuple element: " << i;
|
||||
results.push_back(LocalShapedBuffer(
|
||||
results.push_back(PyLocalBuffer(
|
||||
ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_));
|
||||
}
|
||||
return results;
|
||||
@ -314,8 +313,8 @@ std::vector<int> PyLocalExecutable::DeviceOrdinals() const {
|
||||
return device_ordinals;
|
||||
}
|
||||
|
||||
StatusOr<LocalShapedBuffer> PyLocalExecutable::Execute(
|
||||
absl::Span<LocalShapedBuffer* const> argument_handles) {
|
||||
StatusOr<PyLocalBuffer> PyLocalExecutable::Execute(
|
||||
absl::Span<PyLocalBuffer* const> argument_handles) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
||||
if (num_replicas() != 1) {
|
||||
return InvalidArgument(
|
||||
@ -345,12 +344,11 @@ StatusOr<LocalShapedBuffer> PyLocalExecutable::Execute(
|
||||
if (!result_buffer_status.ok()) {
|
||||
return result_buffer_status.status();
|
||||
}
|
||||
return LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(),
|
||||
client_);
|
||||
return PyLocalBuffer(std::move(result_buffer_status).ValueOrDie(), client_);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<LocalShapedBuffer>> PyLocalExecutable::ExecutePerReplica(
|
||||
absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles) {
|
||||
StatusOr<std::vector<PyLocalBuffer>> PyLocalExecutable::ExecutePerReplica(
|
||||
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::ExecutePerReplica");
|
||||
const int num_devices = client_->device_count();
|
||||
|
||||
@ -448,7 +446,7 @@ StatusOr<std::vector<LocalShapedBuffer>> PyLocalExecutable::ExecutePerReplica(
|
||||
}
|
||||
VLOG(1) << "Replicated execution complete.";
|
||||
|
||||
std::vector<LocalShapedBuffer> wrapped_results(num_replicas());
|
||||
std::vector<PyLocalBuffer> wrapped_results(num_replicas());
|
||||
for (int replica = 0; replica < num_replicas(); ++replica) {
|
||||
auto& statusor = results[replica];
|
||||
if (!statusor.ok()) {
|
||||
@ -460,46 +458,11 @@ StatusOr<std::vector<LocalShapedBuffer>> PyLocalExecutable::ExecutePerReplica(
|
||||
replica));
|
||||
}
|
||||
wrapped_results[replica] =
|
||||
LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_);
|
||||
PyLocalBuffer(std::move(statusor).ValueOrDie(), client_);
|
||||
}
|
||||
return wrapped_results;
|
||||
}
|
||||
|
||||
StatusOr<py::bytes> GetComputationSerializedProto(
|
||||
const XlaComputation& computation) {
|
||||
std::string result;
|
||||
if (!computation.proto().SerializeToString(&result)) {
|
||||
return Unknown("Failed to serialize the HloModuleProto.");
|
||||
}
|
||||
return py::bytes(result);
|
||||
}
|
||||
|
||||
StatusOr<std::string> GetComputationHloText(const XlaComputation& computation) {
|
||||
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
|
||||
HloModule::CreateModuleConfigFromProto(
|
||||
computation.proto(), GetDebugOptionsFromFlags()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
HloModule::CreateFromProto(computation.proto(), module_config));
|
||||
HloPrintOptions options;
|
||||
options = HloPrintOptions::ShortParsable();
|
||||
options.set_print_large_constants(false);
|
||||
return hlo_module->ToString(options);
|
||||
}
|
||||
|
||||
StatusOr<std::string> GetComputationHloDotGraph(
|
||||
const XlaComputation& computation) {
|
||||
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
|
||||
HloModule::CreateModuleConfigFromProto(
|
||||
computation.proto(), GetDebugOptionsFromFlags()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
HloModule::CreateFromProto(computation.proto(), module_config));
|
||||
return RenderGraph(*hlo_module->entry_computation(), /*label=*/"",
|
||||
hlo_module->config().debug_options(),
|
||||
RenderedGraphFormat::kDot);
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<PyLocalExecutable>>
|
||||
PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||
std::vector<Shape> argument_layouts,
|
||||
@ -559,5 +522,4 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||
std::move(local_executable), std::move(device_assignment), client);
|
||||
}
|
||||
|
||||
} // namespace xla_python
|
||||
} // namespace xla
|
||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace xla_python {
|
||||
|
||||
// Registers a 'fn_capsule' as a CPU custom call target.
|
||||
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
|
||||
@ -74,20 +73,20 @@ class PyLocalClient {
|
||||
// Represents a reference to literals that live in a device-allocated buffer via
|
||||
// XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a
|
||||
// literal to device via the local client.
|
||||
class LocalShapedBuffer {
|
||||
class PyLocalBuffer {
|
||||
public:
|
||||
static StatusOr<LocalShapedBuffer> FromPython(
|
||||
const pybind11::object& argument, PyLocalClient* client,
|
||||
int device_ordinal);
|
||||
static StatusOr<PyLocalBuffer> FromPython(const pybind11::object& argument,
|
||||
PyLocalClient* client,
|
||||
int device_ordinal);
|
||||
|
||||
// Converts multiple (python object, device ordinal) pairs into
|
||||
// LocalShapedBuffers in parallel.
|
||||
static StatusOr<std::vector<LocalShapedBuffer>> FromPythonValues(
|
||||
// PyLocalBuffers in parallel.
|
||||
static StatusOr<std::vector<PyLocalBuffer>> FromPythonValues(
|
||||
const std::vector<std::pair<pybind11::object, int>>& argument,
|
||||
PyLocalClient* client);
|
||||
|
||||
LocalShapedBuffer() = default;
|
||||
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, PyLocalClient* client);
|
||||
PyLocalBuffer() = default;
|
||||
PyLocalBuffer(ScopedShapedBuffer shaped_buffer, PyLocalClient* client);
|
||||
StatusOr<pybind11::object> ToPython() const;
|
||||
const Shape& shape() const;
|
||||
const ScopedShapedBuffer* shaped_buffer() const;
|
||||
@ -101,9 +100,8 @@ class LocalShapedBuffer {
|
||||
client_ = nullptr;
|
||||
}
|
||||
|
||||
// Destructures a tuple-valued LocalShapedBuffer into its constituent
|
||||
// elements in LocalShapedBufferTuple form.
|
||||
StatusOr<std::vector<LocalShapedBuffer>> DestructureTuple();
|
||||
// Destructures a tuple-valued PyLocalBuffer into its constituent elements.
|
||||
StatusOr<std::vector<PyLocalBuffer>> DestructureTuple();
|
||||
|
||||
private:
|
||||
absl::optional<ScopedShapedBuffer> shaped_buffer_;
|
||||
@ -133,14 +131,14 @@ class PyLocalExecutable {
|
||||
return device_assignment_;
|
||||
}
|
||||
|
||||
StatusOr<LocalShapedBuffer> Execute(
|
||||
absl::Span<LocalShapedBuffer* const> argument_handles);
|
||||
StatusOr<PyLocalBuffer> Execute(
|
||||
absl::Span<PyLocalBuffer* const> argument_handles);
|
||||
|
||||
// Execute on many replicas. Takes a sequence of argument lists (one argument
|
||||
// list per replica) and returns a tuple of results (one result per replica).
|
||||
// The number of argument lists must be equal to the replica count.
|
||||
StatusOr<std::vector<LocalShapedBuffer>> ExecutePerReplica(
|
||||
absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles);
|
||||
StatusOr<std::vector<PyLocalBuffer>> ExecutePerReplica(
|
||||
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles);
|
||||
|
||||
void Delete() { executable_ = nullptr; }
|
||||
|
||||
@ -150,18 +148,6 @@ class PyLocalExecutable {
|
||||
PyLocalClient* const client_;
|
||||
};
|
||||
|
||||
// Converts a computation to a serialized HloModuleProto
|
||||
StatusOr<pybind11::bytes> GetComputationSerializedProto(
|
||||
const XlaComputation& computation);
|
||||
|
||||
// Converts a computation to textual HLO form.
|
||||
StatusOr<std::string> GetComputationHloText(const XlaComputation& computation);
|
||||
|
||||
// Converts a computation to HLO dot graph form.
|
||||
StatusOr<std::string> GetComputationHloDotGraph(
|
||||
const XlaComputation& computation);
|
||||
|
||||
} // namespace xla_python
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/optional.h"
|
||||
@ -29,16 +31,21 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/python/xrt.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace xla_python {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace {
|
||||
|
||||
struct Uniquer {
|
||||
absl::Mutex mu;
|
||||
NameUniquer name_uniquer GUARDED_BY(mu);
|
||||
@ -55,6 +62,46 @@ static string UniquifyName(const string& name) {
|
||||
return uniquer->name_uniquer.GetUniqueName(name);
|
||||
}
|
||||
|
||||
// Converts a computation to a serialized HloModuleProto.
|
||||
StatusOr<py::bytes> GetComputationSerializedProto(
|
||||
const XlaComputation& computation) {
|
||||
std::string result;
|
||||
if (!computation.proto().SerializeToString(&result)) {
|
||||
return Unknown("Failed to serialize the HloModuleProto.");
|
||||
}
|
||||
return py::bytes(result);
|
||||
}
|
||||
|
||||
// Converts a computation to textual HLO form.
|
||||
StatusOr<std::string> GetComputationHloText(const XlaComputation& computation) {
|
||||
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
|
||||
HloModule::CreateModuleConfigFromProto(
|
||||
computation.proto(), GetDebugOptionsFromFlags()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
HloModule::CreateFromProto(computation.proto(), module_config));
|
||||
HloPrintOptions options;
|
||||
options = HloPrintOptions::ShortParsable();
|
||||
options.set_print_large_constants(false);
|
||||
return hlo_module->ToString(options);
|
||||
}
|
||||
|
||||
// Converts a computation to HLO dot graph form.
|
||||
StatusOr<std::string> GetComputationHloDotGraph(
|
||||
const XlaComputation& computation) {
|
||||
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
|
||||
HloModule::CreateModuleConfigFromProto(
|
||||
computation.proto(), GetDebugOptionsFromFlags()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
HloModule::CreateFromProto(computation.proto(), module_config));
|
||||
return RenderGraph(*hlo_module->entry_computation(), /*label=*/"",
|
||||
hlo_module->config().debug_options(),
|
||||
RenderedGraphFormat::kDot);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(xla_extension, m) {
|
||||
// Types
|
||||
py::enum_<PrimitiveType>(m, "PrimitiveType")
|
||||
@ -167,13 +214,13 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("TransferToInfeed", &PyLocalClient::TransferToInfeed)
|
||||
.def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed);
|
||||
|
||||
py::class_<LocalShapedBuffer>(m, "LocalShapedBuffer")
|
||||
.def_static("FromPython", &LocalShapedBuffer::FromPython)
|
||||
.def_static("FromPythonValues", &LocalShapedBuffer::FromPythonValues)
|
||||
.def("Delete", &LocalShapedBuffer::Delete)
|
||||
.def("DestructureTuple", &LocalShapedBuffer::DestructureTuple)
|
||||
.def("ToPython", &LocalShapedBuffer::ToPython)
|
||||
.def("shape", &LocalShapedBuffer::shape);
|
||||
py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
|
||||
.def_static("FromPython", &PyLocalBuffer::FromPython)
|
||||
.def_static("FromPythonValues", &PyLocalBuffer::FromPythonValues)
|
||||
.def("Delete", &PyLocalBuffer::Delete)
|
||||
.def("DestructureTuple", &PyLocalBuffer::DestructureTuple)
|
||||
.def("ToPython", &PyLocalBuffer::ToPython)
|
||||
.def("shape", &PyLocalBuffer::shape);
|
||||
|
||||
py::class_<PyLocalExecutable>(m, "LocalExecutable")
|
||||
.def_static("Compile", &PyLocalExecutable::Compile,
|
||||
@ -433,5 +480,4 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
tensorflow::AddXrtSubmodule(&m);
|
||||
}
|
||||
|
||||
} // namespace xla_python
|
||||
} // namespace xla
|
||||
|
@ -113,11 +113,10 @@ class LocalBackend(Backend):
|
||||
return self.client.DeviceCount()
|
||||
|
||||
def buffer_from_pyval(self, pyval, device=0):
|
||||
return _xla.LocalShapedBuffer.FromPython(pyval, self.client, device)
|
||||
return _xla.PyLocalBuffer.FromPython(pyval, self.client, device)
|
||||
|
||||
def buffers_from_pyvals(self, pyvals_and_devices):
|
||||
return _xla.LocalShapedBuffer.FromPythonValues(pyvals_and_devices,
|
||||
self.client)
|
||||
return _xla.PyLocalBuffer.FromPythonValues(pyvals_and_devices, self.client)
|
||||
|
||||
def delete_buffer(self, c_buffer):
|
||||
c_buffer.Delete()
|
||||
|
Loading…
x
Reference in New Issue
Block a user