Python library and C++ bindings for creating and compiling local XLA computations.

PiperOrigin-RevId: 179211353
This commit is contained in:
Roy Frostig 2017-12-15 10:38:16 -08:00 committed by TensorFlower Gardener
parent 22fe6558a9
commit 75a91cf3be
14 changed files with 2951 additions and 0 deletions

View File

@ -411,6 +411,7 @@ filegroup(
"//tensorflow/compiler/xla/client:all_files",
"//tensorflow/compiler/xla/client/lib:all_files",
"//tensorflow/compiler/xla/legacy_flags:all_files",
"//tensorflow/compiler/xla/python:all_files",
"//tensorflow/compiler/xla/service:all_files",
"//tensorflow/compiler/xla/service/cpu:all_files",
"//tensorflow/compiler/xla/service/gpu:all_files",

View File

@ -20,6 +20,10 @@ package_group(
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_proto_library_py",
)
# Filegroup used to collect source files for dependency checking.
filegroup(
@ -36,6 +40,12 @@ xla_proto_library(
visibility = ["//visibility:public"],
)
tf_proto_library_py(
name = "xla_data_proto", # bzl adds a _py suffix
srcs = ["xla_data.proto"],
visibility = ["//visibility:public"],
)
xla_proto_library(
name = "xla_proto",
srcs = ["xla.proto"],

View File

@ -0,0 +1,82 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
py_library(
name = "xla_client",
srcs = ["xla_client.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":pywrap_xla",
"//tensorflow/compiler/xla:xla_data_proto_py",
],
)
py_test(
name = "xla_client_test",
srcs = ["xla_client_test.py"],
main = "xla_client_test.py",
srcs_version = "PY2AND3",
deps = [
":xla_client",
"//tensorflow/python:platform_test",
],
)
cc_library(
name = "numpy_bridge",
srcs = ["numpy_bridge.cc"],
hdrs = ["numpy_bridge.h"],
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/python:numpy_lib",
],
)
cc_library(
name = "local_computation_builder",
srcs = ["local_computation_builder.cc"],
hdrs = ["local_computation_builder.h"],
deps = [
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:lib",
],
)
tf_py_wrap_cc(
name = "pywrap_xla",
srcs = ["xla.i"],
swig_includes = [
"local_computation_builder.i",
],
deps = [
":local_computation_builder",
":numpy_bridge",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,265 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace swig {
CompiledLocalComputation::CompiledLocalComputation(
std::unique_ptr<LocalExecutable> executable)
: executable_(std::move(executable)) {}
std::unique_ptr<Literal> CompiledLocalComputation::Execute(
const std::vector<Literal>& arguments) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
// Transfer arguments in
std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
scoped_buffers.reserve(arguments.size());
for (const Literal& argument : arguments) {
scoped_buffers.push_back(
client
->LiteralToShapedBuffer(argument,
/*device_ordinal=*/0,
client->backend().memory_allocator())
.ConsumeValueOrDie());
}
// Execute
std::vector<const ShapedBuffer*> argument_buffers;
argument_buffers.reserve(scoped_buffers.size());
for (auto& buffer : scoped_buffers) {
argument_buffers.push_back(buffer.get());
}
ExecutableRunOptions options;
options.set_allocator(client->backend().memory_allocator());
options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool());
options.set_intra_op_thread_pool(
client->backend().eigen_intra_op_thread_pool_device());
std::unique_ptr<ScopedShapedBuffer> result_buffer =
executable_->Run(argument_buffers, options).ConsumeValueOrDie();
// Transfer result out
return client->ShapedBufferToLiteral(*result_buffer).ConsumeValueOrDie();
}
LocalComputation::LocalComputation(std::unique_ptr<Computation> computation)
: computation_(std::move(computation)) {}
CompiledLocalComputation* LocalComputation::Compile(
const std::vector<Shape>& argument_shapes) {
std::vector<const Shape*> argument_shape_pointers;
argument_shape_pointers.reserve(argument_shapes.size());
for (auto& argument_shape : argument_shapes) {
argument_shape_pointers.push_back(&argument_shape);
}
LocalClient* client = ClientLibrary::LocalClientOrDie();
ExecutableBuildOptions options;
return new CompiledLocalComputation(
client->Compile(*computation_, argument_shape_pointers, options)
.ValueOrDie());
}
const Computation& LocalComputation::computation() const {
return *computation_;
}
LocalComputationBuilder::LocalComputationBuilder(const string& computation_name)
: builder_(ClientLibrary::LocalClientOrDie(), computation_name) {}
LocalComputation* LocalComputationBuilder::Build() {
return new LocalComputation(std::unique_ptr<Computation>(
new Computation(builder_.Build().ConsumeValueOrDie())));
}
ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number,
const Shape& shape,
const string& name) {
return builder_.Parameter(parameter_number, shape, name);
}
std::unique_ptr<Shape> LocalComputationBuilder::GetShape(
const ComputationDataHandle& operand) {
return builder_.GetShape(operand).ConsumeValueOrDie();
}
ComputationDataHandle LocalComputationBuilder::ConstantLiteral(
const Literal& literal) {
return builder_.ConstantLiteral(literal);
}
ComputationDataHandle LocalComputationBuilder::Broadcast(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
return builder_.Broadcast(operand, broadcast_sizes);
}
ComputationDataHandle LocalComputationBuilder::Reshape(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
return builder_.Reshape(operand, dimensions, new_sizes);
}
ComputationDataHandle LocalComputationBuilder::Slice(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
return builder_.Slice(operand, start_indices, limit_indices, strides);
}
ComputationDataHandle LocalComputationBuilder::DynamicSlice(
const ComputationDataHandle& operand,
const ComputationDataHandle& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
return builder_.DynamicSlice(operand, start_indices, slice_sizes);
}
ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice(
const ComputationDataHandle& operand, const ComputationDataHandle& update,
const ComputationDataHandle& start_indices) {
return builder_.DynamicUpdateSlice(operand, update, start_indices);
}
ComputationDataHandle LocalComputationBuilder::ConcatInDim(
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
int64 dimension) {
return builder_.ConcatInDim(operands, dimension);
}
ComputationDataHandle LocalComputationBuilder::Select(
const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
const ComputationDataHandle& on_false) {
return builder_.Select(pred, on_true, on_false);
}
ComputationDataHandle LocalComputationBuilder::Tuple(
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
return builder_.Tuple(elements);
}
ComputationDataHandle LocalComputationBuilder::GetTupleElement(
const ComputationDataHandle& tuple_data, int64 index) {
return builder_.GetTupleElement(tuple_data, index);
}
ComputationDataHandle LocalComputationBuilder::Dot(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
return builder_.Dot(lhs, rhs);
}
ComputationDataHandle LocalComputationBuilder::ConvertElementType(
const ComputationDataHandle& operand, PrimitiveType new_element_type) {
return builder_.ConvertElementType(operand, new_element_type);
}
ComputationDataHandle LocalComputationBuilder::Call(
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
return builder_.Call(local_computation.computation(), operands);
}
ComputationDataHandle LocalComputationBuilder::Transpose(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> permutation) {
return builder_.Transpose(operand, permutation);
}
ComputationDataHandle LocalComputationBuilder::Map(
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
return builder_.Map(operands, local_computation.computation(), dimensions,
static_operands);
}
ComputationDataHandle LocalComputationBuilder::Reduce(
const ComputationDataHandle& operand,
const ComputationDataHandle& init_value,
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
return builder_.Reduce(operand, init_value, local_computation.computation(),
dimensions_to_reduce);
}
ComputationDataHandle LocalComputationBuilder::While(
const LocalComputation& condition, const LocalComputation& body,
const ComputationDataHandle& init) {
return builder_.While(condition.computation(), body.computation(), init);
}
#define _FORWARD(method_name, return_sig, args_sig, args) \
return_sig LocalComputationBuilder::method_name args_sig { \
return builder_.method_name args; \
}
#define _FORWARD_UNOP(method_name) \
_FORWARD(method_name, ComputationDataHandle, \
(const ComputationDataHandle& operand), (operand))
#define _FORWARD_BINOP(method_name) \
_FORWARD( \
method_name, ComputationDataHandle, \
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
(lhs, rhs, broadcast_dimensions))
_FORWARD_BINOP(Eq)
_FORWARD_BINOP(Ne)
_FORWARD_BINOP(Ge)
_FORWARD_BINOP(Gt)
_FORWARD_BINOP(Lt)
_FORWARD_BINOP(Le)
_FORWARD_BINOP(Add)
_FORWARD_BINOP(Sub)
_FORWARD_BINOP(Mul)
_FORWARD_BINOP(Div)
_FORWARD_BINOP(Rem)
_FORWARD_BINOP(Max)
_FORWARD_BINOP(Min)
_FORWARD_BINOP(And)
_FORWARD_BINOP(Or)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Log)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
_FORWARD_UNOP(SqrtF32)
_FORWARD_UNOP(SquareF32)
_FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
_FORWARD_UNOP(ReciprocalF32)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)
#undef _FORWARD
#undef _FORWARD_UNOP
#undef _FORWARD_BINOP
} // namespace swig
} // namespace xla

View File

@ -0,0 +1,210 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace swig {
// Wraps a LocalExecutable produced by compiling a
// LocalComputation. The Execute method forwards to that of the
// underlying LocalExecutable, and additionally handles tranferring
// arguments and return values in and back out of the client library's
// local client. This class is intended to be made available to Python
// via SWIG.
class CompiledLocalComputation {
public:
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
std::unique_ptr<Literal> Execute(const std::vector<Literal>& arguments);
private:
std::unique_ptr<LocalExecutable> executable_;
};
// Wraps a Computation produced by a LocalComputationBuilder. The
// Compile method compiles the computation to a (local) executable via
// the client library's local client. This class is intended to be
// made available to Python via SWIG.
class LocalComputation {
public:
LocalComputation(std::unique_ptr<Computation> computation);
CompiledLocalComputation* Compile(const std::vector<Shape>& argument_shapes);
const Computation& computation() const;
private:
std::unique_ptr<Computation> computation_;
};
// Wraps the ComputationBuilder API in order to:
// - Support consumption by SWIG in order to be made available to
// Python.
// - Set up the underlying builder to use the client library's
// LocalClient.
// - Wrap Computations in LocalComputations for Python access.
// - Correspondingly unwrap incoming LocalComputations.
class LocalComputationBuilder {
public:
LocalComputationBuilder(const string& computation_name);
LocalComputation* Build();
ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape,
const string& name);
std::unique_ptr<Shape> GetShape(const ComputationDataHandle& operand);
ComputationDataHandle ConstantLiteral(const Literal& literal);
ComputationDataHandle Broadcast(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
ComputationDataHandle Reshape(const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes);
ComputationDataHandle Slice(const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides);
ComputationDataHandle DynamicSlice(
const ComputationDataHandle& operand,
const ComputationDataHandle& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes);
ComputationDataHandle DynamicUpdateSlice(
const ComputationDataHandle& operand, const ComputationDataHandle& update,
const ComputationDataHandle& start_indices);
ComputationDataHandle ConcatInDim(
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
int64 dimension);
ComputationDataHandle Select(const ComputationDataHandle& pred,
const ComputationDataHandle& on_true,
const ComputationDataHandle& on_false);
ComputationDataHandle Tuple(
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data,
int64 index);
ComputationDataHandle Dot(const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs);
ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
PrimitiveType new_element_type);
ComputationDataHandle Call(
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);
ComputationDataHandle Transpose(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> permutation);
ComputationDataHandle Map(
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands);
ComputationDataHandle Reduce(
const ComputationDataHandle& operand,
const ComputationDataHandle& init_value,
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
ComputationDataHandle While(const LocalComputation& condition,
const LocalComputation& body,
const ComputationDataHandle& init);
#define _FORWARD(method_name, return_sig, args_sig) \
return_sig method_name args_sig;
#define _FORWARD_UNOP(method_name) \
_FORWARD(method_name, ComputationDataHandle, \
(const ComputationDataHandle& operand))
#define _FORWARD_BINOP(method_name) \
_FORWARD( \
method_name, ComputationDataHandle, \
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
_FORWARD_BINOP(Eq)
_FORWARD_BINOP(Ne)
_FORWARD_BINOP(Ge)
_FORWARD_BINOP(Gt)
_FORWARD_BINOP(Lt)
_FORWARD_BINOP(Le)
_FORWARD_BINOP(Add)
_FORWARD_BINOP(Sub)
_FORWARD_BINOP(Mul)
_FORWARD_BINOP(Div)
_FORWARD_BINOP(Rem)
_FORWARD_BINOP(Max)
_FORWARD_BINOP(Min)
_FORWARD_BINOP(And)
_FORWARD_BINOP(Or)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Log)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
_FORWARD_UNOP(SqrtF32)
_FORWARD_UNOP(SquareF32)
_FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
_FORWARD_UNOP(ReciprocalF32)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)
#undef _FORWARD
#undef _FORWARD_UNOP
#undef _FORWARD_BINOP
private:
ComputationBuilder builder_;
};
static void DeleteLocalComputation(LocalComputation* computation) {
delete computation;
}
static void DeleteCompiledLocalComputation(
CompiledLocalComputation* computation) {
delete computation;
}
} // namespace swig
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_

View File

@ -0,0 +1,348 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// SWIG typemaps and declarations for building, compiling, and
// executing XLA computations, wrapping most of what is declared in
// local_computation_builder.h.
//
// The typemaps below implement/assert the following correspondences
// (with elaborations below):
//
// C++ Python
// -------------------------------------+---------------------------------------
// ComputationDataHandle <-> long
// ArraySlice<int64> <- sequence of long
// ArraySlice<ComputationDataHandle> <- sequence of long
// Literal <-> (nested tuple of) numpy ndarray
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
// Shape <-> pair holding (dtype, dimensions)
// std::vector<Shape> <- sequence of shape information pairs
// PrimitiveType <- int
//
// Arrows indicate whether a conversion only ever occurs in one
// direction, or whether it is maintained bidirectionally. Also,
// "long" and "int" denote the Python types so named, not C.
//
// The Python objects corresponding to C++ Literals have the type:
//
// T = ndarray | (T, ...)
//
// where a terminal numpy ndarray translates to a Literal with a
// non-tuple Shape, an XLA primitive element type corresponding to the
// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates
// to a tuple-shaped Literal whose tuple components are translated
// recursively. For example, if x is a numpy ndarray in Python, with
// shape (2, 3) and dtype of dtype('float32'), then x translates to a
// Literal with rank 2, dimension 2 and 3, and XLA primitive type
// F32. Meanwhile,
//
// (x, (x, x), (x,)),
//
// translates to a tuple-shaped XLA Literal, whose component subshapes
// are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
//
// The Python objects corresponding to C++ Shapes have the type:
//
// T = (dtype, S)
// S = DIMENSIONS | TUPLE_SHAPES
// DIMENSIONS = (int, ...)
// TUPLE_SHAPES = (T, ...)
//
// In the pair described by the T rule, the terminal dtype determines
// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is
// dtype('O'), numpy's object dtype, the structure represents a tuple
// shape and the expansion of the non-terminal S is
// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type
// and S expands into DIMENSIONS giving dimension sizes. For example:
//
// (dtype('float32'), (3, 5, 7))
//
// describes a 3x5x7 array of F32s, and
//
// (dtype('O'), ((dtype('float32'), (2, 3)),
// (dtype('float64'), (4, 5))))
//
// describes a tuple shape with two subshapes: the first a 2x3 F32,
// and the other a 4x5 F64.
//
// The Python int corresponding to a PrimitiveType enum must be valid
// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32).
//
// The SWIG object wrappers generated by this file are not intended
// for end use, but rather for internal use in the Python XLA client,
// xla_client.py.
//
// One central reason for the Python-side indirection is that the
// Python-side objects produced by the typemaps in this file are
// further packaged up by xla_client before being passed on. For
// instance, xla_client wraps the long produced for a C++
// ComputationDataHandle in a Python ComputationDataHandle proto,
// rather than exposing a raw long outside of the client. Similarly,
// the Python pair produced for a C++ Shape is further wrapped in a
// Python class (xla_client.Shape) so as not to expose the raw pair
// externally.
//
// Other SWIG object wrappers (e.g. of LocalComputation) are further
// wrapped by xla_client in order to set up a custom destructor that
// triggers memory deallocation on the C++ side.
%include "tensorflow/python/platform/base.i"
%{
// Must be included first
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
using namespace xla;
using namespace xla::swig;
%}
// Required to use PyArray_* functions.
%init %{
tensorflow::ImportNumpy();
%}
// ComputationDataHandle
%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) {
const int64 handle = numpy::PyIntOrPyLongToLong($input);
if (handle == -1 && PyErr_Occurred()) {
return NULL;
}
temp.set_handle(handle);
$1 = &temp;
}
%typemap(out) ComputationDataHandle {
$result = numpy::LongToPyIntOrPyLong($1.handle());
}
// ArraySlice<int64>
%typemap(in) tensorflow::gtl::ArraySlice<int64>
(std::vector<int64> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
return NULL;
}
const int size = PySequence_Size($input);
temps.resize(size);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
PyObject* py_int = numpy::PyNumberToPyInt(o);
if (!py_int) {
PyErr_SetString(
PyExc_TypeError,
"Argument sequence element cannot be converted to int");
Py_DECREF(o);
return NULL;
}
temps[i] = numpy::PyIntOrPyLongToLong(py_int);
if (temps[i] == -1 && PyErr_Occurred()) {
Py_DECREF(py_int);
Py_DECREF(o);
return NULL;
}
Py_DECREF(py_int);
Py_DECREF(o);
}
$1 = temps;
}
// ComputationDataHandle
%typemap(in) tensorflow::gtl::ArraySlice<ComputationDataHandle>
(std::vector<ComputationDataHandle> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
return NULL;
}
const int size = PySequence_Size($input);
temps.resize(size);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
PyObject* py_int = numpy::PyNumberToPyInt(o);
if (!py_int) {
PyErr_SetString(
PyExc_TypeError,
"Argument sequence element cannot be converted to int");
return NULL;
}
const int64 handle = numpy::PyIntOrPyLongToLong(py_int);
if (handle == -1 && PyErr_Occurred()) {
Py_DECREF(py_int);
Py_DECREF(o);
return NULL;
}
temps[i].set_handle(handle);
Py_DECREF(py_int);
Py_DECREF(o);
}
$1 = temps;
}
// Literal
%typemap(in) const Literal& (std::unique_ptr<Literal> temp) {
temp = numpy::XlaLiteralFromPyObject($input);
$1 = &*temp;
}
%typemap(out) std::unique_ptr<Literal> {
$result = numpy::PyObjectFromXlaLiteral(*$1);
}
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
return NULL;
}
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
temps.push_back(*numpy::XlaLiteralFromPyObject(o));
Py_DECREF(o);
}
$1 = &temps;
}
// Shape
%typemap(in) const Shape& (Shape temp) {
if (!numpy::CheckPyShapeInfo($input)) {
return NULL;
}
temp = numpy::XlaShapeFromPyShapeInfo($input);
$1 = &temp;
}
%typemap(out) std::unique_ptr<Shape> {
$result = numpy::PyShapeInfoFromXlaShape(*$1);
}
%typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
return NULL;
}
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
if (!numpy::CheckPyShapeInfo(o)) {
Py_DECREF(o);
return NULL;
}
temps.push_back(numpy::XlaShapeFromPyShapeInfo(o));
Py_DECREF(o);
}
$1 = &temps;
}
// PrimitiveType
%typemap(in) PrimitiveType {
PyObject* py_int = numpy::PyNumberToPyInt($input);
if (!py_int) {
PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int");
return NULL;
}
const long value = numpy::PyIntOrPyLongToLong(py_int);
if (value == -1 && PyErr_Occurred()) {
Py_DECREF(py_int);
return NULL;
}
if (!PrimitiveType_IsValid(value)) {
PyErr_SetString(
PyExc_TypeError, "Argument not valid for PrimitiveType enum");
Py_DECREF(py_int);
return NULL;
}
$1 = static_cast<PrimitiveType>(value);
}
%ignoreall
%unignore xla;
%unignore xla::swig;
%unignore xla::swig::CompiledLocalComputation;
%unignore xla::swig::CompiledLocalComputation::Execute;
%unignore xla::swig::LocalComputation;
%unignore xla::swig::LocalComputation::Compile;
%unignore xla::swig::LocalComputationBuilder;
%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder;
%unignore xla::swig::LocalComputationBuilder::Build;
%unignore xla::swig::LocalComputationBuilder::Parameter;
%unignore xla::swig::LocalComputationBuilder::GetShape;
%unignore xla::swig::LocalComputationBuilder::ConstantLiteral;
%unignore xla::swig::LocalComputationBuilder::ConstantR0;
%unignore xla::swig::LocalComputationBuilder::Broadcast;
%unignore xla::swig::LocalComputationBuilder::Reshape;
%unignore xla::swig::LocalComputationBuilder::Slice;
%unignore xla::swig::LocalComputationBuilder::DynamicSlice;
%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice;
%unignore xla::swig::LocalComputationBuilder::ConcatInDim;
%unignore xla::swig::LocalComputationBuilder::Select;
%unignore xla::swig::LocalComputationBuilder::Tuple;
%unignore xla::swig::LocalComputationBuilder::GetTupleElement;
%unignore xla::swig::LocalComputationBuilder::ConvertElementType;
%unignore xla::swig::LocalComputationBuilder::Call;
%unignore xla::swig::LocalComputationBuilder::Transpose;
%unignore xla::swig::LocalComputationBuilder::Map;
%unignore xla::swig::LocalComputationBuilder::Reduce;
%unignore xla::swig::LocalComputationBuilder::While;
%unignore xla::swig::LocalComputationBuilder::Eq;
%unignore xla::swig::LocalComputationBuilder::Ne;
%unignore xla::swig::LocalComputationBuilder::Ge;
%unignore xla::swig::LocalComputationBuilder::Gt;
%unignore xla::swig::LocalComputationBuilder::Lt;
%unignore xla::swig::LocalComputationBuilder::Le;
%unignore xla::swig::LocalComputationBuilder::Dot;
%unignore xla::swig::LocalComputationBuilder::Add;
%unignore xla::swig::LocalComputationBuilder::Sub;
%unignore xla::swig::LocalComputationBuilder::Mul;
%unignore xla::swig::LocalComputationBuilder::Div;
%unignore xla::swig::LocalComputationBuilder::Rem;
%unignore xla::swig::LocalComputationBuilder::Max;
%unignore xla::swig::LocalComputationBuilder::Min;
%unignore xla::swig::LocalComputationBuilder::And;
%unignore xla::swig::LocalComputationBuilder::Or;
%unignore xla::swig::LocalComputationBuilder::Not;
%unignore xla::swig::LocalComputationBuilder::Abs;
%unignore xla::swig::LocalComputationBuilder::Exp;
%unignore xla::swig::LocalComputationBuilder::Floor;
%unignore xla::swig::LocalComputationBuilder::Ceil;
%unignore xla::swig::LocalComputationBuilder::Log;
%unignore xla::swig::LocalComputationBuilder::Sign;
%unignore xla::swig::LocalComputationBuilder::Cos;
%unignore xla::swig::LocalComputationBuilder::Sin;
%unignore xla::swig::LocalComputationBuilder::Tanh;
%unignore xla::swig::LocalComputationBuilder::SqrtF32;
%unignore xla::swig::LocalComputationBuilder::SquareF32;
%unignore xla::swig::LocalComputationBuilder::Pow;
%unignore xla::swig::LocalComputationBuilder::IsFinite;
%unignore xla::swig::LocalComputationBuilder::ReciprocalF32;
%unignore xla::swig::LocalComputationBuilder::Neg;
%unignore xla::swig::LocalComputationBuilder::Sort;
%unignore xla::swig::DeleteLocalComputation;
%unignore xla::swig::DeleteCompiledLocalComputation;
%include "tensorflow/compiler/xla/python/local_computation_builder.h"
%unignoreall

View File

@ -0,0 +1,389 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace swig {
namespace numpy {
int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
switch (primitive_type) {
case PRED:
return NPY_BOOL;
case S8:
return NPY_INT8;
case S16:
return NPY_INT16;
case S32:
return NPY_INT32;
case S64:
return NPY_INT64;
case U8:
return NPY_UINT8;
case U16:
return NPY_UINT16;
case U32:
return NPY_UINT32;
case U64:
return NPY_UINT64;
case F16:
return NPY_FLOAT16;
case F32:
return NPY_FLOAT32;
case F64:
return NPY_FLOAT64;
case TUPLE:
return NPY_OBJECT;
default:
LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type;
}
}
PrimitiveType NumpyTypeToPrimitiveType(int np_type) {
switch (np_type) {
case NPY_BOOL:
return PRED;
case NPY_INT8:
return S8;
case NPY_INT16:
return S16;
case NPY_INT32:
return S32;
case NPY_INT64:
return S64;
case NPY_UINT8:
return U8;
case NPY_UINT16:
return U16;
case NPY_UINT32:
return U32;
case NPY_UINT64:
return U64;
case NPY_FLOAT16:
return F16;
case NPY_FLOAT32:
return F32;
case NPY_FLOAT64:
return F64;
case NPY_OBJECT:
return TUPLE;
default:
LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type;
}
}
bool NumpyTypeIsValid(int np_type) {
switch (np_type) {
case NPY_BOOL:
case NPY_INT8:
case NPY_INT16:
case NPY_INT32:
case NPY_INT64:
case NPY_UINT8:
case NPY_UINT16:
case NPY_UINT32:
case NPY_UINT64:
case NPY_FLOAT16:
case NPY_FLOAT32:
case NPY_FLOAT64:
case NPY_OBJECT:
return true;
default:
return false;
}
}
PyObject* PyShapeInfoFromXlaShape(const Shape& shape) {
int np_typenum = PrimitiveTypeToNumpyType(shape.element_type());
PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum);
PyObject* dimensions;
if (ShapeUtil::IsTuple(shape)) {
int num_elements = ShapeUtil::TupleElementCount(shape);
dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape));
for (int i = 0; i < num_elements; ++i) {
PyTuple_SET_ITEM(
dimensions, i,
PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)));
}
} else {
int rank = ShapeUtil::Rank(shape);
dimensions = PyTuple_New(rank);
for (int i = 0; i < rank; ++i) {
PyTuple_SET_ITEM(dimensions, i,
LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i)));
}
}
return PyTuple_Pack(2, np_dtype, dimensions);
}
// Precondition: o->ob_type == &PyArrayDescr_Type
static int NumpyTypenum(PyObject* o) {
return reinterpret_cast<PyArray_Descr*>(o)->type_num;
}
bool CheckPyShapeInfo(PyObject* o) {
// The object is a tuple (a pair)
if (!PyTuple_Check(o)) {
PyErr_SetString(PyExc_TypeError, "Shape record must be a tuple");
return false;
}
if (PyTuple_Size(o) != 2) {
PyErr_SetString(PyExc_ValueError, "Shape record tuple must be of length 2");
return false;
}
// It has a first element, which is a numpy dtype object
PyObject* first = PyTuple_GetItem(o, 0);
if (!first) {
return false;
}
if (first->ob_type != &PyArrayDescr_Type) {
PyErr_SetString(
PyExc_TypeError,
"Shape record does not have a numpy dtype as its first element");
return false;
}
const int np_type = NumpyTypenum(first);
if (!NumpyTypeIsValid(np_type)) {
PyErr_SetString(PyExc_ValueError,
"Shape record has an invalid integer dtype");
return false;
}
// It has a second element, which is a tuple, either of shape
// records or of Python ints
PyObject* second = PyTuple_GetItem(o, 1);
if (!second) {
return false;
}
if (!PyTuple_Check(second)) {
PyErr_SetString(PyExc_TypeError,
"Shape record does not have a tuple as its second element");
return false;
}
const int length = PyTuple_Size(second);
const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
for (int i = 0; i < length; i++) {
PyObject* dimension = PyTuple_GetItem(second, i);
if (element_type == TUPLE) {
if (!CheckPyShapeInfo(dimension)) {
return false;
}
} else if (!CheckPyIntOrLong(dimension)) {
PyErr_SetString(PyExc_TypeError,
"Non-tuple shape record has a non-integer dimension");
return false;
}
}
return true;
}
// Precondition: CheckPyShapeInfo(o)
Shape XlaShapeFromPyShapeInfo(PyObject* o) {
const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0));
const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
PyObject* py_dimensions = PyTuple_GetItem(o, 1);
const int length = PyTuple_Size(py_dimensions);
if (element_type == TUPLE) {
std::vector<Shape> subshapes;
subshapes.reserve(length);
for (int i = 0; i < length; i++) {
subshapes.push_back(
XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i)));
}
return ShapeUtil::MakeTupleShape(subshapes);
} else {
std::vector<int64> dimensions(length);
for (int i = 0; i < length; i++) {
dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
if (dimensions[i] == -1) {
CHECK(!PyErr_Occurred());
}
}
return ShapeUtil::MakeShape(element_type, dimensions);
}
}
PyObject* PyObjectFromXlaLiteral(const Literal& literal) {
if (ShapeUtil::IsTuple(literal.shape())) {
const std::vector<Literal>& tuple_literals = literal.tuple_literals();
int num_elements = ShapeUtil::TupleElementCount(literal.shape());
PyObject* tuple = PyTuple_New(num_elements);
for (int i = 0; i < num_elements; i++) {
PyTuple_SET_ITEM(tuple, i, PyObjectFromXlaLiteral(tuple_literals[i]));
}
return tuple;
} else {
int rank = ShapeUtil::Rank(literal.shape());
std::vector<long> dimensions(rank); // NOLINT - PyArray requires a long*
for (int i = 0; i < rank; i++) {
dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
}
int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
PyObject* array =
PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0);
CopyLiteralToNumpyArray(np_type, literal,
reinterpret_cast<PyArrayObject*>(array));
return array;
}
}
std::unique_ptr<Literal> XlaLiteralFromPyObject(PyObject* o) {
if (PyTuple_Check(o)) {
int num_elements = PyTuple_Size(o);
std::vector<std::unique_ptr<Literal>> elements;
elements.reserve(num_elements);
for (int i = 0; i < num_elements; i++) {
PyObject* element = PyTuple_GetItem(o, i);
elements.push_back(XlaLiteralFromPyObject(element));
}
return Literal::MakeTupleOwned(std::move(elements));
} else if (PyArray_Check(o)) {
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(o);
int rank = PyArray_NDIM(py_array);
std::vector<int64> dimensions(rank);
for (int i = 0; i < rank; i++) {
dimensions[i] = PyArray_DIM(py_array, i);
}
int np_type = PyArray_TYPE(py_array);
auto literal = Literal::CreateFromDimensions(
NumpyTypeToPrimitiveType(np_type), dimensions);
CopyNumpyArrayToLiteral(np_type, py_array, literal.get());
return literal;
} else {
LOG(FATAL)
<< "Non-tuple or Numpy array encountered in conversion to XLA literal";
}
}
void CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
Literal* literal) {
switch (np_type) {
case NPY_BOOL:
CopyNumpyArrayToLiteral<bool>(py_array, literal);
break;
case NPY_INT32:
CopyNumpyArrayToLiteral<int32>(py_array, literal);
break;
case NPY_INT64:
CopyNumpyArrayToLiteral<int64>(py_array, literal);
break;
case NPY_UINT8:
CopyNumpyArrayToLiteral<uint8>(py_array, literal);
break;
case NPY_UINT32:
CopyNumpyArrayToLiteral<uint32>(py_array, literal);
break;
case NPY_UINT64:
CopyNumpyArrayToLiteral<uint64>(py_array, literal);
break;
case NPY_FLOAT16:
CopyNumpyArrayToLiteral<half>(py_array, literal);
break;
case NPY_FLOAT32:
CopyNumpyArrayToLiteral<float>(py_array, literal);
break;
case NPY_FLOAT64:
CopyNumpyArrayToLiteral<double>(py_array, literal);
break;
default:
LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
}
}
void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
PyArrayObject* py_array) {
switch (np_type) {
case NPY_BOOL:
CopyLiteralToNumpyArray<bool>(literal, py_array);
break;
case NPY_INT32:
CopyLiteralToNumpyArray<int32>(literal, py_array);
break;
case NPY_INT64:
CopyLiteralToNumpyArray<int64>(literal, py_array);
break;
case NPY_UINT8:
CopyLiteralToNumpyArray<uint8>(literal, py_array);
break;
case NPY_UINT32:
CopyLiteralToNumpyArray<uint32>(literal, py_array);
break;
case NPY_UINT64:
CopyLiteralToNumpyArray<uint64>(literal, py_array);
break;
case NPY_FLOAT16:
CopyLiteralToNumpyArray<half>(literal, py_array);
break;
case NPY_FLOAT32:
CopyLiteralToNumpyArray<float>(literal, py_array);
break;
case NPY_FLOAT64:
CopyLiteralToNumpyArray<double>(literal, py_array);
break;
default:
LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
}
}
PyObject* LongToPyIntOrPyLong(long x) { // NOLINT
#if PY_MAJOR_VERSION < 3
return PyInt_FromLong(x);
#else
return PyLong_FromLong(x);
#endif
}
long PyIntOrPyLongToLong(PyObject* o) { // NOLINT
#if PY_MAJOR_VERSION < 3
return PyInt_AsLong(o);
#else
return PyLong_AsLong(o);
#endif
}
bool CheckPyIntOrLong(PyObject* o) {
#if PY_MAJOR_VERSION < 3
return PyInt_Check(o);
#else
if (!PyLong_Check(o)) {
return false;
}
int overflow = 0;
PyLong_AsLongAndOverflow(o, &overflow);
return (overflow == 0);
#endif
}
PyObject* PyNumberToPyInt(PyObject* o) {
#if PY_MAJOR_VERSION < 3
return PyNumber_Int(o);
#else
return PyNumber_Long(o);
#endif
}
} // namespace numpy
} // namespace swig
} // namespace xla

View File

@ -0,0 +1,123 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// These functions transform Python/Numpy data structures to XLA data
// structures and vice versa, performing copies where
// appropriate. Python tuples and Numpy ndarrays translate to XLA
// tuples and XLA literals, respectively, and Numpy shape/dtype
// information is translated to XLA shape information.
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
#include <algorithm>
#include <memory>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/python/lib/core/numpy.h"
namespace xla {
namespace swig {
namespace numpy {
// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
// vice versa.
int PrimitiveTypeToNumpyType(PrimitiveType primitive_type);
PrimitiveType NumpyTypeToPrimitiveType(int np_type);
// Determines whether an integer-encoded Numpy dtype is valid,
// i.e. has a supported conversion to an XLA PrimitiveType.
bool NumpyTypeIsValid(int np_type);
// Converts XLA shape information into a Python pair of the form
// (numpy dtype, dimensions). If the XLA shape represents a tuple,
// then the numpy dtype is NPY_OBJECT ('O') and `dimensions` is a
// Python tuple of shape-description pairs, created
// recursively. Otherwise, `dimensions` is a Python tuple-of-integers
// providing the array dimensions.
//
// The return value is a new reference.
PyObject* PyShapeInfoFromXlaShape(const Shape& shape);
// Returns the outcome of a best-effort check that the Python object
// is a pair of the form (numpy dtype, dimensions), as produced by
// PyShapeInfoFromXlaShape.
bool CheckPyShapeInfo(PyObject* o);
// Performs the inverse conversion to that of PyShapeInfoFromXlaShape.
//
// The return value is a new reference.
Shape XlaShapeFromPyShapeInfo(PyObject* o);
// Converts an XLA literal to a Python object, either a Numpy ndarray
// or a nested Python tuple thereof.
//
// To avoid transferring ownership of the data buffers that underlie
// PyArrays and XLA literals, this function makes deep copies of all
// array data.
//
// The return value is a new reference.
PyObject* PyObjectFromXlaLiteral(const Literal& literal);
// Converts a Numpy ndarray or a nested Python tuple thereof to a
// corresponding XLA literal.
//
// To avoid transferring ownership of the data buffers that underlie
// PyArrays and XLA literals, this function makes deep copies of all
// array data.
std::unique_ptr<Literal> XlaLiteralFromPyObject(PyObject* o);
// The following functions copy array data from the buffers underlying Numpy
// ndarrays into those underlying XLA literals, and vice versa.
void CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
Literal* literal);
void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
PyArrayObject* py_array);
template <typename NativeT>
void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
NativeT* source = static_cast<NativeT*>(PyArray_DATA(py_array));
auto dest = literal->GetMutableArraySlice<NativeT>();
std::copy(source, source + PyArray_SIZE(py_array), dest.data());
}
template <typename NativeT>
void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) {
NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array));
auto source = literal.GetArraySlice<NativeT>();
std::copy(source.begin(), source.end(), dest);
}
// Workarounds for Python 2 and 3 interop
PyObject* LongToPyIntOrPyLong(long x); // NOLINT
long PyIntOrPyLongToLong(PyObject* o); // NOLINT
bool CheckPyIntOrLong(PyObject* o);
PyObject* PyNumberToPyInt(PyObject* o);
} // namespace numpy
} // namespace swig
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_

View File

@ -0,0 +1,18 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/* XLA-wide SWIG wrapper */
%include "tensorflow/compiler/xla/python/local_computation_builder.i"

View File

@ -0,0 +1,605 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An in-process, local XLA client in Python, supporting AOT compilation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.compiler.xla.python import pywrap_xla as c_api
_UNARY_OPS = [
'Not',
'Abs',
'Exp',
'Floor',
'Ceil',
'Log',
'Sign',
'Cos',
'Sin',
'Tanh',
'SqrtF32',
'SquareF32',
'IsFinite',
'ReciprocalF32',
'Neg',
'Sort',
]
_BINARY_OPS = [
'Eq',
'Ne',
'Ge',
'Gt',
'Lt',
'Le',
'Add',
'Sub',
'Mul',
'Div',
'Rem',
'Max',
'Min',
'And',
'Or',
'Pow',
]
# Most functions are snake_case for consistency with other modules,
# whereas method names of ComputationBuilder and LocalComputation are
# CamelCase for consistency with XLA.
# pylint: disable=invalid-name
XLA_ELEMENT_TYPE_TO_DTYPE = {
xla_data_pb2.F32: np.dtype(np.float32),
xla_data_pb2.F64: np.dtype(np.float64),
xla_data_pb2.S32: np.dtype(np.int32),
xla_data_pb2.S64: np.dtype(np.int64),
xla_data_pb2.PRED: np.dtype(np.bool),
xla_data_pb2.TUPLE: np.dtype(np.object),
}
DTYPE_TO_XLA_ELEMENT_TYPE = {
str(v): k
for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items()
}
class Shape(object):
"""XLA shape.
Represents an XLA shape by a corresponding Python/Numpy type and a
list of dimensions, which are themselves Shapes in case this one
represents an XLA tuple.
"""
def __init__(self, np_dtype, dimensions):
self.np_dtype = np_dtype
self._dimensions = dimensions
def element_type(self):
return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)]
def is_tuple(self):
return self.element_type() == xla_data_pb2.TUPLE
def dimensions(self):
if self.is_tuple():
raise ValueError('Tuple shape has no dimensions')
return self._dimensions
def tuple_shapes(self):
if not self.is_tuple():
raise ValueError('Shape is not a tuple shape')
return self._dimensions
@staticmethod
def from_numpy(npval):
def convert(npval):
if isinstance(npval, tuple):
return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval))
else:
return Shape(npval.dtype, np.shape(npval))
return convert(require_numpy_array_layout(npval))
def _wrap_shape(shape_info):
dtype, dims = shape_info
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
if element_type == xla_data_pb2.TUPLE:
dims = [_wrap_shape(subshape_info) for subshape_info in dims]
return Shape(dtype, dims)
def _unwrap_shape(shape):
if shape.is_tuple():
components = tuple(
_unwrap_shape(subshape) for subshape in shape.tuple_shapes())
else:
components = shape.dimensions()
return (shape.np_dtype, components)
def _unwrap_shapes(shapes):
return [_unwrap_shape(shape) for shape in shapes]
def _wrap_data_handle(handle):
cdh = xla_data_pb2.ComputationDataHandle()
cdh.handle = handle
return cdh
def _unwrap_data_handle(handle_proto):
return handle_proto.handle
def _unwrap_data_handles(handle_protos):
return [_unwrap_data_handle(cdh) for cdh in handle_protos]
def require_numpy_array_layout(value):
if isinstance(value, tuple):
return tuple(require_numpy_array_layout(x) for x in value)
else:
return np.require(value, requirements=['C', 'A'])
class LocalComputation(object):
"""Python wrapper for a local XLA Computation.
A LocalComputation can be executed if it is compiled. Otherwise, it
can still be used as a Computation where required by the
ComputationBuilder methods.
"""
def __init__(self, c_local_computation, is_compiled):
self.c_local_computation = c_local_computation
self.is_compiled = is_compiled
# Ensure a reference to C-based destructor for use in __del__.
if is_compiled:
self._delete = c_api.DeleteCompiledLocalComputation
else:
self._delete = c_api.DeleteLocalComputation
def Compile(self, argument_shapes=()):
if self.is_compiled:
raise ValueError('Attempt to compile a compiled local XLA computation.')
return LocalComputation(
self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)),
is_compiled=True)
def CompileWithExampleArguments(self, arguments=()):
return self.Compile(
argument_shapes=[Shape.from_numpy(arg) for arg in arguments])
def Execute(self, arguments=()):
if not self.is_compiled:
raise ValueError('Cannot execute an uncompiled local XLA computation.')
arguments = tuple(map(require_numpy_array_layout, arguments))
return self.c_local_computation.Execute(arguments)
def __del__(self):
self._delete(self.c_local_computation)
class ComputationBuilder(object):
"""XLA computation builder.
Enqueues XLA ops in sequence and in order to build a
LocalComputation, which in turn can be compiled into a
CompiledLocalComputation, which in turn can be locally executed.
"""
# The methods of this class map 1-to-1 onto the XLA C++
# computation builder API. Therefore, there's no need to laboriously list
# arguments and return values for every method, especially where it's obvious.
#
# pylint: disable=g-doc-return-or-yield
# pylint: disable=g-doc-args
def __init__(self, name):
self._client = c_api.LocalComputationBuilder(name.encode('utf8'))
self._parameter_numbering = itertools.count()
def Build(self):
return LocalComputation(self._client.Build(), is_compiled=False)
def Constant(self, value):
"""Enqueues a constant op onto the computation.
Args:
value: value for the constant, as a np.array with an explicit dtype set
to one of the supported types.
Returns:
A ComputationDataHandle message.
"""
value = require_numpy_array_layout(value)
return _wrap_data_handle(self._client.ConstantLiteral(value))
def ConstantF32Scalar(self, value):
"""Convenience method to enqueue a scalar F32 constant op.
Args:
value: a floating-point number.
Returns:
A ComputationDataHandle message.
"""
return self.Constant(np.array(value, dtype=np.float32))
def ConstantF64Scalar(self, value):
"""Convenience method to enqueue a scalar F32 constant op.
Args:
value: a floating-point number.
Returns:
A ComputationDataHandle message.
"""
return self.Constant(np.array(value, dtype=np.float64))
def ConstantS32Scalar(self, value):
"""Convenience method to enqueue a scalar S32 constant op.
Args:
value: a floating-point number.
Returns:
A ComputationDataHandle message.
"""
return self.Constant(np.array(value, dtype=np.int32))
def ConstantS64Scalar(self, value):
"""Convenience method to enqueue a scalar S64 constant op.
Args:
value: a floating-point number.
Returns:
A ComputationDataHandle message.
"""
return self.Constant(np.array(value, dtype=np.int64))
def ConstantPredScalar(self, value):
"""Convenience method to enqueue a scalar PRED constant op.
Args:
value: a boolean value.
Returns:
A ComputationDataHandle message.
"""
return self.Constant(np.array(value, dtype=np.bool))
def ParameterWithShape(self, shape, name=None, parameter_num=None):
"""Enqueues a Parameter op onto the computation, given a shape.
Args:
shape: the parameter's shape as a Shape object.
name: optional string name for the parameter.
parameter_num: parameter number in the computation function. If None,
the next linear parameter number is used. The default value capability
can be used for auto-numbering. If you're using auto-numbering for some
parameters, use it for *all* parameters to avoid clashes.
Returns:
A ComputationDataHandle message.
"""
if name is None:
name = ''
if parameter_num is None:
parameter_num = next(self._parameter_numbering)
return _wrap_data_handle(
self._client.Parameter(
parameter_num, _unwrap_shape(shape), name.encode('utf8')))
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
"""Enqueues a Parameter op onto the computation.
Args:
value: a Numpy array, or a nested tuple thereof, from which the
shape is inferred.
name: as in ParameterWithShape.
parameter_num: as in ParameterWithShape.
Returns:
A ComputationDataHandle message.
"""
return self.ParameterWithShape(
Shape.from_numpy(value), name=name, parameter_num=parameter_num)
def Broadcast(self, operand, sizes):
"""Enqueues a broadcast operation onto the computation.
Args:
operand: the operand ComputationDataHandle to broadcast.
sizes: an iterable of broadcast sizes.
Returns:
A ComputationDataHandle representing the added broadcast op.
"""
return _wrap_data_handle(
self._client.Broadcast(_unwrap_data_handle(operand), sizes))
def Concatenate(self, operands, dimension):
"""Enqueues a concatenate operation onto the computation.
Args:
operands: the operands to concatenate.
dimension: the dimension in which to perform the concatenation.
Returns:
A ComputationDataHandle representing the added concatenate op.
"""
return _wrap_data_handle(
self._client.ConcatInDim(_unwrap_data_handles(operands), dimension))
def ConvertElementType(self, operand, new_element_type):
"""Enqueues an element type conversion operation onto the computation.
Args:
operand: the operand to convert.
new_element_type: the target primitive type.
Returns:
A ComputationDataHandle representing the added conversion op.
"""
return _wrap_data_handle(
self._client.ConvertElementType(
_unwrap_data_handle(operand), new_element_type))
def GetShape(self, operand):
return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand)))
def GetComputationStats(self):
raise NotImplementedError()
def Reshape(self, operand, dimensions, new_sizes):
"""Reshape op."""
return _wrap_data_handle(
self._client.Reshape(
_unwrap_data_handle(operand), dimensions, new_sizes))
def Trans(self, operand):
"""Specialized matrix transpose op."""
return _wrap_data_handle(
self._client.Transpose(_unwrap_data_handle(operand), [1, 0]))
def Transpose(self, operand, permutation):
"""Transpose op."""
return _wrap_data_handle(
self._client.Transpose(_unwrap_data_handle(operand), permutation))
def Select(self, pred, on_true, on_false):
"""Element-wise selection op.
Constructs an output array from elements of two input arrays, based on the
values of a predicate array.
"""
return _wrap_data_handle(
self._client.Select(
_unwrap_data_handle(pred),
_unwrap_data_handle(on_true),
_unwrap_data_handle(on_false)))
def Slice(self, operand, start_indices, limit_indices, strides=None):
"""Enqueues a slice operation onto the computation.
Args:
operand: ComputationDataHandle for the N dimensional array to be sliced.
start_indices: iterable of N integers containing the starting indices of
the slice for each dimension.
limit_indices: iterable of N integers containing the ending indices
(exclusive) of the slice for each dimension.
strides: optional iterable of N integers containing the stride sizes for
each dimension.
Returns:
A ComputationDataHandle representing the added Slice op.
"""
if strides is None:
start_indices = list(start_indices)
strides = [1] * len(start_indices)
return _wrap_data_handle(
self._client.Slice(
_unwrap_data_handle(operand),
start_indices,
limit_indices,
strides))
def DynamicSlice(self, operand, start_indices, slice_sizes):
"""Enqueues a slice op with dynamic start indices onto the computation.
Args:
operand: ComputationDataHandle for the N dimensional array to be sliced.
start_indices: ComputationDataHandle for the 1D array of N integers
containing the starting indices of the slice.
slice_sizes: iterable of N integers containing the slice sizes in each
dimension.
Returns:
A ComputationDataHandle representing the added DynamicSlice op.
"""
return _wrap_data_handle(
self._client.DynamicSlice(
_unwrap_data_handle(operand),
_unwrap_data_handle(start_indices),
slice_sizes))
def DynamicUpdateSlice(self, operand, update, start_indices):
"""Enqueues a dynamic update slice operation onto the computation.
Args:
operand: ComputationDataHandle for the N dimensional array to be updated.
update: N dimensional array comprising the slice update.
start_indices: Rank-1 array of N integers comprising the starting indices
of the slice along each dimension.
Returns:
A ComputationDataHandle representing the added DynamicUpdateSlice op.
"""
return _wrap_data_handle(
self._client.DynamicUpdateSlice(
_unwrap_data_handle(operand),
_unwrap_data_handle(update),
_unwrap_data_handle(start_indices)))
def Tuple(self, *ops):
"""Enqueues a tuple operation onto the computation.
Args:
ops: a sequence of tuple operands (each a ComputationDataHandle).
Returns:
A ComputationDataHandle representing the added Tuple op.
"""
return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops)))
def GetTupleElement(self, tup, index):
"""Enqueues a 'get tuple element' operation onto the computation.
Args:
tup: the tuple operand (a ComputationDataHandle).
index: numeric index to select from the tuple.
Returns:
A ComputationDataHandle representing the added GetTupleElement op.
"""
return _wrap_data_handle(
self._client.GetTupleElement(_unwrap_data_handle(tup), index))
def Call(self, computation_to_apply, operands):
"""Enqueues a call operation onto the computation.
Args:
computation_to_apply: a Computation object.
operands: an iterable of ComputationDataHandle. The number and types of
operands must match the arity of computation_to_apply.
Returns:
A ComputationDataHandle representing the added call op.
"""
return _wrap_data_handle(
self._client.Call(computation_to_apply.c_local_computation,
_unwrap_data_handles(operands)))
def Map(self, operands, computation_to_apply, dimensions, static_operands=()):
"""Enqueues a map operation onto the computation.
Args:
operands: an iterable of ComputationDataHandle.
computation_to_apply: a Computation object.
dimensions: dimensions over which to apply map the function.
static_operands: auxiliary arguments passed to the applied computation.
Returns:
A ComputationDataHandle representing the added Map op.
"""
return _wrap_data_handle(
self._client.Map(
_unwrap_data_handles(operands),
computation_to_apply.c_local_computation,
dimensions,
_unwrap_data_handles(static_operands)))
def Reduce(self, operand, init_value, computation_to_apply, dimensions):
"""Enqueues a reduction operation onto the computation.
Args:
operand: reduction operand (ComputationDataHandle).
init_value: reduction initial value (ComputationDataHandle).
computation_to_apply: a Computation object - binary reduction function.
dimensions: sequence of dimensions (integers) to reduce on.
Returns:
A ComputationDataHandle representing the added Reduce op.
"""
return _wrap_data_handle(
self._client.Reduce(
_unwrap_data_handle(operand),
_unwrap_data_handle(init_value),
computation_to_apply.c_local_computation,
dimensions))
def While(self, cond, body, init):
"""Enqueues a While operation onto the computation.
Args:
cond: a Computation for the loop condition, which has type T -> PRED
body: a Computation for the loop body, which has type T -> T
init: an ComputationDataHandle for the initial parameter, which has type T
Returns: a ComputationDataHandle representing the While operation.
"""
return _wrap_data_handle(
self._client.While(cond.c_local_computation,
body.c_local_computation,
_unwrap_data_handle(init)))
def Dot(self, lhs, rhs):
"""Matrix multiplication between lhs and rhs."""
return _wrap_data_handle(
self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
def _forward_methods_to_local_builder():
"""Forward remaining ComputationBuilder methods to the C API.
Set up methods, corresponding to unary and binary XLA operations,
whose calls are forwarded in a boilerplate manner to the underlying
LocalComputationBuilder C-extension API.
"""
def forward_to_local_builder_with_handles(target_method, is_binop=False):
"""Generate a forwarding method that wraps/unwraps data handles."""
def forward(self, *args, **kwargs):
unwrapped_args = [_unwrap_data_handle(arg) for arg in args]
if is_binop and len(unwrapped_args) < 3:
unwrapped_args.append(kwargs.get('broadcast_dimensions', ()))
return _wrap_data_handle(
target_method(
self._client, # pylint: disable=protected-access
*unwrapped_args))
return forward
for method_name in _UNARY_OPS:
forward = forward_to_local_builder_with_handles(
getattr(c_api.LocalComputationBuilder, method_name))
forward.__name__ = method_name
setattr(ComputationBuilder, method_name, forward)
for method_name in _BINARY_OPS:
forward = forward_to_local_builder_with_handles(
getattr(c_api.LocalComputationBuilder, method_name), is_binop=True)
forward.__name__ = method_name
setattr(ComputationBuilder, method_name, forward)
_forward_methods_to_local_builder()

View File

@ -0,0 +1,898 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the Python extension-based XLA client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from tensorflow.compiler.xla.python import xla_client
import unittest
class LocalComputationTest(unittest.TestCase):
"""Base class for running an XLA Computation through the local client."""
def _NewComputation(self, name=None):
if name is None:
name = self.id()
return xla_client.ComputationBuilder(name)
def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
assert expected is not None
compiled_c = c.Build().CompileWithExampleArguments(arguments)
result = compiled_c.Execute(arguments)
# Numpy's comparison methods are a bit too lenient by treating inputs as
# "array-like", meaning that scalar 4 will be happily compared equal to
# [[4]]. We'd like to be more strict so assert shapes as well.
self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape)
assert_func(result, expected)
def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
def _ExecuteAndCompareClose(self, c, arguments=(), expected=None):
self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments,
expected)
def NumpyArrayF32(*args, **kwargs):
"""Convenience wrapper to create Numpy arrays with a np.float32 dtype."""
return np.array(*args, dtype=np.float32, **kwargs)
def NumpyArrayF64(*args, **kwargs):
"""Convenience wrapper to create Numpy arrays with a np.float64 dtype."""
return np.array(*args, dtype=np.float64, **kwargs)
def NumpyArrayS32(*args, **kwargs):
"""Convenience wrapper to create Numpy arrays with a np.int32 dtype."""
return np.array(*args, dtype=np.int32, **kwargs)
def NumpyArrayS64(*args, **kwargs):
"""Convenience wrapper to create Numpy arrays with a np.int64 dtype."""
return np.array(*args, dtype=np.int64, **kwargs)
def NumpyArrayBool(*args, **kwargs):
"""Convenience wrapper to create Numpy arrays with a np.bool dtype."""
return np.array(*args, dtype=np.bool, **kwargs)
class ComputationsWithConstantsTest(LocalComputationTest):
"""Tests focusing on Constant ops."""
def testConstantScalarSumF32(self):
c = self._NewComputation()
c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
self._ExecuteAndCompareClose(c, expected=4.25)
def testConstantScalarSumF64(self):
c = self._NewComputation()
c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14))
self._ExecuteAndCompareClose(c, expected=4.25)
def testConstantScalarSumS32(self):
c = self._NewComputation()
c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2))
self._ExecuteAndCompareClose(c, expected=3)
def testConstantScalarSumS64(self):
c = self._NewComputation()
c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2))
self._ExecuteAndCompareClose(c, expected=3)
def testConstantVectorMulF32(self):
c = self._NewComputation()
c.Mul(
c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])),
c.Constant(NumpyArrayF32([-1.2, 2, -2, -3])))
self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
def testConstantVectorMulF64(self):
c = self._NewComputation()
c.Mul(
c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])),
c.Constant(NumpyArrayF64([-1.2, 2, -2, -3])))
self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
def testConstantVectorScalarDivF32(self):
c = self._NewComputation()
c.Div(
c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])),
c.ConstantF32Scalar(2.0))
self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
def testConstantVectorScalarDivF64(self):
c = self._NewComputation()
c.Div(
c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])),
c.ConstantF64Scalar(2.0))
self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
def testConstantVectorScalarPowF32(self):
c = self._NewComputation()
c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.))
self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
def testConstantVectorScalarPowF64(self):
c = self._NewComputation()
c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.))
self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
def testBooleanAnd(self):
c = self._NewComputation()
c.And(
c.Constant(NumpyArrayBool([True, False, True, False])),
c.Constant(NumpyArrayBool([True, True, False, False])))
self._ExecuteAndCompareExact(c, expected=[True, False, False, False])
def testBooleanOr(self):
c = self._NewComputation()
c.Or(
c.Constant(NumpyArrayBool([True, False, True, False])),
c.Constant(NumpyArrayBool([True, True, False, False])))
self._ExecuteAndCompareExact(c, expected=[True, True, True, False])
def testSum2DF32(self):
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])),
c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
def testSum2DF64(self):
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])),
c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]])))
self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
def testSum2DWith1DBroadcastDim0F32(self):
# sum of a 2D array with a 1D array where the latter is replicated across
# dimension 0 to match the former's shape.
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF32([10, 20, 30])),
broadcast_dimensions=(0,))
self._ExecuteAndCompareClose(
c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
def testSum2DWith1DBroadcastDim0F64(self):
# sum of a 2D array with a 1D array where the latter is replicated across
# dimension 0 to match the former's shape.
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF64([10, 20, 30])),
broadcast_dimensions=(0,))
self._ExecuteAndCompareClose(
c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
def testSum2DWith1DBroadcastDim1F32(self):
# sum of a 2D array with a 1D array where the latter is replicated across
# dimension 1 to match the former's shape.
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF32([10, 20, 30])),
broadcast_dimensions=(1,))
self._ExecuteAndCompareClose(
c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
def testSum2DWith1DBroadcastDim1F64(self):
# sum of a 2D array with a 1D array where the latter is replicated across
# dimension 1 to match the former's shape.
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF64([10, 20, 30])),
broadcast_dimensions=(1,))
self._ExecuteAndCompareClose(
c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
def testConstantAxpyF32(self):
c = self._NewComputation()
c.Add(
c.Mul(
c.ConstantF32Scalar(2),
c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))),
c.Constant(NumpyArrayF32([100, -100, 200, -200])))
self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
def testConstantAxpyF64(self):
c = self._NewComputation()
c.Add(
c.Mul(
c.ConstantF64Scalar(2),
c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))),
c.Constant(NumpyArrayF64([100, -100, 200, -200])))
self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
class ParametersTest(LocalComputationTest):
"""Tests focusing on Parameter ops and argument-passing."""
def setUp(self):
self.f32_scalar_2 = NumpyArrayF32(2.0)
self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3])
self.f64_scalar_2 = NumpyArrayF64(2.0)
self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3])
self.s32_scalar_3 = NumpyArrayS32(3)
self.s32_4vector = NumpyArrayS32([10, 15, -2, 7])
self.s64_scalar_3 = NumpyArrayS64(3)
self.s64_4vector = NumpyArrayS64([10, 15, -2, 7])
def testScalarTimesVectorAutonumberF32(self):
c = self._NewComputation()
p0 = c.ParameterFromNumpy(self.f32_scalar_2)
p1 = c.ParameterFromNumpy(self.f32_4vector)
c.Mul(p0, p1)
self._ExecuteAndCompareClose(
c,
arguments=[self.f32_scalar_2, self.f32_4vector],
expected=[-4.6, 6.6, -8.6, 10.6])
def testScalarTimesVectorAutonumberF64(self):
c = self._NewComputation()
p0 = c.ParameterFromNumpy(self.f64_scalar_2)
p1 = c.ParameterFromNumpy(self.f64_4vector)
c.Mul(p0, p1)
self._ExecuteAndCompareClose(
c,
arguments=[self.f64_scalar_2, self.f64_4vector],
expected=[-4.6, 6.6, -8.6, 10.6])
def testScalarTimesVectorS32(self):
c = self._NewComputation()
p0 = c.ParameterFromNumpy(self.s32_scalar_3)
p1 = c.ParameterFromNumpy(self.s32_4vector)
c.Mul(p0, p1)
self._ExecuteAndCompareExact(
c,
arguments=[self.s32_scalar_3, self.s32_4vector],
expected=[30, 45, -6, 21])
def testScalarTimesVectorS64(self):
c = self._NewComputation()
p0 = c.ParameterFromNumpy(self.s64_scalar_3)
p1 = c.ParameterFromNumpy(self.s64_4vector)
c.Mul(p0, p1)
self._ExecuteAndCompareExact(
c,
arguments=[self.s64_scalar_3, self.s64_4vector],
expected=[30, 45, -6, 21])
def testScalarMinusVectorExplicitNumberingF32(self):
# Use explicit numbering and pass parameter_num first. Sub is used since
# it's not commutative and can help catch parameter reversal within the
# computation.
c = self._NewComputation()
p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1)
p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0)
c.Sub(p1, p0)
self._ExecuteAndCompareClose(
c,
arguments=[self.f32_scalar_2, self.f32_4vector],
expected=[-4.3, 1.3, -6.3, 3.3])
def testScalarMinusVectorExplicitNumberingF64(self):
# Use explicit numbering and pass parameter_num first. Sub is used since
# it's not commutative and can help catch parameter reversal within the
# computation.
c = self._NewComputation()
p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1)
p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0)
c.Sub(p1, p0)
self._ExecuteAndCompareClose(
c,
arguments=[self.f64_scalar_2, self.f64_4vector],
expected=[-4.3, 1.3, -6.3, 3.3])
class SingleOpTest(LocalComputationTest):
"""Tests for single ops.
The goal here is smoke testing - to exercise the most basic functionality of
single XLA ops. As minimal as possible number of additional ops are added
around the op being tested.
"""
def testConcatenateF32(self):
c = self._NewComputation()
c.Concatenate(
(c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))),
dimension=0)
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
def testConcatenateF64(self):
c = self._NewComputation()
c.Concatenate(
(c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))),
dimension=0)
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
def testConvertElementType(self):
xla_types = {
np.bool: xla_client.xla_data_pb2.PRED,
np.int32: xla_client.xla_data_pb2.S32,
np.int64: xla_client.xla_data_pb2.S64,
np.float32: xla_client.xla_data_pb2.F32,
np.float64: xla_client.xla_data_pb2.F64,
}
def _ConvertAndTest(template, src_dtype, dst_dtype):
c = self._NewComputation()
x = c.Constant(np.array(template, dtype=src_dtype))
c.ConvertElementType(x, xla_types[dst_dtype])
result = c.Build().Compile().Execute()
expected = np.array(template, dtype=dst_dtype)
self.assertEqual(result.shape, expected.shape)
self.assertEqual(result.dtype, expected.dtype)
np.testing.assert_equal(result, expected)
x = [0, 1, 0, 0, 1]
for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
_ConvertAndTest(x, src_dtype, dst_dtype)
def testDotMatrixVectorF32(self):
c = self._NewComputation()
lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
rhs = NumpyArrayF32([[10.0], [20.0]])
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testDotMatrixVectorF64(self):
c = self._NewComputation()
lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
rhs = NumpyArrayF64([[10.0], [20.0]])
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testDotMatrixMatrixF32(self):
c = self._NewComputation()
lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]])
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testDotMatrixMatrixF64(self):
c = self._NewComputation()
lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]])
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])
c.Not(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=~arr)
def testExp(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Exp(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
def testLog(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Log(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.log(arr))
def testNeg(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Neg(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=-arr)
def testFloor(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Floor(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.floor(arr))
def testCeil(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Ceil(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.ceil(arr))
def testAbs(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.])
c.Abs(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.abs(arr))
def testTanh(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Tanh(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.tanh(arr))
def testTrans(self):
def _TransposeAndTest(array):
c = self._NewComputation()
c.Trans(c.Constant(array))
self._ExecuteAndCompareClose(c, expected=array.T)
# Test square and non-square matrices in both default (C) and F orders.
for array_fun in [NumpyArrayF32, NumpyArrayF64]:
_TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]]))
_TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F"))
_TransposeAndTest(array_fun([[1, 2], [4, 5]]))
_TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F"))
def testTranspose(self):
def _TransposeAndTest(array, permutation):
c = self._NewComputation()
c.Transpose(c.Constant(array), permutation)
expected = np.transpose(array, permutation)
self._ExecuteAndCompareClose(c, expected=expected)
_TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1])
_TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0])
_TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1])
_TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0])
arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32)
for permutation in itertools.permutations(range(arr.ndim)):
_TransposeAndTest(arr, permutation)
_TransposeAndTest(np.asfortranarray(arr), permutation)
def testEq(self):
c = self._NewComputation()
c.Eq(
c.Constant(NumpyArrayS32([1, 2, 3, 4])),
c.Constant(NumpyArrayS32([4, 2, 3, 1])))
self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
def testNe(self):
c = self._NewComputation()
c.Ne(
c.Constant(NumpyArrayS32([1, 2, 3, 4])),
c.Constant(NumpyArrayS32([4, 2, 3, 1])))
self._ExecuteAndCompareExact(c, expected=[True, False, False, True])
c.Ne(
c.Constant(NumpyArrayF32([-2.0, 0.0,
float("nan"),
float("nan")])),
c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")])))
self._ExecuteAndAssertWith(
np.testing.assert_allclose, c, (), expected=[True, False, True, True])
def testGt(self):
c = self._NewComputation()
c.Gt(
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False])
def testGe(self):
c = self._NewComputation()
c.Ge(
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False])
def testLt(self):
c = self._NewComputation()
c.Lt(
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True])
def testLe(self):
c = self._NewComputation()
c.Le(
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True])
def testMax(self):
c = self._NewComputation()
c.Max(
c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0])
def testMaxExplicitBroadcastDim0(self):
c = self._NewComputation()
c.Max(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF32([3, 4, 5])),
broadcast_dimensions=(0,))
self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]])
def testMaxExplicitBroadcastDim1(self):
c = self._NewComputation()
c.Max(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF32([3, 4, 5])),
broadcast_dimensions=(1,))
self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]])
def testMin(self):
c = self._NewComputation()
c.Min(
c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0])
def testReshape(self):
c = self._NewComputation()
c.Reshape(
c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])),
dimensions=[0, 1],
new_sizes=[2, 3])
self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]])
def testSelect(self):
c = self._NewComputation()
c.Select(
c.Constant(NumpyArrayBool([True, False, False, True, False])),
c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])),
c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5])))
self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5])
def testSlice(self):
c = self._NewComputation()
c.Slice(
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0],
[3, 2])
self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
def testDynamicSlice(self):
c = self._NewComputation()
c.DynamicSlice(
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayS32([1, 0])), [2, 2])
self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
def testDynamicUpdateSlice(self):
c = self._NewComputation()
c.DynamicUpdateSlice(
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayS32([[1, 2], [3, 4]])),
c.Constant(NumpyArrayS32([1, 1])))
self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]])
def testTuple(self):
c = self._NewComputation()
c.Tuple(
c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
c.Constant(NumpyArrayBool([True, False, False, True])))
result = c.Build().Compile().Execute()
self.assertIsInstance(result, tuple)
np.testing.assert_equal(result[0], 42)
np.testing.assert_allclose(result[1], [1.0, 2.0])
np.testing.assert_equal(result[2], [True, False, False, True])
def testGetTupleElement(self):
c = self._NewComputation()
c.GetTupleElement(
c.Tuple(
c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
c.Constant(NumpyArrayBool([True, False, False, True]))), 1)
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0])
def testBroadcast(self):
c = self._NewComputation()
c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,))
self._ExecuteAndCompareExact(
c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]])
class EmbeddedComputationsTest(LocalComputationTest):
"""Tests for XLA graphs with embedded computations (such as maps)."""
def _CreateConstantS32Computation(self):
"""Computation (f32) -> s32 that returns a constant 1 for any input."""
c = self._NewComputation("constant_s32_one")
# TODO(eliben): consider adding a nicer way to create new parameters without
# having to create dummy Numpy arrays or populating Shape messages. Perhaps
# we need our own (Python-client-own) way to represent Shapes conveniently.
c.ParameterFromNumpy(NumpyArrayF32(0))
c.ConstantS32Scalar(1)
return c.Build()
def _CreateConstantS64Computation(self):
"""Computation (f64) -> s64 that returns a constant 1 for any input."""
c = self._NewComputation("constant_s64_one")
# TODO(eliben): consider adding a nicer way to create new parameters without
# having to create dummy Numpy arrays or populating Shape messages. Perhaps
# we need our own (Python-client-own) way to represent Shapes conveniently.
c.ParameterFromNumpy(NumpyArrayF64(0))
c.ConstantS64Scalar(1)
return c.Build()
def _CreateConstantF32Computation(self):
"""Computation (f32) -> f32 that returns a constant 1.0 for any input."""
c = self._NewComputation("constant_f32_one")
c.ParameterFromNumpy(NumpyArrayF32(0))
c.ConstantF32Scalar(1.0)
return c.Build()
def _CreateConstantF64Computation(self):
"""Computation (f64) -> f64 that returns a constant 1.0 for any input."""
c = self._NewComputation("constant_f64_one")
c.ParameterFromNumpy(NumpyArrayF64(0))
c.ConstantF64Scalar(1.0)
return c.Build()
def _CreateMulF32By2Computation(self):
"""Computation (f32) -> f32 that multiplies its parameter by 2."""
c = self._NewComputation("mul_f32_by2")
c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0))
return c.Build()
def _CreateMulF64By2Computation(self):
"""Computation (f64) -> f64 that multiplies its parameter by 2."""
c = self._NewComputation("mul_f64_by2")
c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0))
return c.Build()
def _CreateBinaryAddF32Computation(self):
"""Computation (f32, f32) -> f32 that adds its two parameters."""
c = self._NewComputation("add_param0_by_param1")
c.Add(
c.ParameterFromNumpy(NumpyArrayF32(0)),
c.ParameterFromNumpy(NumpyArrayF32(0)))
return c.Build()
def _CreateBinaryAddF64Computation(self):
"""Computation (f64, f64) -> f64 that adds its two parameters."""
c = self._NewComputation("add_param0_by_param1")
c.Add(
c.ParameterFromNumpy(NumpyArrayF64(0)),
c.ParameterFromNumpy(NumpyArrayF64(0)))
return c.Build()
def _CreateBinaryDivF32Computation(self):
"""Computation (f32, f32) -> f32 that divides its two parameters."""
c = self._NewComputation("div_param0_by_param1")
c.Div(
c.ParameterFromNumpy(NumpyArrayF32(0)),
c.ParameterFromNumpy(NumpyArrayF32(0)))
return c.Build()
def _CreateBinaryDivF64Computation(self):
"""Computation (f64, f64) -> f64 that divides its two parameters."""
c = self._NewComputation("div_param0_by_param1")
c.Div(
c.ParameterFromNumpy(NumpyArrayF64(0)),
c.ParameterFromNumpy(NumpyArrayF64(0)))
return c.Build()
def _CreateTestF32Lt10Computation(self):
"""Computation (f32) -> bool that tests if its parameter is less than 10."""
c = self._NewComputation("test_f32_lt_10")
c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.))
return c.Build()
def _CreateTestF64Lt10Computation(self):
"""Computation (f64) -> bool that tests if its parameter is less than 10."""
c = self._NewComputation("test_f64_lt_10")
c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.))
return c.Build()
def _MakeSample3DArrayF32(self):
return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
def _MakeSample3DArrayF64(self):
return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
def testCallF32(self):
c = self._NewComputation()
c.Call(
self._CreateMulF32By2Computation(),
operands=(c.ConstantF32Scalar(5.0),))
self._ExecuteAndCompareClose(c, expected=10.0)
def testCallF64(self):
c = self._NewComputation()
c.Call(
self._CreateMulF64By2Computation(),
operands=(c.ConstantF64Scalar(5.0),))
self._ExecuteAndCompareClose(c, expected=10.0)
def testMapEachElementToS32Constant(self):
c = self._NewComputation()
c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
self._CreateConstantS32Computation(), [0])
self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
def testMapEachElementToS64Constant(self):
c = self._NewComputation()
c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
self._CreateConstantS64Computation(), [0])
self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
def testMapMulBy2F32(self):
c = self._NewComputation()
c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
self._CreateMulF32By2Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
def testMapMulBy2F64(self):
c = self._NewComputation()
c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
self._CreateMulF64By2Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
def testSimpleMapChainF32(self):
# Chains a map of constant-f32 with a map of mul-by-2
c = self._NewComputation()
const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
self._CreateConstantF32Computation(), [0])
c.Map([const_f32], self._CreateMulF32By2Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
def testSimpleMapChainF64(self):
# Chains a map of constant-f64 with a map of mul-by-2
c = self._NewComputation()
const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
self._CreateConstantF64Computation(), [0])
c.Map([const_f64], self._CreateMulF64By2Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
def testDivVectorsWithMapF32(self):
c = self._NewComputation()
c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))),
self._CreateBinaryDivF32Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
def testDivVectorsWithMapF64(self):
c = self._NewComputation()
c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))),
self._CreateBinaryDivF64Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
def testReduce1DtoScalarF32(self):
c = self._NewComputation()
c.Reduce(
operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
dimensions=[0])
self._ExecuteAndCompareClose(c, expected=10)
def testReduce1DtoScalarF64(self):
c = self._NewComputation()
c.Reduce(
operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
dimensions=[0])
self._ExecuteAndCompareClose(c, expected=10)
def testReduce2DTo1DDim0F32(self):
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
dimensions=[0])
self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
def testReduce2DTo1DDim0F64(self):
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
dimensions=[0])
self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
def testReduce2DTo1DDim1F32(self):
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
dimensions=[1])
self._ExecuteAndCompareClose(c, expected=[6, 15])
def testReduce2DTo1DDim1F64(self):
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
dimensions=[1])
self._ExecuteAndCompareClose(c, expected=[6, 15])
def testReduce3DAllPossibleWaysF32(self):
input_array = self._MakeSample3DArrayF32()
def _ReduceAndTest(*dims):
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
dimensions=dims)
self._ExecuteAndCompareClose(
c, expected=np.sum(input_array, axis=tuple(dims)))
_ReduceAndTest(0)
_ReduceAndTest(0)
_ReduceAndTest(0, 1)
_ReduceAndTest(0, 2)
_ReduceAndTest(1, 2)
_ReduceAndTest(0, 1, 2)
def testReduce3DAllPossibleWaysF64(self):
input_array = self._MakeSample3DArrayF64()
def _ReduceAndTest(*dims):
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
dimensions=dims)
self._ExecuteAndCompareClose(
c, expected=np.sum(input_array, axis=tuple(dims)))
_ReduceAndTest(0)
_ReduceAndTest(0)
_ReduceAndTest(0, 1)
_ReduceAndTest(0, 2)
_ReduceAndTest(1, 2)
_ReduceAndTest(0, 1, 2)
def testWhileF32(self):
cond = self._CreateTestF32Lt10Computation()
body = self._CreateMulF32By2Computation()
c = self._NewComputation()
init = c.ConstantF32Scalar(1.)
c.While(cond, body, init)
self._ExecuteAndCompareClose(c, expected=16.)
def testWhileF64(self):
cond = self._CreateTestF64Lt10Computation()
body = self._CreateMulF64By2Computation()
c = self._NewComputation()
init = c.ConstantF64Scalar(1.)
c.While(cond, body, init)
self._ExecuteAndCompareClose(c, expected=16.)
if __name__ == "__main__":
unittest.main()

View File

@ -4,3 +4,4 @@
*TF_*
*TFE_*
*nsync_*
*pywrap_xla*

View File

@ -5,6 +5,7 @@ tensorflow {
*TF_*;
*TFE_*;
*nsync_*;
*pywrap_xla*;
local:
*;
};