Python library and C++ bindings for creating and compiling local XLA computations.
PiperOrigin-RevId: 179211353
This commit is contained in:
parent
22fe6558a9
commit
75a91cf3be
@ -411,6 +411,7 @@ filegroup(
|
||||
"//tensorflow/compiler/xla/client:all_files",
|
||||
"//tensorflow/compiler/xla/client/lib:all_files",
|
||||
"//tensorflow/compiler/xla/legacy_flags:all_files",
|
||||
"//tensorflow/compiler/xla/python:all_files",
|
||||
"//tensorflow/compiler/xla/service:all_files",
|
||||
"//tensorflow/compiler/xla/service/cpu:all_files",
|
||||
"//tensorflow/compiler/xla/service/gpu:all_files",
|
||||
|
@ -20,6 +20,10 @@ package_group(
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_proto_library_py",
|
||||
)
|
||||
|
||||
# Filegroup used to collect source files for dependency checking.
|
||||
filegroup(
|
||||
@ -36,6 +40,12 @@ xla_proto_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_proto_library_py(
|
||||
name = "xla_data_proto", # bzl adds a _py suffix
|
||||
srcs = ["xla_data.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
xla_proto_library(
|
||||
name = "xla_proto",
|
||||
srcs = ["xla.proto"],
|
||||
|
82
tensorflow/compiler/xla/python/BUILD
Normal file
82
tensorflow/compiler/xla/python/BUILD
Normal file
@ -0,0 +1,82 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
|
||||
py_library(
|
||||
name = "xla_client",
|
||||
srcs = ["xla_client.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":pywrap_xla",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "xla_client_test",
|
||||
srcs = ["xla_client_test.py"],
|
||||
main = "xla_client_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":xla_client",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "numpy_bridge",
|
||||
srcs = ["numpy_bridge.cc"],
|
||||
hdrs = ["numpy_bridge.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/python:numpy_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "local_computation_builder",
|
||||
srcs = ["local_computation_builder.cc"],
|
||||
hdrs = ["local_computation_builder.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_wrap_cc(
|
||||
name = "pywrap_xla",
|
||||
srcs = ["xla.i"],
|
||||
swig_includes = [
|
||||
"local_computation_builder.i",
|
||||
],
|
||||
deps = [
|
||||
":local_computation_builder",
|
||||
":numpy_bridge",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
0
tensorflow/compiler/xla/python/__init__.py
Normal file
0
tensorflow/compiler/xla/python/__init__.py
Normal file
265
tensorflow/compiler/xla/python/local_computation_builder.cc
Normal file
265
tensorflow/compiler/xla/python/local_computation_builder.cc
Normal file
@ -0,0 +1,265 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace swig {
|
||||
|
||||
CompiledLocalComputation::CompiledLocalComputation(
|
||||
std::unique_ptr<LocalExecutable> executable)
|
||||
: executable_(std::move(executable)) {}
|
||||
|
||||
std::unique_ptr<Literal> CompiledLocalComputation::Execute(
|
||||
const std::vector<Literal>& arguments) {
|
||||
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
||||
|
||||
// Transfer arguments in
|
||||
std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
|
||||
scoped_buffers.reserve(arguments.size());
|
||||
for (const Literal& argument : arguments) {
|
||||
scoped_buffers.push_back(
|
||||
client
|
||||
->LiteralToShapedBuffer(argument,
|
||||
/*device_ordinal=*/0,
|
||||
client->backend().memory_allocator())
|
||||
.ConsumeValueOrDie());
|
||||
}
|
||||
|
||||
// Execute
|
||||
std::vector<const ShapedBuffer*> argument_buffers;
|
||||
argument_buffers.reserve(scoped_buffers.size());
|
||||
for (auto& buffer : scoped_buffers) {
|
||||
argument_buffers.push_back(buffer.get());
|
||||
}
|
||||
ExecutableRunOptions options;
|
||||
options.set_allocator(client->backend().memory_allocator());
|
||||
options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool());
|
||||
options.set_intra_op_thread_pool(
|
||||
client->backend().eigen_intra_op_thread_pool_device());
|
||||
std::unique_ptr<ScopedShapedBuffer> result_buffer =
|
||||
executable_->Run(argument_buffers, options).ConsumeValueOrDie();
|
||||
|
||||
// Transfer result out
|
||||
return client->ShapedBufferToLiteral(*result_buffer).ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
LocalComputation::LocalComputation(std::unique_ptr<Computation> computation)
|
||||
: computation_(std::move(computation)) {}
|
||||
|
||||
CompiledLocalComputation* LocalComputation::Compile(
|
||||
const std::vector<Shape>& argument_shapes) {
|
||||
std::vector<const Shape*> argument_shape_pointers;
|
||||
argument_shape_pointers.reserve(argument_shapes.size());
|
||||
for (auto& argument_shape : argument_shapes) {
|
||||
argument_shape_pointers.push_back(&argument_shape);
|
||||
}
|
||||
|
||||
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
||||
ExecutableBuildOptions options;
|
||||
return new CompiledLocalComputation(
|
||||
client->Compile(*computation_, argument_shape_pointers, options)
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
const Computation& LocalComputation::computation() const {
|
||||
return *computation_;
|
||||
}
|
||||
|
||||
LocalComputationBuilder::LocalComputationBuilder(const string& computation_name)
|
||||
: builder_(ClientLibrary::LocalClientOrDie(), computation_name) {}
|
||||
|
||||
LocalComputation* LocalComputationBuilder::Build() {
|
||||
return new LocalComputation(std::unique_ptr<Computation>(
|
||||
new Computation(builder_.Build().ConsumeValueOrDie())));
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number,
|
||||
const Shape& shape,
|
||||
const string& name) {
|
||||
return builder_.Parameter(parameter_number, shape, name);
|
||||
}
|
||||
|
||||
std::unique_ptr<Shape> LocalComputationBuilder::GetShape(
|
||||
const ComputationDataHandle& operand) {
|
||||
return builder_.GetShape(operand).ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::ConstantLiteral(
|
||||
const Literal& literal) {
|
||||
return builder_.ConstantLiteral(literal);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Broadcast(
|
||||
const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
|
||||
return builder_.Broadcast(operand, broadcast_sizes);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Reshape(
|
||||
const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> new_sizes) {
|
||||
return builder_.Reshape(operand, dimensions, new_sizes);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Slice(
|
||||
const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> strides) {
|
||||
return builder_.Slice(operand, start_indices, limit_indices, strides);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::DynamicSlice(
|
||||
const ComputationDataHandle& operand,
|
||||
const ComputationDataHandle& start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
|
||||
return builder_.DynamicSlice(operand, start_indices, slice_sizes);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice(
|
||||
const ComputationDataHandle& operand, const ComputationDataHandle& update,
|
||||
const ComputationDataHandle& start_indices) {
|
||||
return builder_.DynamicUpdateSlice(operand, update, start_indices);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::ConcatInDim(
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
|
||||
int64 dimension) {
|
||||
return builder_.ConcatInDim(operands, dimension);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Select(
|
||||
const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
|
||||
const ComputationDataHandle& on_false) {
|
||||
return builder_.Select(pred, on_true, on_false);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Tuple(
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
|
||||
return builder_.Tuple(elements);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::GetTupleElement(
|
||||
const ComputationDataHandle& tuple_data, int64 index) {
|
||||
return builder_.GetTupleElement(tuple_data, index);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Dot(
|
||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
|
||||
return builder_.Dot(lhs, rhs);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::ConvertElementType(
|
||||
const ComputationDataHandle& operand, PrimitiveType new_element_type) {
|
||||
return builder_.ConvertElementType(operand, new_element_type);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Call(
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
|
||||
return builder_.Call(local_computation.computation(), operands);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Transpose(
|
||||
const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> permutation) {
|
||||
return builder_.Transpose(operand, permutation);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Map(
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
|
||||
return builder_.Map(operands, local_computation.computation(), dimensions,
|
||||
static_operands);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Reduce(
|
||||
const ComputationDataHandle& operand,
|
||||
const ComputationDataHandle& init_value,
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
|
||||
return builder_.Reduce(operand, init_value, local_computation.computation(),
|
||||
dimensions_to_reduce);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::While(
|
||||
const LocalComputation& condition, const LocalComputation& body,
|
||||
const ComputationDataHandle& init) {
|
||||
return builder_.While(condition.computation(), body.computation(), init);
|
||||
}
|
||||
|
||||
#define _FORWARD(method_name, return_sig, args_sig, args) \
|
||||
return_sig LocalComputationBuilder::method_name args_sig { \
|
||||
return builder_.method_name args; \
|
||||
}
|
||||
|
||||
#define _FORWARD_UNOP(method_name) \
|
||||
_FORWARD(method_name, ComputationDataHandle, \
|
||||
(const ComputationDataHandle& operand), (operand))
|
||||
|
||||
#define _FORWARD_BINOP(method_name) \
|
||||
_FORWARD( \
|
||||
method_name, ComputationDataHandle, \
|
||||
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
|
||||
(lhs, rhs, broadcast_dimensions))
|
||||
|
||||
_FORWARD_BINOP(Eq)
|
||||
_FORWARD_BINOP(Ne)
|
||||
_FORWARD_BINOP(Ge)
|
||||
_FORWARD_BINOP(Gt)
|
||||
_FORWARD_BINOP(Lt)
|
||||
_FORWARD_BINOP(Le)
|
||||
_FORWARD_BINOP(Add)
|
||||
_FORWARD_BINOP(Sub)
|
||||
_FORWARD_BINOP(Mul)
|
||||
_FORWARD_BINOP(Div)
|
||||
_FORWARD_BINOP(Rem)
|
||||
_FORWARD_BINOP(Max)
|
||||
_FORWARD_BINOP(Min)
|
||||
_FORWARD_BINOP(And)
|
||||
_FORWARD_BINOP(Or)
|
||||
_FORWARD_UNOP(Not)
|
||||
_FORWARD_UNOP(Abs)
|
||||
_FORWARD_UNOP(Exp)
|
||||
_FORWARD_UNOP(Floor)
|
||||
_FORWARD_UNOP(Ceil)
|
||||
_FORWARD_UNOP(Log)
|
||||
_FORWARD_UNOP(Sign)
|
||||
_FORWARD_UNOP(Cos)
|
||||
_FORWARD_UNOP(Sin)
|
||||
_FORWARD_UNOP(Tanh)
|
||||
_FORWARD_UNOP(SqrtF32)
|
||||
_FORWARD_UNOP(SquareF32)
|
||||
_FORWARD_BINOP(Pow)
|
||||
_FORWARD_UNOP(IsFinite)
|
||||
_FORWARD_UNOP(ReciprocalF32)
|
||||
_FORWARD_UNOP(Neg)
|
||||
_FORWARD_UNOP(Sort)
|
||||
|
||||
#undef _FORWARD
|
||||
#undef _FORWARD_UNOP
|
||||
#undef _FORWARD_BINOP
|
||||
|
||||
} // namespace swig
|
||||
|
||||
} // namespace xla
|
210
tensorflow/compiler/xla/python/local_computation_builder.h
Normal file
210
tensorflow/compiler/xla/python/local_computation_builder.h
Normal file
@ -0,0 +1,210 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace swig {
|
||||
|
||||
// Wraps a LocalExecutable produced by compiling a
|
||||
// LocalComputation. The Execute method forwards to that of the
|
||||
// underlying LocalExecutable, and additionally handles tranferring
|
||||
// arguments and return values in and back out of the client library's
|
||||
// local client. This class is intended to be made available to Python
|
||||
// via SWIG.
|
||||
class CompiledLocalComputation {
|
||||
public:
|
||||
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
|
||||
std::unique_ptr<Literal> Execute(const std::vector<Literal>& arguments);
|
||||
|
||||
private:
|
||||
std::unique_ptr<LocalExecutable> executable_;
|
||||
};
|
||||
|
||||
// Wraps a Computation produced by a LocalComputationBuilder. The
|
||||
// Compile method compiles the computation to a (local) executable via
|
||||
// the client library's local client. This class is intended to be
|
||||
// made available to Python via SWIG.
|
||||
class LocalComputation {
|
||||
public:
|
||||
LocalComputation(std::unique_ptr<Computation> computation);
|
||||
CompiledLocalComputation* Compile(const std::vector<Shape>& argument_shapes);
|
||||
const Computation& computation() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<Computation> computation_;
|
||||
};
|
||||
|
||||
// Wraps the ComputationBuilder API in order to:
|
||||
// - Support consumption by SWIG in order to be made available to
|
||||
// Python.
|
||||
// - Set up the underlying builder to use the client library's
|
||||
// LocalClient.
|
||||
// - Wrap Computations in LocalComputations for Python access.
|
||||
// - Correspondingly unwrap incoming LocalComputations.
|
||||
class LocalComputationBuilder {
|
||||
public:
|
||||
LocalComputationBuilder(const string& computation_name);
|
||||
|
||||
LocalComputation* Build();
|
||||
|
||||
ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape,
|
||||
const string& name);
|
||||
|
||||
std::unique_ptr<Shape> GetShape(const ComputationDataHandle& operand);
|
||||
|
||||
ComputationDataHandle ConstantLiteral(const Literal& literal);
|
||||
|
||||
ComputationDataHandle Broadcast(
|
||||
const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
|
||||
|
||||
ComputationDataHandle Reshape(const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> new_sizes);
|
||||
|
||||
ComputationDataHandle Slice(const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> strides);
|
||||
|
||||
ComputationDataHandle DynamicSlice(
|
||||
const ComputationDataHandle& operand,
|
||||
const ComputationDataHandle& start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> slice_sizes);
|
||||
|
||||
ComputationDataHandle DynamicUpdateSlice(
|
||||
const ComputationDataHandle& operand, const ComputationDataHandle& update,
|
||||
const ComputationDataHandle& start_indices);
|
||||
|
||||
ComputationDataHandle ConcatInDim(
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
|
||||
int64 dimension);
|
||||
|
||||
ComputationDataHandle Select(const ComputationDataHandle& pred,
|
||||
const ComputationDataHandle& on_true,
|
||||
const ComputationDataHandle& on_false);
|
||||
|
||||
ComputationDataHandle Tuple(
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
|
||||
|
||||
ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data,
|
||||
int64 index);
|
||||
|
||||
ComputationDataHandle Dot(const ComputationDataHandle& lhs,
|
||||
const ComputationDataHandle& rhs);
|
||||
|
||||
ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
|
||||
PrimitiveType new_element_type);
|
||||
|
||||
ComputationDataHandle Call(
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);
|
||||
|
||||
ComputationDataHandle Transpose(
|
||||
const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> permutation);
|
||||
|
||||
ComputationDataHandle Map(
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands);
|
||||
|
||||
ComputationDataHandle Reduce(
|
||||
const ComputationDataHandle& operand,
|
||||
const ComputationDataHandle& init_value,
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
|
||||
|
||||
ComputationDataHandle While(const LocalComputation& condition,
|
||||
const LocalComputation& body,
|
||||
const ComputationDataHandle& init);
|
||||
|
||||
#define _FORWARD(method_name, return_sig, args_sig) \
|
||||
return_sig method_name args_sig;
|
||||
|
||||
#define _FORWARD_UNOP(method_name) \
|
||||
_FORWARD(method_name, ComputationDataHandle, \
|
||||
(const ComputationDataHandle& operand))
|
||||
|
||||
#define _FORWARD_BINOP(method_name) \
|
||||
_FORWARD( \
|
||||
method_name, ComputationDataHandle, \
|
||||
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
|
||||
|
||||
_FORWARD_BINOP(Eq)
|
||||
_FORWARD_BINOP(Ne)
|
||||
_FORWARD_BINOP(Ge)
|
||||
_FORWARD_BINOP(Gt)
|
||||
_FORWARD_BINOP(Lt)
|
||||
_FORWARD_BINOP(Le)
|
||||
_FORWARD_BINOP(Add)
|
||||
_FORWARD_BINOP(Sub)
|
||||
_FORWARD_BINOP(Mul)
|
||||
_FORWARD_BINOP(Div)
|
||||
_FORWARD_BINOP(Rem)
|
||||
_FORWARD_BINOP(Max)
|
||||
_FORWARD_BINOP(Min)
|
||||
_FORWARD_BINOP(And)
|
||||
_FORWARD_BINOP(Or)
|
||||
_FORWARD_UNOP(Not)
|
||||
_FORWARD_UNOP(Abs)
|
||||
_FORWARD_UNOP(Exp)
|
||||
_FORWARD_UNOP(Floor)
|
||||
_FORWARD_UNOP(Ceil)
|
||||
_FORWARD_UNOP(Log)
|
||||
_FORWARD_UNOP(Sign)
|
||||
_FORWARD_UNOP(Cos)
|
||||
_FORWARD_UNOP(Sin)
|
||||
_FORWARD_UNOP(Tanh)
|
||||
_FORWARD_UNOP(SqrtF32)
|
||||
_FORWARD_UNOP(SquareF32)
|
||||
_FORWARD_BINOP(Pow)
|
||||
_FORWARD_UNOP(IsFinite)
|
||||
_FORWARD_UNOP(ReciprocalF32)
|
||||
_FORWARD_UNOP(Neg)
|
||||
_FORWARD_UNOP(Sort)
|
||||
|
||||
#undef _FORWARD
|
||||
#undef _FORWARD_UNOP
|
||||
#undef _FORWARD_BINOP
|
||||
|
||||
private:
|
||||
ComputationBuilder builder_;
|
||||
};
|
||||
|
||||
static void DeleteLocalComputation(LocalComputation* computation) {
|
||||
delete computation;
|
||||
}
|
||||
|
||||
static void DeleteCompiledLocalComputation(
|
||||
CompiledLocalComputation* computation) {
|
||||
delete computation;
|
||||
}
|
||||
|
||||
} // namespace swig
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
|
348
tensorflow/compiler/xla/python/local_computation_builder.i
Normal file
348
tensorflow/compiler/xla/python/local_computation_builder.i
Normal file
@ -0,0 +1,348 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// SWIG typemaps and declarations for building, compiling, and
|
||||
// executing XLA computations, wrapping most of what is declared in
|
||||
// local_computation_builder.h.
|
||||
//
|
||||
// The typemaps below implement/assert the following correspondences
|
||||
// (with elaborations below):
|
||||
//
|
||||
// C++ Python
|
||||
// -------------------------------------+---------------------------------------
|
||||
// ComputationDataHandle <-> long
|
||||
// ArraySlice<int64> <- sequence of long
|
||||
// ArraySlice<ComputationDataHandle> <- sequence of long
|
||||
// Literal <-> (nested tuple of) numpy ndarray
|
||||
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
|
||||
// Shape <-> pair holding (dtype, dimensions)
|
||||
// std::vector<Shape> <- sequence of shape information pairs
|
||||
// PrimitiveType <- int
|
||||
//
|
||||
// Arrows indicate whether a conversion only ever occurs in one
|
||||
// direction, or whether it is maintained bidirectionally. Also,
|
||||
// "long" and "int" denote the Python types so named, not C.
|
||||
//
|
||||
// The Python objects corresponding to C++ Literals have the type:
|
||||
//
|
||||
// T = ndarray | (T, ...)
|
||||
//
|
||||
// where a terminal numpy ndarray translates to a Literal with a
|
||||
// non-tuple Shape, an XLA primitive element type corresponding to the
|
||||
// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates
|
||||
// to a tuple-shaped Literal whose tuple components are translated
|
||||
// recursively. For example, if x is a numpy ndarray in Python, with
|
||||
// shape (2, 3) and dtype of dtype('float32'), then x translates to a
|
||||
// Literal with rank 2, dimension 2 and 3, and XLA primitive type
|
||||
// F32. Meanwhile,
|
||||
//
|
||||
// (x, (x, x), (x,)),
|
||||
//
|
||||
// translates to a tuple-shaped XLA Literal, whose component subshapes
|
||||
// are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
|
||||
//
|
||||
// The Python objects corresponding to C++ Shapes have the type:
|
||||
//
|
||||
// T = (dtype, S)
|
||||
// S = DIMENSIONS | TUPLE_SHAPES
|
||||
// DIMENSIONS = (int, ...)
|
||||
// TUPLE_SHAPES = (T, ...)
|
||||
//
|
||||
// In the pair described by the T rule, the terminal dtype determines
|
||||
// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is
|
||||
// dtype('O'), numpy's object dtype, the structure represents a tuple
|
||||
// shape and the expansion of the non-terminal S is
|
||||
// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type
|
||||
// and S expands into DIMENSIONS giving dimension sizes. For example:
|
||||
//
|
||||
// (dtype('float32'), (3, 5, 7))
|
||||
//
|
||||
// describes a 3x5x7 array of F32s, and
|
||||
//
|
||||
// (dtype('O'), ((dtype('float32'), (2, 3)),
|
||||
// (dtype('float64'), (4, 5))))
|
||||
//
|
||||
// describes a tuple shape with two subshapes: the first a 2x3 F32,
|
||||
// and the other a 4x5 F64.
|
||||
//
|
||||
// The Python int corresponding to a PrimitiveType enum must be valid
|
||||
// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32).
|
||||
//
|
||||
// The SWIG object wrappers generated by this file are not intended
|
||||
// for end use, but rather for internal use in the Python XLA client,
|
||||
// xla_client.py.
|
||||
//
|
||||
// One central reason for the Python-side indirection is that the
|
||||
// Python-side objects produced by the typemaps in this file are
|
||||
// further packaged up by xla_client before being passed on. For
|
||||
// instance, xla_client wraps the long produced for a C++
|
||||
// ComputationDataHandle in a Python ComputationDataHandle proto,
|
||||
// rather than exposing a raw long outside of the client. Similarly,
|
||||
// the Python pair produced for a C++ Shape is further wrapped in a
|
||||
// Python class (xla_client.Shape) so as not to expose the raw pair
|
||||
// externally.
|
||||
//
|
||||
// Other SWIG object wrappers (e.g. of LocalComputation) are further
|
||||
// wrapped by xla_client in order to set up a custom destructor that
|
||||
// triggers memory deallocation on the C++ side.
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
// Must be included first
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
|
||||
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
|
||||
|
||||
using namespace xla;
|
||||
using namespace xla::swig;
|
||||
%}
|
||||
|
||||
// Required to use PyArray_* functions.
|
||||
%init %{
|
||||
tensorflow::ImportNumpy();
|
||||
%}
|
||||
|
||||
// ComputationDataHandle
|
||||
|
||||
%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) {
|
||||
const int64 handle = numpy::PyIntOrPyLongToLong($input);
|
||||
if (handle == -1 && PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
temp.set_handle(handle);
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%typemap(out) ComputationDataHandle {
|
||||
$result = numpy::LongToPyIntOrPyLong($1.handle());
|
||||
}
|
||||
|
||||
// ArraySlice<int64>
|
||||
|
||||
%typemap(in) tensorflow::gtl::ArraySlice<int64>
|
||||
(std::vector<int64> temps) {
|
||||
if (!PySequence_Check($input)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||
return NULL;
|
||||
}
|
||||
const int size = PySequence_Size($input);
|
||||
temps.resize(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* o = PySequence_GetItem($input, i);
|
||||
PyObject* py_int = numpy::PyNumberToPyInt(o);
|
||||
if (!py_int) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"Argument sequence element cannot be converted to int");
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
temps[i] = numpy::PyIntOrPyLongToLong(py_int);
|
||||
if (temps[i] == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(py_int);
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
Py_DECREF(py_int);
|
||||
Py_DECREF(o);
|
||||
}
|
||||
$1 = temps;
|
||||
}
|
||||
|
||||
// ComputationDataHandle
|
||||
|
||||
%typemap(in) tensorflow::gtl::ArraySlice<ComputationDataHandle>
|
||||
(std::vector<ComputationDataHandle> temps) {
|
||||
if (!PySequence_Check($input)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||
return NULL;
|
||||
}
|
||||
const int size = PySequence_Size($input);
|
||||
temps.resize(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* o = PySequence_GetItem($input, i);
|
||||
PyObject* py_int = numpy::PyNumberToPyInt(o);
|
||||
if (!py_int) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"Argument sequence element cannot be converted to int");
|
||||
return NULL;
|
||||
}
|
||||
const int64 handle = numpy::PyIntOrPyLongToLong(py_int);
|
||||
if (handle == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(py_int);
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
temps[i].set_handle(handle);
|
||||
Py_DECREF(py_int);
|
||||
Py_DECREF(o);
|
||||
}
|
||||
$1 = temps;
|
||||
}
|
||||
|
||||
// Literal
|
||||
|
||||
%typemap(in) const Literal& (std::unique_ptr<Literal> temp) {
|
||||
temp = numpy::XlaLiteralFromPyObject($input);
|
||||
$1 = &*temp;
|
||||
}
|
||||
|
||||
%typemap(out) std::unique_ptr<Literal> {
|
||||
$result = numpy::PyObjectFromXlaLiteral(*$1);
|
||||
}
|
||||
|
||||
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
|
||||
if (!PySequence_Check($input)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||
return NULL;
|
||||
}
|
||||
const int size = PySequence_Size($input);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* o = PySequence_GetItem($input, i);
|
||||
temps.push_back(*numpy::XlaLiteralFromPyObject(o));
|
||||
Py_DECREF(o);
|
||||
}
|
||||
$1 = &temps;
|
||||
}
|
||||
|
||||
// Shape
|
||||
|
||||
%typemap(in) const Shape& (Shape temp) {
|
||||
if (!numpy::CheckPyShapeInfo($input)) {
|
||||
return NULL;
|
||||
}
|
||||
temp = numpy::XlaShapeFromPyShapeInfo($input);
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%typemap(out) std::unique_ptr<Shape> {
|
||||
$result = numpy::PyShapeInfoFromXlaShape(*$1);
|
||||
}
|
||||
|
||||
%typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
|
||||
if (!PySequence_Check($input)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||
return NULL;
|
||||
}
|
||||
const int size = PySequence_Size($input);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* o = PySequence_GetItem($input, i);
|
||||
if (!numpy::CheckPyShapeInfo(o)) {
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
temps.push_back(numpy::XlaShapeFromPyShapeInfo(o));
|
||||
Py_DECREF(o);
|
||||
}
|
||||
$1 = &temps;
|
||||
}
|
||||
|
||||
// PrimitiveType
|
||||
|
||||
%typemap(in) PrimitiveType {
|
||||
PyObject* py_int = numpy::PyNumberToPyInt($input);
|
||||
if (!py_int) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int");
|
||||
return NULL;
|
||||
}
|
||||
const long value = numpy::PyIntOrPyLongToLong(py_int);
|
||||
if (value == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(py_int);
|
||||
return NULL;
|
||||
}
|
||||
if (!PrimitiveType_IsValid(value)) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError, "Argument not valid for PrimitiveType enum");
|
||||
Py_DECREF(py_int);
|
||||
return NULL;
|
||||
}
|
||||
$1 = static_cast<PrimitiveType>(value);
|
||||
}
|
||||
|
||||
%ignoreall
|
||||
%unignore xla;
|
||||
%unignore xla::swig;
|
||||
%unignore xla::swig::CompiledLocalComputation;
|
||||
%unignore xla::swig::CompiledLocalComputation::Execute;
|
||||
%unignore xla::swig::LocalComputation;
|
||||
%unignore xla::swig::LocalComputation::Compile;
|
||||
%unignore xla::swig::LocalComputationBuilder;
|
||||
%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder;
|
||||
%unignore xla::swig::LocalComputationBuilder::Build;
|
||||
%unignore xla::swig::LocalComputationBuilder::Parameter;
|
||||
%unignore xla::swig::LocalComputationBuilder::GetShape;
|
||||
%unignore xla::swig::LocalComputationBuilder::ConstantLiteral;
|
||||
%unignore xla::swig::LocalComputationBuilder::ConstantR0;
|
||||
%unignore xla::swig::LocalComputationBuilder::Broadcast;
|
||||
%unignore xla::swig::LocalComputationBuilder::Reshape;
|
||||
%unignore xla::swig::LocalComputationBuilder::Slice;
|
||||
%unignore xla::swig::LocalComputationBuilder::DynamicSlice;
|
||||
%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice;
|
||||
%unignore xla::swig::LocalComputationBuilder::ConcatInDim;
|
||||
%unignore xla::swig::LocalComputationBuilder::Select;
|
||||
%unignore xla::swig::LocalComputationBuilder::Tuple;
|
||||
%unignore xla::swig::LocalComputationBuilder::GetTupleElement;
|
||||
%unignore xla::swig::LocalComputationBuilder::ConvertElementType;
|
||||
%unignore xla::swig::LocalComputationBuilder::Call;
|
||||
%unignore xla::swig::LocalComputationBuilder::Transpose;
|
||||
%unignore xla::swig::LocalComputationBuilder::Map;
|
||||
%unignore xla::swig::LocalComputationBuilder::Reduce;
|
||||
%unignore xla::swig::LocalComputationBuilder::While;
|
||||
%unignore xla::swig::LocalComputationBuilder::Eq;
|
||||
%unignore xla::swig::LocalComputationBuilder::Ne;
|
||||
%unignore xla::swig::LocalComputationBuilder::Ge;
|
||||
%unignore xla::swig::LocalComputationBuilder::Gt;
|
||||
%unignore xla::swig::LocalComputationBuilder::Lt;
|
||||
%unignore xla::swig::LocalComputationBuilder::Le;
|
||||
%unignore xla::swig::LocalComputationBuilder::Dot;
|
||||
%unignore xla::swig::LocalComputationBuilder::Add;
|
||||
%unignore xla::swig::LocalComputationBuilder::Sub;
|
||||
%unignore xla::swig::LocalComputationBuilder::Mul;
|
||||
%unignore xla::swig::LocalComputationBuilder::Div;
|
||||
%unignore xla::swig::LocalComputationBuilder::Rem;
|
||||
%unignore xla::swig::LocalComputationBuilder::Max;
|
||||
%unignore xla::swig::LocalComputationBuilder::Min;
|
||||
%unignore xla::swig::LocalComputationBuilder::And;
|
||||
%unignore xla::swig::LocalComputationBuilder::Or;
|
||||
%unignore xla::swig::LocalComputationBuilder::Not;
|
||||
%unignore xla::swig::LocalComputationBuilder::Abs;
|
||||
%unignore xla::swig::LocalComputationBuilder::Exp;
|
||||
%unignore xla::swig::LocalComputationBuilder::Floor;
|
||||
%unignore xla::swig::LocalComputationBuilder::Ceil;
|
||||
%unignore xla::swig::LocalComputationBuilder::Log;
|
||||
%unignore xla::swig::LocalComputationBuilder::Sign;
|
||||
%unignore xla::swig::LocalComputationBuilder::Cos;
|
||||
%unignore xla::swig::LocalComputationBuilder::Sin;
|
||||
%unignore xla::swig::LocalComputationBuilder::Tanh;
|
||||
%unignore xla::swig::LocalComputationBuilder::SqrtF32;
|
||||
%unignore xla::swig::LocalComputationBuilder::SquareF32;
|
||||
%unignore xla::swig::LocalComputationBuilder::Pow;
|
||||
%unignore xla::swig::LocalComputationBuilder::IsFinite;
|
||||
%unignore xla::swig::LocalComputationBuilder::ReciprocalF32;
|
||||
%unignore xla::swig::LocalComputationBuilder::Neg;
|
||||
%unignore xla::swig::LocalComputationBuilder::Sort;
|
||||
%unignore xla::swig::DeleteLocalComputation;
|
||||
%unignore xla::swig::DeleteCompiledLocalComputation;
|
||||
|
||||
%include "tensorflow/compiler/xla/python/local_computation_builder.h"
|
||||
|
||||
%unignoreall
|
389
tensorflow/compiler/xla/python/numpy_bridge.cc
Normal file
389
tensorflow/compiler/xla/python/numpy_bridge.cc
Normal file
@ -0,0 +1,389 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace swig {
|
||||
|
||||
namespace numpy {
|
||||
|
||||
int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
|
||||
switch (primitive_type) {
|
||||
case PRED:
|
||||
return NPY_BOOL;
|
||||
case S8:
|
||||
return NPY_INT8;
|
||||
case S16:
|
||||
return NPY_INT16;
|
||||
case S32:
|
||||
return NPY_INT32;
|
||||
case S64:
|
||||
return NPY_INT64;
|
||||
case U8:
|
||||
return NPY_UINT8;
|
||||
case U16:
|
||||
return NPY_UINT16;
|
||||
case U32:
|
||||
return NPY_UINT32;
|
||||
case U64:
|
||||
return NPY_UINT64;
|
||||
case F16:
|
||||
return NPY_FLOAT16;
|
||||
case F32:
|
||||
return NPY_FLOAT32;
|
||||
case F64:
|
||||
return NPY_FLOAT64;
|
||||
case TUPLE:
|
||||
return NPY_OBJECT;
|
||||
default:
|
||||
LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type;
|
||||
}
|
||||
}
|
||||
|
||||
PrimitiveType NumpyTypeToPrimitiveType(int np_type) {
|
||||
switch (np_type) {
|
||||
case NPY_BOOL:
|
||||
return PRED;
|
||||
case NPY_INT8:
|
||||
return S8;
|
||||
case NPY_INT16:
|
||||
return S16;
|
||||
case NPY_INT32:
|
||||
return S32;
|
||||
case NPY_INT64:
|
||||
return S64;
|
||||
case NPY_UINT8:
|
||||
return U8;
|
||||
case NPY_UINT16:
|
||||
return U16;
|
||||
case NPY_UINT32:
|
||||
return U32;
|
||||
case NPY_UINT64:
|
||||
return U64;
|
||||
case NPY_FLOAT16:
|
||||
return F16;
|
||||
case NPY_FLOAT32:
|
||||
return F32;
|
||||
case NPY_FLOAT64:
|
||||
return F64;
|
||||
case NPY_OBJECT:
|
||||
return TUPLE;
|
||||
default:
|
||||
LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type;
|
||||
}
|
||||
}
|
||||
|
||||
bool NumpyTypeIsValid(int np_type) {
|
||||
switch (np_type) {
|
||||
case NPY_BOOL:
|
||||
case NPY_INT8:
|
||||
case NPY_INT16:
|
||||
case NPY_INT32:
|
||||
case NPY_INT64:
|
||||
case NPY_UINT8:
|
||||
case NPY_UINT16:
|
||||
case NPY_UINT32:
|
||||
case NPY_UINT64:
|
||||
case NPY_FLOAT16:
|
||||
case NPY_FLOAT32:
|
||||
case NPY_FLOAT64:
|
||||
case NPY_OBJECT:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* PyShapeInfoFromXlaShape(const Shape& shape) {
|
||||
int np_typenum = PrimitiveTypeToNumpyType(shape.element_type());
|
||||
PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum);
|
||||
|
||||
PyObject* dimensions;
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
int num_elements = ShapeUtil::TupleElementCount(shape);
|
||||
dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape));
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
PyTuple_SET_ITEM(
|
||||
dimensions, i,
|
||||
PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)));
|
||||
}
|
||||
} else {
|
||||
int rank = ShapeUtil::Rank(shape);
|
||||
dimensions = PyTuple_New(rank);
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
PyTuple_SET_ITEM(dimensions, i,
|
||||
LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i)));
|
||||
}
|
||||
}
|
||||
return PyTuple_Pack(2, np_dtype, dimensions);
|
||||
}
|
||||
|
||||
// Precondition: o->ob_type == &PyArrayDescr_Type
|
||||
static int NumpyTypenum(PyObject* o) {
|
||||
return reinterpret_cast<PyArray_Descr*>(o)->type_num;
|
||||
}
|
||||
|
||||
bool CheckPyShapeInfo(PyObject* o) {
|
||||
// The object is a tuple (a pair)
|
||||
if (!PyTuple_Check(o)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Shape record must be a tuple");
|
||||
return false;
|
||||
}
|
||||
if (PyTuple_Size(o) != 2) {
|
||||
PyErr_SetString(PyExc_ValueError, "Shape record tuple must be of length 2");
|
||||
return false;
|
||||
}
|
||||
|
||||
// It has a first element, which is a numpy dtype object
|
||||
PyObject* first = PyTuple_GetItem(o, 0);
|
||||
if (!first) {
|
||||
return false;
|
||||
}
|
||||
if (first->ob_type != &PyArrayDescr_Type) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"Shape record does not have a numpy dtype as its first element");
|
||||
return false;
|
||||
}
|
||||
const int np_type = NumpyTypenum(first);
|
||||
if (!NumpyTypeIsValid(np_type)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Shape record has an invalid integer dtype");
|
||||
return false;
|
||||
}
|
||||
|
||||
// It has a second element, which is a tuple, either of shape
|
||||
// records or of Python ints
|
||||
PyObject* second = PyTuple_GetItem(o, 1);
|
||||
if (!second) {
|
||||
return false;
|
||||
}
|
||||
if (!PyTuple_Check(second)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Shape record does not have a tuple as its second element");
|
||||
return false;
|
||||
}
|
||||
const int length = PyTuple_Size(second);
|
||||
const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
|
||||
for (int i = 0; i < length; i++) {
|
||||
PyObject* dimension = PyTuple_GetItem(second, i);
|
||||
if (element_type == TUPLE) {
|
||||
if (!CheckPyShapeInfo(dimension)) {
|
||||
return false;
|
||||
}
|
||||
} else if (!CheckPyIntOrLong(dimension)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Non-tuple shape record has a non-integer dimension");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Precondition: CheckPyShapeInfo(o)
|
||||
Shape XlaShapeFromPyShapeInfo(PyObject* o) {
|
||||
const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0));
|
||||
const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
|
||||
PyObject* py_dimensions = PyTuple_GetItem(o, 1);
|
||||
const int length = PyTuple_Size(py_dimensions);
|
||||
if (element_type == TUPLE) {
|
||||
std::vector<Shape> subshapes;
|
||||
subshapes.reserve(length);
|
||||
for (int i = 0; i < length; i++) {
|
||||
subshapes.push_back(
|
||||
XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i)));
|
||||
}
|
||||
return ShapeUtil::MakeTupleShape(subshapes);
|
||||
} else {
|
||||
std::vector<int64> dimensions(length);
|
||||
for (int i = 0; i < length; i++) {
|
||||
dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
|
||||
if (dimensions[i] == -1) {
|
||||
CHECK(!PyErr_Occurred());
|
||||
}
|
||||
}
|
||||
return ShapeUtil::MakeShape(element_type, dimensions);
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* PyObjectFromXlaLiteral(const Literal& literal) {
|
||||
if (ShapeUtil::IsTuple(literal.shape())) {
|
||||
const std::vector<Literal>& tuple_literals = literal.tuple_literals();
|
||||
int num_elements = ShapeUtil::TupleElementCount(literal.shape());
|
||||
PyObject* tuple = PyTuple_New(num_elements);
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
PyTuple_SET_ITEM(tuple, i, PyObjectFromXlaLiteral(tuple_literals[i]));
|
||||
}
|
||||
return tuple;
|
||||
} else {
|
||||
int rank = ShapeUtil::Rank(literal.shape());
|
||||
std::vector<long> dimensions(rank); // NOLINT - PyArray requires a long*
|
||||
for (int i = 0; i < rank; i++) {
|
||||
dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
|
||||
}
|
||||
int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
|
||||
PyObject* array =
|
||||
PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0);
|
||||
CopyLiteralToNumpyArray(np_type, literal,
|
||||
reinterpret_cast<PyArrayObject*>(array));
|
||||
return array;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> XlaLiteralFromPyObject(PyObject* o) {
|
||||
if (PyTuple_Check(o)) {
|
||||
int num_elements = PyTuple_Size(o);
|
||||
std::vector<std::unique_ptr<Literal>> elements;
|
||||
elements.reserve(num_elements);
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
PyObject* element = PyTuple_GetItem(o, i);
|
||||
elements.push_back(XlaLiteralFromPyObject(element));
|
||||
}
|
||||
return Literal::MakeTupleOwned(std::move(elements));
|
||||
} else if (PyArray_Check(o)) {
|
||||
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(o);
|
||||
int rank = PyArray_NDIM(py_array);
|
||||
std::vector<int64> dimensions(rank);
|
||||
for (int i = 0; i < rank; i++) {
|
||||
dimensions[i] = PyArray_DIM(py_array, i);
|
||||
}
|
||||
int np_type = PyArray_TYPE(py_array);
|
||||
auto literal = Literal::CreateFromDimensions(
|
||||
NumpyTypeToPrimitiveType(np_type), dimensions);
|
||||
CopyNumpyArrayToLiteral(np_type, py_array, literal.get());
|
||||
return literal;
|
||||
} else {
|
||||
LOG(FATAL)
|
||||
<< "Non-tuple or Numpy array encountered in conversion to XLA literal";
|
||||
}
|
||||
}
|
||||
|
||||
void CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
||||
Literal* literal) {
|
||||
switch (np_type) {
|
||||
case NPY_BOOL:
|
||||
CopyNumpyArrayToLiteral<bool>(py_array, literal);
|
||||
break;
|
||||
case NPY_INT32:
|
||||
CopyNumpyArrayToLiteral<int32>(py_array, literal);
|
||||
break;
|
||||
case NPY_INT64:
|
||||
CopyNumpyArrayToLiteral<int64>(py_array, literal);
|
||||
break;
|
||||
case NPY_UINT8:
|
||||
CopyNumpyArrayToLiteral<uint8>(py_array, literal);
|
||||
break;
|
||||
case NPY_UINT32:
|
||||
CopyNumpyArrayToLiteral<uint32>(py_array, literal);
|
||||
break;
|
||||
case NPY_UINT64:
|
||||
CopyNumpyArrayToLiteral<uint64>(py_array, literal);
|
||||
break;
|
||||
case NPY_FLOAT16:
|
||||
CopyNumpyArrayToLiteral<half>(py_array, literal);
|
||||
break;
|
||||
case NPY_FLOAT32:
|
||||
CopyNumpyArrayToLiteral<float>(py_array, literal);
|
||||
break;
|
||||
case NPY_FLOAT64:
|
||||
CopyNumpyArrayToLiteral<double>(py_array, literal);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
|
||||
}
|
||||
}
|
||||
|
||||
void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
|
||||
PyArrayObject* py_array) {
|
||||
switch (np_type) {
|
||||
case NPY_BOOL:
|
||||
CopyLiteralToNumpyArray<bool>(literal, py_array);
|
||||
break;
|
||||
case NPY_INT32:
|
||||
CopyLiteralToNumpyArray<int32>(literal, py_array);
|
||||
break;
|
||||
case NPY_INT64:
|
||||
CopyLiteralToNumpyArray<int64>(literal, py_array);
|
||||
break;
|
||||
case NPY_UINT8:
|
||||
CopyLiteralToNumpyArray<uint8>(literal, py_array);
|
||||
break;
|
||||
case NPY_UINT32:
|
||||
CopyLiteralToNumpyArray<uint32>(literal, py_array);
|
||||
break;
|
||||
case NPY_UINT64:
|
||||
CopyLiteralToNumpyArray<uint64>(literal, py_array);
|
||||
break;
|
||||
case NPY_FLOAT16:
|
||||
CopyLiteralToNumpyArray<half>(literal, py_array);
|
||||
break;
|
||||
case NPY_FLOAT32:
|
||||
CopyLiteralToNumpyArray<float>(literal, py_array);
|
||||
break;
|
||||
case NPY_FLOAT64:
|
||||
CopyLiteralToNumpyArray<double>(literal, py_array);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* LongToPyIntOrPyLong(long x) { // NOLINT
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
return PyInt_FromLong(x);
|
||||
#else
|
||||
return PyLong_FromLong(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
long PyIntOrPyLongToLong(PyObject* o) { // NOLINT
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
return PyInt_AsLong(o);
|
||||
#else
|
||||
return PyLong_AsLong(o);
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CheckPyIntOrLong(PyObject* o) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
return PyInt_Check(o);
|
||||
#else
|
||||
if (!PyLong_Check(o)) {
|
||||
return false;
|
||||
}
|
||||
int overflow = 0;
|
||||
PyLong_AsLongAndOverflow(o, &overflow);
|
||||
return (overflow == 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
PyObject* PyNumberToPyInt(PyObject* o) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
return PyNumber_Int(o);
|
||||
#else
|
||||
return PyNumber_Long(o);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace numpy
|
||||
|
||||
} // namespace swig
|
||||
|
||||
} // namespace xla
|
123
tensorflow/compiler/xla/python/numpy_bridge.h
Normal file
123
tensorflow/compiler/xla/python/numpy_bridge.h
Normal file
@ -0,0 +1,123 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// These functions transform Python/Numpy data structures to XLA data
|
||||
// structures and vice versa, performing copies where
|
||||
// appropriate. Python tuples and Numpy ndarrays translate to XLA
|
||||
// tuples and XLA literals, respectively, and Numpy shape/dtype
|
||||
// information is translated to XLA shape information.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace swig {
|
||||
|
||||
namespace numpy {
|
||||
|
||||
// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
|
||||
// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
|
||||
// vice versa.
|
||||
int PrimitiveTypeToNumpyType(PrimitiveType primitive_type);
|
||||
PrimitiveType NumpyTypeToPrimitiveType(int np_type);
|
||||
|
||||
// Determines whether an integer-encoded Numpy dtype is valid,
|
||||
// i.e. has a supported conversion to an XLA PrimitiveType.
|
||||
bool NumpyTypeIsValid(int np_type);
|
||||
|
||||
// Converts XLA shape information into a Python pair of the form
|
||||
// (numpy dtype, dimensions). If the XLA shape represents a tuple,
|
||||
// then the numpy dtype is NPY_OBJECT ('O') and `dimensions` is a
|
||||
// Python tuple of shape-description pairs, created
|
||||
// recursively. Otherwise, `dimensions` is a Python tuple-of-integers
|
||||
// providing the array dimensions.
|
||||
//
|
||||
// The return value is a new reference.
|
||||
PyObject* PyShapeInfoFromXlaShape(const Shape& shape);
|
||||
|
||||
// Returns the outcome of a best-effort check that the Python object
|
||||
// is a pair of the form (numpy dtype, dimensions), as produced by
|
||||
// PyShapeInfoFromXlaShape.
|
||||
bool CheckPyShapeInfo(PyObject* o);
|
||||
|
||||
// Performs the inverse conversion to that of PyShapeInfoFromXlaShape.
|
||||
//
|
||||
// The return value is a new reference.
|
||||
Shape XlaShapeFromPyShapeInfo(PyObject* o);
|
||||
|
||||
// Converts an XLA literal to a Python object, either a Numpy ndarray
|
||||
// or a nested Python tuple thereof.
|
||||
//
|
||||
// To avoid transferring ownership of the data buffers that underlie
|
||||
// PyArrays and XLA literals, this function makes deep copies of all
|
||||
// array data.
|
||||
//
|
||||
// The return value is a new reference.
|
||||
PyObject* PyObjectFromXlaLiteral(const Literal& literal);
|
||||
|
||||
// Converts a Numpy ndarray or a nested Python tuple thereof to a
|
||||
// corresponding XLA literal.
|
||||
//
|
||||
// To avoid transferring ownership of the data buffers that underlie
|
||||
// PyArrays and XLA literals, this function makes deep copies of all
|
||||
// array data.
|
||||
std::unique_ptr<Literal> XlaLiteralFromPyObject(PyObject* o);
|
||||
|
||||
// The following functions copy array data from the buffers underlying Numpy
|
||||
// ndarrays into those underlying XLA literals, and vice versa.
|
||||
|
||||
void CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
||||
Literal* literal);
|
||||
|
||||
void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
|
||||
PyArrayObject* py_array);
|
||||
|
||||
template <typename NativeT>
|
||||
void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
|
||||
NativeT* source = static_cast<NativeT*>(PyArray_DATA(py_array));
|
||||
auto dest = literal->GetMutableArraySlice<NativeT>();
|
||||
std::copy(source, source + PyArray_SIZE(py_array), dest.data());
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) {
|
||||
NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array));
|
||||
auto source = literal.GetArraySlice<NativeT>();
|
||||
std::copy(source.begin(), source.end(), dest);
|
||||
}
|
||||
|
||||
// Workarounds for Python 2 and 3 interop
|
||||
|
||||
PyObject* LongToPyIntOrPyLong(long x); // NOLINT
|
||||
long PyIntOrPyLongToLong(PyObject* o); // NOLINT
|
||||
bool CheckPyIntOrLong(PyObject* o);
|
||||
PyObject* PyNumberToPyInt(PyObject* o);
|
||||
|
||||
} // namespace numpy
|
||||
|
||||
} // namespace swig
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
|
18
tensorflow/compiler/xla/python/xla.i
Normal file
18
tensorflow/compiler/xla/python/xla.i
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
/* XLA-wide SWIG wrapper */
|
||||
|
||||
%include "tensorflow/compiler/xla/python/local_computation_builder.i"
|
605
tensorflow/compiler/xla/python/xla_client.py
Normal file
605
tensorflow/compiler/xla/python/xla_client.py
Normal file
@ -0,0 +1,605 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""An in-process, local XLA client in Python, supporting AOT compilation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.xla import xla_data_pb2
|
||||
from tensorflow.compiler.xla.python import pywrap_xla as c_api
|
||||
|
||||
_UNARY_OPS = [
|
||||
'Not',
|
||||
'Abs',
|
||||
'Exp',
|
||||
'Floor',
|
||||
'Ceil',
|
||||
'Log',
|
||||
'Sign',
|
||||
'Cos',
|
||||
'Sin',
|
||||
'Tanh',
|
||||
'SqrtF32',
|
||||
'SquareF32',
|
||||
'IsFinite',
|
||||
'ReciprocalF32',
|
||||
'Neg',
|
||||
'Sort',
|
||||
]
|
||||
|
||||
_BINARY_OPS = [
|
||||
'Eq',
|
||||
'Ne',
|
||||
'Ge',
|
||||
'Gt',
|
||||
'Lt',
|
||||
'Le',
|
||||
'Add',
|
||||
'Sub',
|
||||
'Mul',
|
||||
'Div',
|
||||
'Rem',
|
||||
'Max',
|
||||
'Min',
|
||||
'And',
|
||||
'Or',
|
||||
'Pow',
|
||||
]
|
||||
|
||||
# Most functions are snake_case for consistency with other modules,
|
||||
# whereas method names of ComputationBuilder and LocalComputation are
|
||||
# CamelCase for consistency with XLA.
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
XLA_ELEMENT_TYPE_TO_DTYPE = {
|
||||
xla_data_pb2.F32: np.dtype(np.float32),
|
||||
xla_data_pb2.F64: np.dtype(np.float64),
|
||||
xla_data_pb2.S32: np.dtype(np.int32),
|
||||
xla_data_pb2.S64: np.dtype(np.int64),
|
||||
xla_data_pb2.PRED: np.dtype(np.bool),
|
||||
xla_data_pb2.TUPLE: np.dtype(np.object),
|
||||
}
|
||||
|
||||
DTYPE_TO_XLA_ELEMENT_TYPE = {
|
||||
str(v): k
|
||||
for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items()
|
||||
}
|
||||
|
||||
|
||||
class Shape(object):
|
||||
"""XLA shape.
|
||||
|
||||
Represents an XLA shape by a corresponding Python/Numpy type and a
|
||||
list of dimensions, which are themselves Shapes in case this one
|
||||
represents an XLA tuple.
|
||||
"""
|
||||
|
||||
def __init__(self, np_dtype, dimensions):
|
||||
self.np_dtype = np_dtype
|
||||
self._dimensions = dimensions
|
||||
|
||||
def element_type(self):
|
||||
return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)]
|
||||
|
||||
def is_tuple(self):
|
||||
return self.element_type() == xla_data_pb2.TUPLE
|
||||
|
||||
def dimensions(self):
|
||||
if self.is_tuple():
|
||||
raise ValueError('Tuple shape has no dimensions')
|
||||
return self._dimensions
|
||||
|
||||
def tuple_shapes(self):
|
||||
if not self.is_tuple():
|
||||
raise ValueError('Shape is not a tuple shape')
|
||||
return self._dimensions
|
||||
|
||||
@staticmethod
|
||||
def from_numpy(npval):
|
||||
|
||||
def convert(npval):
|
||||
if isinstance(npval, tuple):
|
||||
return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval))
|
||||
else:
|
||||
return Shape(npval.dtype, np.shape(npval))
|
||||
|
||||
return convert(require_numpy_array_layout(npval))
|
||||
|
||||
|
||||
def _wrap_shape(shape_info):
|
||||
dtype, dims = shape_info
|
||||
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
|
||||
if element_type == xla_data_pb2.TUPLE:
|
||||
dims = [_wrap_shape(subshape_info) for subshape_info in dims]
|
||||
return Shape(dtype, dims)
|
||||
|
||||
|
||||
def _unwrap_shape(shape):
|
||||
if shape.is_tuple():
|
||||
components = tuple(
|
||||
_unwrap_shape(subshape) for subshape in shape.tuple_shapes())
|
||||
else:
|
||||
components = shape.dimensions()
|
||||
return (shape.np_dtype, components)
|
||||
|
||||
|
||||
def _unwrap_shapes(shapes):
|
||||
return [_unwrap_shape(shape) for shape in shapes]
|
||||
|
||||
|
||||
def _wrap_data_handle(handle):
|
||||
cdh = xla_data_pb2.ComputationDataHandle()
|
||||
cdh.handle = handle
|
||||
return cdh
|
||||
|
||||
|
||||
def _unwrap_data_handle(handle_proto):
|
||||
return handle_proto.handle
|
||||
|
||||
|
||||
def _unwrap_data_handles(handle_protos):
|
||||
return [_unwrap_data_handle(cdh) for cdh in handle_protos]
|
||||
|
||||
|
||||
def require_numpy_array_layout(value):
|
||||
if isinstance(value, tuple):
|
||||
return tuple(require_numpy_array_layout(x) for x in value)
|
||||
else:
|
||||
return np.require(value, requirements=['C', 'A'])
|
||||
|
||||
|
||||
class LocalComputation(object):
|
||||
"""Python wrapper for a local XLA Computation.
|
||||
|
||||
A LocalComputation can be executed if it is compiled. Otherwise, it
|
||||
can still be used as a Computation where required by the
|
||||
ComputationBuilder methods.
|
||||
"""
|
||||
|
||||
def __init__(self, c_local_computation, is_compiled):
|
||||
self.c_local_computation = c_local_computation
|
||||
self.is_compiled = is_compiled
|
||||
|
||||
# Ensure a reference to C-based destructor for use in __del__.
|
||||
if is_compiled:
|
||||
self._delete = c_api.DeleteCompiledLocalComputation
|
||||
else:
|
||||
self._delete = c_api.DeleteLocalComputation
|
||||
|
||||
def Compile(self, argument_shapes=()):
|
||||
if self.is_compiled:
|
||||
raise ValueError('Attempt to compile a compiled local XLA computation.')
|
||||
return LocalComputation(
|
||||
self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)),
|
||||
is_compiled=True)
|
||||
|
||||
def CompileWithExampleArguments(self, arguments=()):
|
||||
return self.Compile(
|
||||
argument_shapes=[Shape.from_numpy(arg) for arg in arguments])
|
||||
|
||||
def Execute(self, arguments=()):
|
||||
if not self.is_compiled:
|
||||
raise ValueError('Cannot execute an uncompiled local XLA computation.')
|
||||
arguments = tuple(map(require_numpy_array_layout, arguments))
|
||||
return self.c_local_computation.Execute(arguments)
|
||||
|
||||
def __del__(self):
|
||||
self._delete(self.c_local_computation)
|
||||
|
||||
|
||||
class ComputationBuilder(object):
|
||||
"""XLA computation builder.
|
||||
|
||||
Enqueues XLA ops in sequence and in order to build a
|
||||
LocalComputation, which in turn can be compiled into a
|
||||
CompiledLocalComputation, which in turn can be locally executed.
|
||||
"""
|
||||
|
||||
# The methods of this class map 1-to-1 onto the XLA C++
|
||||
# computation builder API. Therefore, there's no need to laboriously list
|
||||
# arguments and return values for every method, especially where it's obvious.
|
||||
#
|
||||
# pylint: disable=g-doc-return-or-yield
|
||||
# pylint: disable=g-doc-args
|
||||
|
||||
def __init__(self, name):
|
||||
self._client = c_api.LocalComputationBuilder(name.encode('utf8'))
|
||||
self._parameter_numbering = itertools.count()
|
||||
|
||||
def Build(self):
|
||||
return LocalComputation(self._client.Build(), is_compiled=False)
|
||||
|
||||
def Constant(self, value):
|
||||
"""Enqueues a constant op onto the computation.
|
||||
|
||||
Args:
|
||||
value: value for the constant, as a np.array with an explicit dtype set
|
||||
to one of the supported types.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle message.
|
||||
"""
|
||||
value = require_numpy_array_layout(value)
|
||||
return _wrap_data_handle(self._client.ConstantLiteral(value))
|
||||
|
||||
def ConstantF32Scalar(self, value):
|
||||
"""Convenience method to enqueue a scalar F32 constant op.
|
||||
|
||||
Args:
|
||||
value: a floating-point number.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle message.
|
||||
"""
|
||||
return self.Constant(np.array(value, dtype=np.float32))
|
||||
|
||||
def ConstantF64Scalar(self, value):
|
||||
"""Convenience method to enqueue a scalar F32 constant op.
|
||||
|
||||
Args:
|
||||
value: a floating-point number.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle message.
|
||||
"""
|
||||
return self.Constant(np.array(value, dtype=np.float64))
|
||||
|
||||
def ConstantS32Scalar(self, value):
|
||||
"""Convenience method to enqueue a scalar S32 constant op.
|
||||
|
||||
Args:
|
||||
value: a floating-point number.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle message.
|
||||
"""
|
||||
return self.Constant(np.array(value, dtype=np.int32))
|
||||
|
||||
def ConstantS64Scalar(self, value):
|
||||
"""Convenience method to enqueue a scalar S64 constant op.
|
||||
|
||||
Args:
|
||||
value: a floating-point number.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle message.
|
||||
"""
|
||||
return self.Constant(np.array(value, dtype=np.int64))
|
||||
|
||||
def ConstantPredScalar(self, value):
|
||||
"""Convenience method to enqueue a scalar PRED constant op.
|
||||
|
||||
Args:
|
||||
value: a boolean value.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle message.
|
||||
"""
|
||||
return self.Constant(np.array(value, dtype=np.bool))
|
||||
|
||||
def ParameterWithShape(self, shape, name=None, parameter_num=None):
|
||||
"""Enqueues a Parameter op onto the computation, given a shape.
|
||||
|
||||
Args:
|
||||
shape: the parameter's shape as a Shape object.
|
||||
name: optional string name for the parameter.
|
||||
parameter_num: parameter number in the computation function. If None,
|
||||
the next linear parameter number is used. The default value capability
|
||||
can be used for auto-numbering. If you're using auto-numbering for some
|
||||
parameters, use it for *all* parameters to avoid clashes.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle message.
|
||||
"""
|
||||
if name is None:
|
||||
name = ''
|
||||
if parameter_num is None:
|
||||
parameter_num = next(self._parameter_numbering)
|
||||
|
||||
return _wrap_data_handle(
|
||||
self._client.Parameter(
|
||||
parameter_num, _unwrap_shape(shape), name.encode('utf8')))
|
||||
|
||||
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
|
||||
"""Enqueues a Parameter op onto the computation.
|
||||
|
||||
Args:
|
||||
value: a Numpy array, or a nested tuple thereof, from which the
|
||||
shape is inferred.
|
||||
name: as in ParameterWithShape.
|
||||
parameter_num: as in ParameterWithShape.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle message.
|
||||
"""
|
||||
return self.ParameterWithShape(
|
||||
Shape.from_numpy(value), name=name, parameter_num=parameter_num)
|
||||
|
||||
def Broadcast(self, operand, sizes):
|
||||
"""Enqueues a broadcast operation onto the computation.
|
||||
|
||||
Args:
|
||||
operand: the operand ComputationDataHandle to broadcast.
|
||||
sizes: an iterable of broadcast sizes.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added broadcast op.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.Broadcast(_unwrap_data_handle(operand), sizes))
|
||||
|
||||
def Concatenate(self, operands, dimension):
|
||||
"""Enqueues a concatenate operation onto the computation.
|
||||
|
||||
Args:
|
||||
operands: the operands to concatenate.
|
||||
dimension: the dimension in which to perform the concatenation.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added concatenate op.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.ConcatInDim(_unwrap_data_handles(operands), dimension))
|
||||
|
||||
def ConvertElementType(self, operand, new_element_type):
|
||||
"""Enqueues an element type conversion operation onto the computation.
|
||||
|
||||
Args:
|
||||
operand: the operand to convert.
|
||||
new_element_type: the target primitive type.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added conversion op.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.ConvertElementType(
|
||||
_unwrap_data_handle(operand), new_element_type))
|
||||
|
||||
def GetShape(self, operand):
|
||||
return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand)))
|
||||
|
||||
def GetComputationStats(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def Reshape(self, operand, dimensions, new_sizes):
|
||||
"""Reshape op."""
|
||||
return _wrap_data_handle(
|
||||
self._client.Reshape(
|
||||
_unwrap_data_handle(operand), dimensions, new_sizes))
|
||||
|
||||
def Trans(self, operand):
|
||||
"""Specialized matrix transpose op."""
|
||||
return _wrap_data_handle(
|
||||
self._client.Transpose(_unwrap_data_handle(operand), [1, 0]))
|
||||
|
||||
def Transpose(self, operand, permutation):
|
||||
"""Transpose op."""
|
||||
return _wrap_data_handle(
|
||||
self._client.Transpose(_unwrap_data_handle(operand), permutation))
|
||||
|
||||
def Select(self, pred, on_true, on_false):
|
||||
"""Element-wise selection op.
|
||||
|
||||
Constructs an output array from elements of two input arrays, based on the
|
||||
values of a predicate array.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.Select(
|
||||
_unwrap_data_handle(pred),
|
||||
_unwrap_data_handle(on_true),
|
||||
_unwrap_data_handle(on_false)))
|
||||
|
||||
def Slice(self, operand, start_indices, limit_indices, strides=None):
|
||||
"""Enqueues a slice operation onto the computation.
|
||||
|
||||
Args:
|
||||
operand: ComputationDataHandle for the N dimensional array to be sliced.
|
||||
start_indices: iterable of N integers containing the starting indices of
|
||||
the slice for each dimension.
|
||||
limit_indices: iterable of N integers containing the ending indices
|
||||
(exclusive) of the slice for each dimension.
|
||||
strides: optional iterable of N integers containing the stride sizes for
|
||||
each dimension.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added Slice op.
|
||||
"""
|
||||
if strides is None:
|
||||
start_indices = list(start_indices)
|
||||
strides = [1] * len(start_indices)
|
||||
return _wrap_data_handle(
|
||||
self._client.Slice(
|
||||
_unwrap_data_handle(operand),
|
||||
start_indices,
|
||||
limit_indices,
|
||||
strides))
|
||||
|
||||
def DynamicSlice(self, operand, start_indices, slice_sizes):
|
||||
"""Enqueues a slice op with dynamic start indices onto the computation.
|
||||
|
||||
Args:
|
||||
operand: ComputationDataHandle for the N dimensional array to be sliced.
|
||||
start_indices: ComputationDataHandle for the 1D array of N integers
|
||||
containing the starting indices of the slice.
|
||||
slice_sizes: iterable of N integers containing the slice sizes in each
|
||||
dimension.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added DynamicSlice op.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.DynamicSlice(
|
||||
_unwrap_data_handle(operand),
|
||||
_unwrap_data_handle(start_indices),
|
||||
slice_sizes))
|
||||
|
||||
def DynamicUpdateSlice(self, operand, update, start_indices):
|
||||
"""Enqueues a dynamic update slice operation onto the computation.
|
||||
|
||||
Args:
|
||||
operand: ComputationDataHandle for the N dimensional array to be updated.
|
||||
update: N dimensional array comprising the slice update.
|
||||
start_indices: Rank-1 array of N integers comprising the starting indices
|
||||
of the slice along each dimension.
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added DynamicUpdateSlice op.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.DynamicUpdateSlice(
|
||||
_unwrap_data_handle(operand),
|
||||
_unwrap_data_handle(update),
|
||||
_unwrap_data_handle(start_indices)))
|
||||
|
||||
def Tuple(self, *ops):
|
||||
"""Enqueues a tuple operation onto the computation.
|
||||
|
||||
Args:
|
||||
ops: a sequence of tuple operands (each a ComputationDataHandle).
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added Tuple op.
|
||||
"""
|
||||
return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops)))
|
||||
|
||||
def GetTupleElement(self, tup, index):
|
||||
"""Enqueues a 'get tuple element' operation onto the computation.
|
||||
|
||||
Args:
|
||||
tup: the tuple operand (a ComputationDataHandle).
|
||||
index: numeric index to select from the tuple.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added GetTupleElement op.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.GetTupleElement(_unwrap_data_handle(tup), index))
|
||||
|
||||
def Call(self, computation_to_apply, operands):
|
||||
"""Enqueues a call operation onto the computation.
|
||||
|
||||
Args:
|
||||
computation_to_apply: a Computation object.
|
||||
operands: an iterable of ComputationDataHandle. The number and types of
|
||||
operands must match the arity of computation_to_apply.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added call op.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.Call(computation_to_apply.c_local_computation,
|
||||
_unwrap_data_handles(operands)))
|
||||
|
||||
def Map(self, operands, computation_to_apply, dimensions, static_operands=()):
|
||||
"""Enqueues a map operation onto the computation.
|
||||
|
||||
Args:
|
||||
operands: an iterable of ComputationDataHandle.
|
||||
computation_to_apply: a Computation object.
|
||||
dimensions: dimensions over which to apply map the function.
|
||||
static_operands: auxiliary arguments passed to the applied computation.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added Map op.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.Map(
|
||||
_unwrap_data_handles(operands),
|
||||
computation_to_apply.c_local_computation,
|
||||
dimensions,
|
||||
_unwrap_data_handles(static_operands)))
|
||||
|
||||
def Reduce(self, operand, init_value, computation_to_apply, dimensions):
|
||||
"""Enqueues a reduction operation onto the computation.
|
||||
|
||||
Args:
|
||||
operand: reduction operand (ComputationDataHandle).
|
||||
init_value: reduction initial value (ComputationDataHandle).
|
||||
computation_to_apply: a Computation object - binary reduction function.
|
||||
dimensions: sequence of dimensions (integers) to reduce on.
|
||||
|
||||
Returns:
|
||||
A ComputationDataHandle representing the added Reduce op.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.Reduce(
|
||||
_unwrap_data_handle(operand),
|
||||
_unwrap_data_handle(init_value),
|
||||
computation_to_apply.c_local_computation,
|
||||
dimensions))
|
||||
|
||||
def While(self, cond, body, init):
|
||||
"""Enqueues a While operation onto the computation.
|
||||
|
||||
Args:
|
||||
cond: a Computation for the loop condition, which has type T -> PRED
|
||||
body: a Computation for the loop body, which has type T -> T
|
||||
init: an ComputationDataHandle for the initial parameter, which has type T
|
||||
|
||||
Returns: a ComputationDataHandle representing the While operation.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.While(cond.c_local_computation,
|
||||
body.c_local_computation,
|
||||
_unwrap_data_handle(init)))
|
||||
|
||||
def Dot(self, lhs, rhs):
|
||||
"""Matrix multiplication between lhs and rhs."""
|
||||
return _wrap_data_handle(
|
||||
self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
|
||||
|
||||
|
||||
def _forward_methods_to_local_builder():
|
||||
"""Forward remaining ComputationBuilder methods to the C API.
|
||||
|
||||
Set up methods, corresponding to unary and binary XLA operations,
|
||||
whose calls are forwarded in a boilerplate manner to the underlying
|
||||
LocalComputationBuilder C-extension API.
|
||||
"""
|
||||
|
||||
def forward_to_local_builder_with_handles(target_method, is_binop=False):
|
||||
"""Generate a forwarding method that wraps/unwraps data handles."""
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
unwrapped_args = [_unwrap_data_handle(arg) for arg in args]
|
||||
|
||||
if is_binop and len(unwrapped_args) < 3:
|
||||
unwrapped_args.append(kwargs.get('broadcast_dimensions', ()))
|
||||
|
||||
return _wrap_data_handle(
|
||||
target_method(
|
||||
self._client, # pylint: disable=protected-access
|
||||
*unwrapped_args))
|
||||
|
||||
return forward
|
||||
|
||||
for method_name in _UNARY_OPS:
|
||||
forward = forward_to_local_builder_with_handles(
|
||||
getattr(c_api.LocalComputationBuilder, method_name))
|
||||
forward.__name__ = method_name
|
||||
setattr(ComputationBuilder, method_name, forward)
|
||||
|
||||
for method_name in _BINARY_OPS:
|
||||
forward = forward_to_local_builder_with_handles(
|
||||
getattr(c_api.LocalComputationBuilder, method_name), is_binop=True)
|
||||
forward.__name__ = method_name
|
||||
setattr(ComputationBuilder, method_name, forward)
|
||||
|
||||
|
||||
_forward_methods_to_local_builder()
|
898
tensorflow/compiler/xla/python/xla_client_test.py
Normal file
898
tensorflow/compiler/xla/python/xla_client_test.py
Normal file
@ -0,0 +1,898 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the Python extension-based XLA client."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.xla.python import xla_client
|
||||
import unittest
|
||||
|
||||
|
||||
class LocalComputationTest(unittest.TestCase):
|
||||
"""Base class for running an XLA Computation through the local client."""
|
||||
|
||||
def _NewComputation(self, name=None):
|
||||
if name is None:
|
||||
name = self.id()
|
||||
return xla_client.ComputationBuilder(name)
|
||||
|
||||
def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
|
||||
assert expected is not None
|
||||
compiled_c = c.Build().CompileWithExampleArguments(arguments)
|
||||
result = compiled_c.Execute(arguments)
|
||||
# Numpy's comparison methods are a bit too lenient by treating inputs as
|
||||
# "array-like", meaning that scalar 4 will be happily compared equal to
|
||||
# [[4]]. We'd like to be more strict so assert shapes as well.
|
||||
self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape)
|
||||
assert_func(result, expected)
|
||||
|
||||
def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
|
||||
self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
|
||||
|
||||
def _ExecuteAndCompareClose(self, c, arguments=(), expected=None):
|
||||
self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments,
|
||||
expected)
|
||||
|
||||
|
||||
def NumpyArrayF32(*args, **kwargs):
|
||||
"""Convenience wrapper to create Numpy arrays with a np.float32 dtype."""
|
||||
return np.array(*args, dtype=np.float32, **kwargs)
|
||||
|
||||
|
||||
def NumpyArrayF64(*args, **kwargs):
|
||||
"""Convenience wrapper to create Numpy arrays with a np.float64 dtype."""
|
||||
return np.array(*args, dtype=np.float64, **kwargs)
|
||||
|
||||
|
||||
def NumpyArrayS32(*args, **kwargs):
|
||||
"""Convenience wrapper to create Numpy arrays with a np.int32 dtype."""
|
||||
return np.array(*args, dtype=np.int32, **kwargs)
|
||||
|
||||
|
||||
def NumpyArrayS64(*args, **kwargs):
|
||||
"""Convenience wrapper to create Numpy arrays with a np.int64 dtype."""
|
||||
return np.array(*args, dtype=np.int64, **kwargs)
|
||||
|
||||
|
||||
def NumpyArrayBool(*args, **kwargs):
|
||||
"""Convenience wrapper to create Numpy arrays with a np.bool dtype."""
|
||||
return np.array(*args, dtype=np.bool, **kwargs)
|
||||
|
||||
|
||||
class ComputationsWithConstantsTest(LocalComputationTest):
|
||||
"""Tests focusing on Constant ops."""
|
||||
|
||||
def testConstantScalarSumF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
|
||||
self._ExecuteAndCompareClose(c, expected=4.25)
|
||||
|
||||
def testConstantScalarSumF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14))
|
||||
self._ExecuteAndCompareClose(c, expected=4.25)
|
||||
|
||||
def testConstantScalarSumS32(self):
|
||||
c = self._NewComputation()
|
||||
c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2))
|
||||
self._ExecuteAndCompareClose(c, expected=3)
|
||||
|
||||
def testConstantScalarSumS64(self):
|
||||
c = self._NewComputation()
|
||||
c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2))
|
||||
self._ExecuteAndCompareClose(c, expected=3)
|
||||
|
||||
def testConstantVectorMulF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Mul(
|
||||
c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])),
|
||||
c.Constant(NumpyArrayF32([-1.2, 2, -2, -3])))
|
||||
self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
|
||||
|
||||
def testConstantVectorMulF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Mul(
|
||||
c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])),
|
||||
c.Constant(NumpyArrayF64([-1.2, 2, -2, -3])))
|
||||
self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
|
||||
|
||||
def testConstantVectorScalarDivF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Div(
|
||||
c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])),
|
||||
c.ConstantF32Scalar(2.0))
|
||||
self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
|
||||
|
||||
def testConstantVectorScalarDivF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Div(
|
||||
c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])),
|
||||
c.ConstantF64Scalar(2.0))
|
||||
self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
|
||||
|
||||
def testConstantVectorScalarPowF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.))
|
||||
self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
|
||||
|
||||
def testConstantVectorScalarPowF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.))
|
||||
self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
|
||||
|
||||
def testBooleanAnd(self):
|
||||
c = self._NewComputation()
|
||||
c.And(
|
||||
c.Constant(NumpyArrayBool([True, False, True, False])),
|
||||
c.Constant(NumpyArrayBool([True, True, False, False])))
|
||||
self._ExecuteAndCompareExact(c, expected=[True, False, False, False])
|
||||
|
||||
def testBooleanOr(self):
|
||||
c = self._NewComputation()
|
||||
c.Or(
|
||||
c.Constant(NumpyArrayBool([True, False, True, False])),
|
||||
c.Constant(NumpyArrayBool([True, True, False, False])))
|
||||
self._ExecuteAndCompareExact(c, expected=[True, True, True, False])
|
||||
|
||||
def testSum2DF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Add(
|
||||
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])),
|
||||
c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
|
||||
self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
|
||||
|
||||
def testSum2DF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Add(
|
||||
c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])),
|
||||
c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]])))
|
||||
self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
|
||||
|
||||
def testSum2DWith1DBroadcastDim0F32(self):
|
||||
# sum of a 2D array with a 1D array where the latter is replicated across
|
||||
# dimension 0 to match the former's shape.
|
||||
c = self._NewComputation()
|
||||
c.Add(
|
||||
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
|
||||
c.Constant(NumpyArrayF32([10, 20, 30])),
|
||||
broadcast_dimensions=(0,))
|
||||
self._ExecuteAndCompareClose(
|
||||
c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
|
||||
|
||||
def testSum2DWith1DBroadcastDim0F64(self):
|
||||
# sum of a 2D array with a 1D array where the latter is replicated across
|
||||
# dimension 0 to match the former's shape.
|
||||
c = self._NewComputation()
|
||||
c.Add(
|
||||
c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
|
||||
c.Constant(NumpyArrayF64([10, 20, 30])),
|
||||
broadcast_dimensions=(0,))
|
||||
self._ExecuteAndCompareClose(
|
||||
c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
|
||||
|
||||
def testSum2DWith1DBroadcastDim1F32(self):
|
||||
# sum of a 2D array with a 1D array where the latter is replicated across
|
||||
# dimension 1 to match the former's shape.
|
||||
c = self._NewComputation()
|
||||
c.Add(
|
||||
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
|
||||
c.Constant(NumpyArrayF32([10, 20, 30])),
|
||||
broadcast_dimensions=(1,))
|
||||
self._ExecuteAndCompareClose(
|
||||
c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
|
||||
|
||||
def testSum2DWith1DBroadcastDim1F64(self):
|
||||
# sum of a 2D array with a 1D array where the latter is replicated across
|
||||
# dimension 1 to match the former's shape.
|
||||
c = self._NewComputation()
|
||||
c.Add(
|
||||
c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
|
||||
c.Constant(NumpyArrayF64([10, 20, 30])),
|
||||
broadcast_dimensions=(1,))
|
||||
self._ExecuteAndCompareClose(
|
||||
c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
|
||||
|
||||
def testConstantAxpyF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Add(
|
||||
c.Mul(
|
||||
c.ConstantF32Scalar(2),
|
||||
c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))),
|
||||
c.Constant(NumpyArrayF32([100, -100, 200, -200])))
|
||||
self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
|
||||
|
||||
def testConstantAxpyF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Add(
|
||||
c.Mul(
|
||||
c.ConstantF64Scalar(2),
|
||||
c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))),
|
||||
c.Constant(NumpyArrayF64([100, -100, 200, -200])))
|
||||
self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
|
||||
|
||||
|
||||
class ParametersTest(LocalComputationTest):
|
||||
"""Tests focusing on Parameter ops and argument-passing."""
|
||||
|
||||
def setUp(self):
|
||||
self.f32_scalar_2 = NumpyArrayF32(2.0)
|
||||
self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3])
|
||||
self.f64_scalar_2 = NumpyArrayF64(2.0)
|
||||
self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3])
|
||||
self.s32_scalar_3 = NumpyArrayS32(3)
|
||||
self.s32_4vector = NumpyArrayS32([10, 15, -2, 7])
|
||||
self.s64_scalar_3 = NumpyArrayS64(3)
|
||||
self.s64_4vector = NumpyArrayS64([10, 15, -2, 7])
|
||||
|
||||
def testScalarTimesVectorAutonumberF32(self):
|
||||
c = self._NewComputation()
|
||||
p0 = c.ParameterFromNumpy(self.f32_scalar_2)
|
||||
p1 = c.ParameterFromNumpy(self.f32_4vector)
|
||||
c.Mul(p0, p1)
|
||||
self._ExecuteAndCompareClose(
|
||||
c,
|
||||
arguments=[self.f32_scalar_2, self.f32_4vector],
|
||||
expected=[-4.6, 6.6, -8.6, 10.6])
|
||||
|
||||
def testScalarTimesVectorAutonumberF64(self):
|
||||
c = self._NewComputation()
|
||||
p0 = c.ParameterFromNumpy(self.f64_scalar_2)
|
||||
p1 = c.ParameterFromNumpy(self.f64_4vector)
|
||||
c.Mul(p0, p1)
|
||||
self._ExecuteAndCompareClose(
|
||||
c,
|
||||
arguments=[self.f64_scalar_2, self.f64_4vector],
|
||||
expected=[-4.6, 6.6, -8.6, 10.6])
|
||||
|
||||
def testScalarTimesVectorS32(self):
|
||||
c = self._NewComputation()
|
||||
p0 = c.ParameterFromNumpy(self.s32_scalar_3)
|
||||
p1 = c.ParameterFromNumpy(self.s32_4vector)
|
||||
c.Mul(p0, p1)
|
||||
self._ExecuteAndCompareExact(
|
||||
c,
|
||||
arguments=[self.s32_scalar_3, self.s32_4vector],
|
||||
expected=[30, 45, -6, 21])
|
||||
|
||||
def testScalarTimesVectorS64(self):
|
||||
c = self._NewComputation()
|
||||
p0 = c.ParameterFromNumpy(self.s64_scalar_3)
|
||||
p1 = c.ParameterFromNumpy(self.s64_4vector)
|
||||
c.Mul(p0, p1)
|
||||
self._ExecuteAndCompareExact(
|
||||
c,
|
||||
arguments=[self.s64_scalar_3, self.s64_4vector],
|
||||
expected=[30, 45, -6, 21])
|
||||
|
||||
def testScalarMinusVectorExplicitNumberingF32(self):
|
||||
# Use explicit numbering and pass parameter_num first. Sub is used since
|
||||
# it's not commutative and can help catch parameter reversal within the
|
||||
# computation.
|
||||
c = self._NewComputation()
|
||||
p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1)
|
||||
p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0)
|
||||
c.Sub(p1, p0)
|
||||
self._ExecuteAndCompareClose(
|
||||
c,
|
||||
arguments=[self.f32_scalar_2, self.f32_4vector],
|
||||
expected=[-4.3, 1.3, -6.3, 3.3])
|
||||
|
||||
def testScalarMinusVectorExplicitNumberingF64(self):
|
||||
# Use explicit numbering and pass parameter_num first. Sub is used since
|
||||
# it's not commutative and can help catch parameter reversal within the
|
||||
# computation.
|
||||
c = self._NewComputation()
|
||||
p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1)
|
||||
p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0)
|
||||
c.Sub(p1, p0)
|
||||
self._ExecuteAndCompareClose(
|
||||
c,
|
||||
arguments=[self.f64_scalar_2, self.f64_4vector],
|
||||
expected=[-4.3, 1.3, -6.3, 3.3])
|
||||
|
||||
|
||||
class SingleOpTest(LocalComputationTest):
|
||||
"""Tests for single ops.
|
||||
|
||||
The goal here is smoke testing - to exercise the most basic functionality of
|
||||
single XLA ops. As minimal as possible number of additional ops are added
|
||||
around the op being tested.
|
||||
"""
|
||||
|
||||
def testConcatenateF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Concatenate(
|
||||
(c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
|
||||
c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))),
|
||||
dimension=0)
|
||||
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
|
||||
def testConcatenateF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Concatenate(
|
||||
(c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
|
||||
c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))),
|
||||
dimension=0)
|
||||
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
|
||||
def testConvertElementType(self):
|
||||
xla_types = {
|
||||
np.bool: xla_client.xla_data_pb2.PRED,
|
||||
np.int32: xla_client.xla_data_pb2.S32,
|
||||
np.int64: xla_client.xla_data_pb2.S64,
|
||||
np.float32: xla_client.xla_data_pb2.F32,
|
||||
np.float64: xla_client.xla_data_pb2.F64,
|
||||
}
|
||||
|
||||
def _ConvertAndTest(template, src_dtype, dst_dtype):
|
||||
c = self._NewComputation()
|
||||
x = c.Constant(np.array(template, dtype=src_dtype))
|
||||
c.ConvertElementType(x, xla_types[dst_dtype])
|
||||
|
||||
result = c.Build().Compile().Execute()
|
||||
expected = np.array(template, dtype=dst_dtype)
|
||||
|
||||
self.assertEqual(result.shape, expected.shape)
|
||||
self.assertEqual(result.dtype, expected.dtype)
|
||||
np.testing.assert_equal(result, expected)
|
||||
|
||||
x = [0, 1, 0, 0, 1]
|
||||
for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
|
||||
_ConvertAndTest(x, src_dtype, dst_dtype)
|
||||
|
||||
def testDotMatrixVectorF32(self):
|
||||
c = self._NewComputation()
|
||||
lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
|
||||
rhs = NumpyArrayF32([[10.0], [20.0]])
|
||||
c.Dot(c.Constant(lhs), c.Constant(rhs))
|
||||
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
|
||||
|
||||
def testDotMatrixVectorF64(self):
|
||||
c = self._NewComputation()
|
||||
lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
|
||||
rhs = NumpyArrayF64([[10.0], [20.0]])
|
||||
c.Dot(c.Constant(lhs), c.Constant(rhs))
|
||||
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
|
||||
|
||||
def testDotMatrixMatrixF32(self):
|
||||
c = self._NewComputation()
|
||||
lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
|
||||
rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]])
|
||||
c.Dot(c.Constant(lhs), c.Constant(rhs))
|
||||
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
|
||||
|
||||
def testDotMatrixMatrixF64(self):
|
||||
c = self._NewComputation()
|
||||
lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
|
||||
rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]])
|
||||
c.Dot(c.Constant(lhs), c.Constant(rhs))
|
||||
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
|
||||
|
||||
def testBooleanNot(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayBool([True, False, True])
|
||||
c.Not(c.Constant(arr))
|
||||
self._ExecuteAndCompareClose(c, expected=~arr)
|
||||
|
||||
def testExp(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayF32([3.3, 12.1])
|
||||
c.Exp(c.Constant(arr))
|
||||
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
|
||||
|
||||
def testLog(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayF32([3.3, 12.1])
|
||||
c.Log(c.Constant(arr))
|
||||
self._ExecuteAndCompareClose(c, expected=np.log(arr))
|
||||
|
||||
def testNeg(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayF32([3.3, 12.1])
|
||||
c.Neg(c.Constant(arr))
|
||||
self._ExecuteAndCompareClose(c, expected=-arr)
|
||||
|
||||
def testFloor(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayF32([3.3, 12.1])
|
||||
c.Floor(c.Constant(arr))
|
||||
self._ExecuteAndCompareClose(c, expected=np.floor(arr))
|
||||
|
||||
def testCeil(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayF32([3.3, 12.1])
|
||||
c.Ceil(c.Constant(arr))
|
||||
self._ExecuteAndCompareClose(c, expected=np.ceil(arr))
|
||||
|
||||
def testAbs(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.])
|
||||
c.Abs(c.Constant(arr))
|
||||
self._ExecuteAndCompareClose(c, expected=np.abs(arr))
|
||||
|
||||
def testTanh(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayF32([3.3, 12.1])
|
||||
c.Tanh(c.Constant(arr))
|
||||
self._ExecuteAndCompareClose(c, expected=np.tanh(arr))
|
||||
|
||||
def testTrans(self):
|
||||
|
||||
def _TransposeAndTest(array):
|
||||
c = self._NewComputation()
|
||||
c.Trans(c.Constant(array))
|
||||
self._ExecuteAndCompareClose(c, expected=array.T)
|
||||
|
||||
# Test square and non-square matrices in both default (C) and F orders.
|
||||
for array_fun in [NumpyArrayF32, NumpyArrayF64]:
|
||||
_TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]]))
|
||||
_TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F"))
|
||||
_TransposeAndTest(array_fun([[1, 2], [4, 5]]))
|
||||
_TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F"))
|
||||
|
||||
def testTranspose(self):
|
||||
|
||||
def _TransposeAndTest(array, permutation):
|
||||
c = self._NewComputation()
|
||||
c.Transpose(c.Constant(array), permutation)
|
||||
expected = np.transpose(array, permutation)
|
||||
self._ExecuteAndCompareClose(c, expected=expected)
|
||||
|
||||
_TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1])
|
||||
_TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0])
|
||||
_TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1])
|
||||
_TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0])
|
||||
|
||||
arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32)
|
||||
for permutation in itertools.permutations(range(arr.ndim)):
|
||||
_TransposeAndTest(arr, permutation)
|
||||
_TransposeAndTest(np.asfortranarray(arr), permutation)
|
||||
|
||||
def testEq(self):
|
||||
c = self._NewComputation()
|
||||
c.Eq(
|
||||
c.Constant(NumpyArrayS32([1, 2, 3, 4])),
|
||||
c.Constant(NumpyArrayS32([4, 2, 3, 1])))
|
||||
self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
|
||||
|
||||
def testNe(self):
|
||||
c = self._NewComputation()
|
||||
c.Ne(
|
||||
c.Constant(NumpyArrayS32([1, 2, 3, 4])),
|
||||
c.Constant(NumpyArrayS32([4, 2, 3, 1])))
|
||||
self._ExecuteAndCompareExact(c, expected=[True, False, False, True])
|
||||
|
||||
c.Ne(
|
||||
c.Constant(NumpyArrayF32([-2.0, 0.0,
|
||||
float("nan"),
|
||||
float("nan")])),
|
||||
c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")])))
|
||||
self._ExecuteAndAssertWith(
|
||||
np.testing.assert_allclose, c, (), expected=[True, False, True, True])
|
||||
|
||||
def testGt(self):
|
||||
c = self._NewComputation()
|
||||
c.Gt(
|
||||
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
|
||||
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
|
||||
self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False])
|
||||
|
||||
def testGe(self):
|
||||
c = self._NewComputation()
|
||||
c.Ge(
|
||||
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
|
||||
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
|
||||
self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False])
|
||||
|
||||
def testLt(self):
|
||||
c = self._NewComputation()
|
||||
c.Lt(
|
||||
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
|
||||
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
|
||||
self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True])
|
||||
|
||||
def testLe(self):
|
||||
c = self._NewComputation()
|
||||
c.Le(
|
||||
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
|
||||
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
|
||||
self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True])
|
||||
|
||||
def testMax(self):
|
||||
c = self._NewComputation()
|
||||
c.Max(
|
||||
c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
|
||||
c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
|
||||
self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0])
|
||||
|
||||
def testMaxExplicitBroadcastDim0(self):
|
||||
c = self._NewComputation()
|
||||
c.Max(
|
||||
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
|
||||
c.Constant(NumpyArrayF32([3, 4, 5])),
|
||||
broadcast_dimensions=(0,))
|
||||
self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]])
|
||||
|
||||
def testMaxExplicitBroadcastDim1(self):
|
||||
c = self._NewComputation()
|
||||
c.Max(
|
||||
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
|
||||
c.Constant(NumpyArrayF32([3, 4, 5])),
|
||||
broadcast_dimensions=(1,))
|
||||
self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]])
|
||||
|
||||
def testMin(self):
|
||||
c = self._NewComputation()
|
||||
c.Min(
|
||||
c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
|
||||
c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
|
||||
self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0])
|
||||
|
||||
def testReshape(self):
|
||||
c = self._NewComputation()
|
||||
c.Reshape(
|
||||
c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])),
|
||||
dimensions=[0, 1],
|
||||
new_sizes=[2, 3])
|
||||
self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]])
|
||||
|
||||
def testSelect(self):
|
||||
c = self._NewComputation()
|
||||
c.Select(
|
||||
c.Constant(NumpyArrayBool([True, False, False, True, False])),
|
||||
c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])),
|
||||
c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5])))
|
||||
self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5])
|
||||
|
||||
def testSlice(self):
|
||||
c = self._NewComputation()
|
||||
c.Slice(
|
||||
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0],
|
||||
[3, 2])
|
||||
self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
|
||||
|
||||
def testDynamicSlice(self):
|
||||
c = self._NewComputation()
|
||||
c.DynamicSlice(
|
||||
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
|
||||
c.Constant(NumpyArrayS32([1, 0])), [2, 2])
|
||||
self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
|
||||
|
||||
def testDynamicUpdateSlice(self):
|
||||
c = self._NewComputation()
|
||||
c.DynamicUpdateSlice(
|
||||
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
|
||||
c.Constant(NumpyArrayS32([[1, 2], [3, 4]])),
|
||||
c.Constant(NumpyArrayS32([1, 1])))
|
||||
self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]])
|
||||
|
||||
def testTuple(self):
|
||||
c = self._NewComputation()
|
||||
c.Tuple(
|
||||
c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
|
||||
c.Constant(NumpyArrayBool([True, False, False, True])))
|
||||
result = c.Build().Compile().Execute()
|
||||
self.assertIsInstance(result, tuple)
|
||||
np.testing.assert_equal(result[0], 42)
|
||||
np.testing.assert_allclose(result[1], [1.0, 2.0])
|
||||
np.testing.assert_equal(result[2], [True, False, False, True])
|
||||
|
||||
def testGetTupleElement(self):
|
||||
c = self._NewComputation()
|
||||
c.GetTupleElement(
|
||||
c.Tuple(
|
||||
c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
|
||||
c.Constant(NumpyArrayBool([True, False, False, True]))), 1)
|
||||
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0])
|
||||
|
||||
def testBroadcast(self):
|
||||
c = self._NewComputation()
|
||||
c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,))
|
||||
self._ExecuteAndCompareExact(
|
||||
c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]])
|
||||
|
||||
|
||||
class EmbeddedComputationsTest(LocalComputationTest):
|
||||
"""Tests for XLA graphs with embedded computations (such as maps)."""
|
||||
|
||||
def _CreateConstantS32Computation(self):
|
||||
"""Computation (f32) -> s32 that returns a constant 1 for any input."""
|
||||
c = self._NewComputation("constant_s32_one")
|
||||
# TODO(eliben): consider adding a nicer way to create new parameters without
|
||||
# having to create dummy Numpy arrays or populating Shape messages. Perhaps
|
||||
# we need our own (Python-client-own) way to represent Shapes conveniently.
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0))
|
||||
c.ConstantS32Scalar(1)
|
||||
return c.Build()
|
||||
|
||||
def _CreateConstantS64Computation(self):
|
||||
"""Computation (f64) -> s64 that returns a constant 1 for any input."""
|
||||
c = self._NewComputation("constant_s64_one")
|
||||
# TODO(eliben): consider adding a nicer way to create new parameters without
|
||||
# having to create dummy Numpy arrays or populating Shape messages. Perhaps
|
||||
# we need our own (Python-client-own) way to represent Shapes conveniently.
|
||||
c.ParameterFromNumpy(NumpyArrayF64(0))
|
||||
c.ConstantS64Scalar(1)
|
||||
return c.Build()
|
||||
|
||||
def _CreateConstantF32Computation(self):
|
||||
"""Computation (f32) -> f32 that returns a constant 1.0 for any input."""
|
||||
c = self._NewComputation("constant_f32_one")
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0))
|
||||
c.ConstantF32Scalar(1.0)
|
||||
return c.Build()
|
||||
|
||||
def _CreateConstantF64Computation(self):
|
||||
"""Computation (f64) -> f64 that returns a constant 1.0 for any input."""
|
||||
c = self._NewComputation("constant_f64_one")
|
||||
c.ParameterFromNumpy(NumpyArrayF64(0))
|
||||
c.ConstantF64Scalar(1.0)
|
||||
return c.Build()
|
||||
|
||||
def _CreateMulF32By2Computation(self):
|
||||
"""Computation (f32) -> f32 that multiplies its parameter by 2."""
|
||||
c = self._NewComputation("mul_f32_by2")
|
||||
c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0))
|
||||
return c.Build()
|
||||
|
||||
def _CreateMulF64By2Computation(self):
|
||||
"""Computation (f64) -> f64 that multiplies its parameter by 2."""
|
||||
c = self._NewComputation("mul_f64_by2")
|
||||
c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0))
|
||||
return c.Build()
|
||||
|
||||
def _CreateBinaryAddF32Computation(self):
|
||||
"""Computation (f32, f32) -> f32 that adds its two parameters."""
|
||||
c = self._NewComputation("add_param0_by_param1")
|
||||
c.Add(
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0)),
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0)))
|
||||
return c.Build()
|
||||
|
||||
def _CreateBinaryAddF64Computation(self):
|
||||
"""Computation (f64, f64) -> f64 that adds its two parameters."""
|
||||
c = self._NewComputation("add_param0_by_param1")
|
||||
c.Add(
|
||||
c.ParameterFromNumpy(NumpyArrayF64(0)),
|
||||
c.ParameterFromNumpy(NumpyArrayF64(0)))
|
||||
return c.Build()
|
||||
|
||||
def _CreateBinaryDivF32Computation(self):
|
||||
"""Computation (f32, f32) -> f32 that divides its two parameters."""
|
||||
c = self._NewComputation("div_param0_by_param1")
|
||||
c.Div(
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0)),
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0)))
|
||||
return c.Build()
|
||||
|
||||
def _CreateBinaryDivF64Computation(self):
|
||||
"""Computation (f64, f64) -> f64 that divides its two parameters."""
|
||||
c = self._NewComputation("div_param0_by_param1")
|
||||
c.Div(
|
||||
c.ParameterFromNumpy(NumpyArrayF64(0)),
|
||||
c.ParameterFromNumpy(NumpyArrayF64(0)))
|
||||
return c.Build()
|
||||
|
||||
def _CreateTestF32Lt10Computation(self):
|
||||
"""Computation (f32) -> bool that tests if its parameter is less than 10."""
|
||||
c = self._NewComputation("test_f32_lt_10")
|
||||
c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.))
|
||||
return c.Build()
|
||||
|
||||
def _CreateTestF64Lt10Computation(self):
|
||||
"""Computation (f64) -> bool that tests if its parameter is less than 10."""
|
||||
c = self._NewComputation("test_f64_lt_10")
|
||||
c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.))
|
||||
return c.Build()
|
||||
|
||||
def _MakeSample3DArrayF32(self):
|
||||
return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
|
||||
[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
|
||||
|
||||
def _MakeSample3DArrayF64(self):
|
||||
return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
|
||||
[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
|
||||
|
||||
def testCallF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Call(
|
||||
self._CreateMulF32By2Computation(),
|
||||
operands=(c.ConstantF32Scalar(5.0),))
|
||||
self._ExecuteAndCompareClose(c, expected=10.0)
|
||||
|
||||
def testCallF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Call(
|
||||
self._CreateMulF64By2Computation(),
|
||||
operands=(c.ConstantF64Scalar(5.0),))
|
||||
self._ExecuteAndCompareClose(c, expected=10.0)
|
||||
|
||||
def testMapEachElementToS32Constant(self):
|
||||
c = self._NewComputation()
|
||||
c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
|
||||
self._CreateConstantS32Computation(), [0])
|
||||
self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
|
||||
|
||||
def testMapEachElementToS64Constant(self):
|
||||
c = self._NewComputation()
|
||||
c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
|
||||
self._CreateConstantS64Computation(), [0])
|
||||
self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
|
||||
|
||||
def testMapMulBy2F32(self):
|
||||
c = self._NewComputation()
|
||||
c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
|
||||
self._CreateMulF32By2Computation(), [0])
|
||||
self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
|
||||
|
||||
def testMapMulBy2F64(self):
|
||||
c = self._NewComputation()
|
||||
c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
|
||||
self._CreateMulF64By2Computation(), [0])
|
||||
self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
|
||||
|
||||
def testSimpleMapChainF32(self):
|
||||
# Chains a map of constant-f32 with a map of mul-by-2
|
||||
c = self._NewComputation()
|
||||
const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
|
||||
self._CreateConstantF32Computation(), [0])
|
||||
c.Map([const_f32], self._CreateMulF32By2Computation(), [0])
|
||||
self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
|
||||
|
||||
def testSimpleMapChainF64(self):
|
||||
# Chains a map of constant-f64 with a map of mul-by-2
|
||||
c = self._NewComputation()
|
||||
const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
|
||||
self._CreateConstantF64Computation(), [0])
|
||||
c.Map([const_f64], self._CreateMulF64By2Computation(), [0])
|
||||
self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
|
||||
|
||||
def testDivVectorsWithMapF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
|
||||
c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))),
|
||||
self._CreateBinaryDivF32Computation(), [0])
|
||||
self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
|
||||
|
||||
def testDivVectorsWithMapF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
|
||||
c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))),
|
||||
self._CreateBinaryDivF64Computation(), [0])
|
||||
self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
|
||||
|
||||
def testReduce1DtoScalarF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Reduce(
|
||||
operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
|
||||
init_value=c.ConstantF32Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF32Computation(),
|
||||
dimensions=[0])
|
||||
self._ExecuteAndCompareClose(c, expected=10)
|
||||
|
||||
def testReduce1DtoScalarF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Reduce(
|
||||
operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
|
||||
init_value=c.ConstantF64Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF64Computation(),
|
||||
dimensions=[0])
|
||||
self._ExecuteAndCompareClose(c, expected=10)
|
||||
|
||||
def testReduce2DTo1DDim0F32(self):
|
||||
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
c = self._NewComputation()
|
||||
c.Reduce(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF32Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF32Computation(),
|
||||
dimensions=[0])
|
||||
self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
|
||||
|
||||
def testReduce2DTo1DDim0F64(self):
|
||||
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
c = self._NewComputation()
|
||||
c.Reduce(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF64Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF64Computation(),
|
||||
dimensions=[0])
|
||||
self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
|
||||
|
||||
def testReduce2DTo1DDim1F32(self):
|
||||
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
c = self._NewComputation()
|
||||
c.Reduce(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF32Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF32Computation(),
|
||||
dimensions=[1])
|
||||
self._ExecuteAndCompareClose(c, expected=[6, 15])
|
||||
|
||||
def testReduce2DTo1DDim1F64(self):
|
||||
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
c = self._NewComputation()
|
||||
c.Reduce(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF64Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF64Computation(),
|
||||
dimensions=[1])
|
||||
self._ExecuteAndCompareClose(c, expected=[6, 15])
|
||||
|
||||
def testReduce3DAllPossibleWaysF32(self):
|
||||
input_array = self._MakeSample3DArrayF32()
|
||||
|
||||
def _ReduceAndTest(*dims):
|
||||
c = self._NewComputation()
|
||||
c.Reduce(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF32Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF32Computation(),
|
||||
dimensions=dims)
|
||||
self._ExecuteAndCompareClose(
|
||||
c, expected=np.sum(input_array, axis=tuple(dims)))
|
||||
|
||||
_ReduceAndTest(0)
|
||||
_ReduceAndTest(0)
|
||||
_ReduceAndTest(0, 1)
|
||||
_ReduceAndTest(0, 2)
|
||||
_ReduceAndTest(1, 2)
|
||||
_ReduceAndTest(0, 1, 2)
|
||||
|
||||
def testReduce3DAllPossibleWaysF64(self):
|
||||
input_array = self._MakeSample3DArrayF64()
|
||||
|
||||
def _ReduceAndTest(*dims):
|
||||
c = self._NewComputation()
|
||||
c.Reduce(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF64Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF64Computation(),
|
||||
dimensions=dims)
|
||||
self._ExecuteAndCompareClose(
|
||||
c, expected=np.sum(input_array, axis=tuple(dims)))
|
||||
|
||||
_ReduceAndTest(0)
|
||||
_ReduceAndTest(0)
|
||||
_ReduceAndTest(0, 1)
|
||||
_ReduceAndTest(0, 2)
|
||||
_ReduceAndTest(1, 2)
|
||||
_ReduceAndTest(0, 1, 2)
|
||||
|
||||
def testWhileF32(self):
|
||||
cond = self._CreateTestF32Lt10Computation()
|
||||
body = self._CreateMulF32By2Computation()
|
||||
c = self._NewComputation()
|
||||
init = c.ConstantF32Scalar(1.)
|
||||
c.While(cond, body, init)
|
||||
self._ExecuteAndCompareClose(c, expected=16.)
|
||||
|
||||
def testWhileF64(self):
|
||||
cond = self._CreateTestF64Lt10Computation()
|
||||
body = self._CreateMulF64By2Computation()
|
||||
c = self._NewComputation()
|
||||
init = c.ConstantF64Scalar(1.)
|
||||
c.While(cond, body, init)
|
||||
self._ExecuteAndCompareClose(c, expected=16.)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -4,3 +4,4 @@
|
||||
*TF_*
|
||||
*TFE_*
|
||||
*nsync_*
|
||||
*pywrap_xla*
|
||||
|
@ -5,6 +5,7 @@ tensorflow {
|
||||
*TF_*;
|
||||
*TFE_*;
|
||||
*nsync_*;
|
||||
*pywrap_xla*;
|
||||
local:
|
||||
*;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user