[TF:XLA] Split literal_util into {literal, literal_util}.
Currently Literal classes sits in literal_util.{h,cc} instead of literal.{h,cc}. It also contains helper functions that are better fit to be their own separate class/namespace. This change starts this process by moving most static factory methods to LiteralUtil namespace. PiperOrigin-RevId: 203217065
This commit is contained in:
parent
1e7dde8791
commit
8779f768a3
@ -162,7 +162,7 @@ cc_library(
|
||||
":sharding_util",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -202,7 +202,7 @@ cc_library(
|
||||
],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -285,6 +285,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":tf2xla",
|
||||
":tf2xla_proto",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
@ -327,7 +328,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:resource_variable_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
@ -364,6 +365,7 @@ tf_cc_test(
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -114,6 +114,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla/lib:while_loop",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -159,7 +160,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -175,7 +176,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -210,6 +211,7 @@ tf_kernel_library(
|
||||
":index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
|
||||
std::vector<xla::XlaOp> args;
|
||||
args.push_back(ctx->Input(0));
|
||||
args.push_back(xla::ConstantLiteral(
|
||||
&b, *xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
|
||||
&b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
|
||||
if (input_shape.dims() > 1) {
|
||||
// Don't bother passing the output shape and dim for the 1d case, since
|
||||
// the shape is always a scalar and the dim is always 0.
|
||||
args.push_back(xla::ConstantLiteral(
|
||||
&b, *xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
|
||||
&b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
|
||||
args.push_back(
|
||||
xla::ConstantLiteral(&b, *xla::Literal::CreateR0<int32>(dim)));
|
||||
xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
|
||||
}
|
||||
|
||||
xla::Shape xla_shape =
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/numeric.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
|
@ -40,7 +40,7 @@ cc_library(
|
||||
":triangular_solve",
|
||||
":util",
|
||||
":while_loop",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -73,7 +73,7 @@ cc_library(
|
||||
deps = [
|
||||
":util",
|
||||
":while_loop",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -92,7 +92,7 @@ cc_library(
|
||||
deps = [
|
||||
":batch_dot",
|
||||
":util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -111,7 +111,7 @@ xla_test(
|
||||
deps = [
|
||||
":triangular_solve",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -133,6 +133,7 @@ cc_library(
|
||||
srcs = ["util.cc"],
|
||||
hdrs = ["util.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -151,7 +152,7 @@ xla_test(
|
||||
":batch_dot",
|
||||
":util",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
|
@ -84,7 +84,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
|
||||
dimensions.push_back(y_shape.dimensions(y_outer_dim));
|
||||
return xla::Broadcast(
|
||||
xla::ConstantLiteral(builder,
|
||||
xla::Literal::Zero(x_shape.element_type())),
|
||||
xla::LiteralUtil::Zero(x_shape.element_type())),
|
||||
dimensions);
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -114,7 +114,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
|
||||
auto buffer = loop_vars[2];
|
||||
|
||||
auto zero_index = xla::ConstantLiteral(
|
||||
body_builder, xla::Literal::Zero(indices_shape.element_type()));
|
||||
body_builder, xla::LiteralUtil::Zero(indices_shape.element_type()));
|
||||
|
||||
// Slice the i-th index from the indices array.
|
||||
xla::XlaOp index;
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -28,6 +29,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) {
|
||||
return xla::Broadcast(
|
||||
xla::ConstantLiteral(builder,
|
||||
xla::LiteralUtil::Zero(shape.element_type())),
|
||||
xla::AsInt64Slice(shape.dimensions()));
|
||||
}
|
||||
|
||||
xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
double value) {
|
||||
switch (type) {
|
||||
@ -56,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
xla::Literal literal;
|
||||
switch (type) {
|
||||
case xla::U8:
|
||||
literal = std::move(*xla::Literal::CreateR0<uint8>(value));
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
|
||||
break;
|
||||
case xla::U32:
|
||||
literal = std::move(*xla::Literal::CreateR0<uint32>(value));
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
|
||||
break;
|
||||
case xla::U64:
|
||||
literal = std::move(*xla::Literal::CreateR0<uint64>(value));
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
|
||||
break;
|
||||
case xla::S8:
|
||||
literal = std::move(*xla::Literal::CreateR0<int8>(value));
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
|
||||
break;
|
||||
case xla::S32:
|
||||
literal = std::move(*xla::Literal::CreateR0<int32>(value));
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
|
||||
break;
|
||||
case xla::S64:
|
||||
literal = std::move(*xla::Literal::CreateR0<int64>(value));
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
|
||||
break;
|
||||
case xla::F32:
|
||||
literal = std::move(*xla::Literal::CreateR0<float>(value));
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
|
||||
break;
|
||||
case xla::F64:
|
||||
literal = std::move(*xla::Literal::CreateR0<double>(value));
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
|
||||
break;
|
||||
case xla::C64:
|
||||
literal = std::move(*xla::Literal::CreateR0<complex64>(value));
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
|
||||
break;
|
||||
case xla::PRED:
|
||||
LOG(FATAL) << "pred element type is not integral";
|
||||
@ -89,11 +97,11 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case xla::BF16:
|
||||
literal = std::move(
|
||||
*xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
|
||||
*xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
|
||||
break;
|
||||
case xla::F16:
|
||||
literal = std::move(
|
||||
*xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)));
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
|
||||
static_cast<xla::half>(value)));
|
||||
break;
|
||||
case xla::TUPLE:
|
||||
LOG(FATAL) << "tuple element type is not integral";
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
|
@ -100,8 +100,9 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
|
||||
std::vector<xla::XlaOp> updated_values;
|
||||
updated_values.reserve(values.size());
|
||||
updated_values.push_back(xla::Add(
|
||||
iteration, xla::ConstantLiteral(
|
||||
body_builder, xla::Literal::One(num_iterations_type))));
|
||||
iteration,
|
||||
xla::ConstantLiteral(body_builder,
|
||||
xla::LiteralUtil::One(num_iterations_type))));
|
||||
|
||||
values.remove_prefix(1);
|
||||
TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs,
|
||||
@ -113,8 +114,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
|
||||
|
||||
std::vector<xla::XlaOp> values;
|
||||
values.reserve(initial_values.size() + 1);
|
||||
values.push_back(
|
||||
xla::ConstantLiteral(builder, xla::Literal::Zero(num_iterations_type)));
|
||||
values.push_back(xla::ConstantLiteral(
|
||||
builder, xla::LiteralUtil::Zero(num_iterations_type)));
|
||||
values.insert(values.end(), initial_values.begin(), initial_values.end());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
@ -27,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
|
||||
{
|
||||
std::vector<int64> int64_values = {1, 2, 3};
|
||||
std::unique_ptr<xla::Literal> int64_values_literal =
|
||||
xla::Literal::CreateR1(gtl::ArraySlice<int64>(int64_values));
|
||||
xla::LiteralUtil::CreateR1(gtl::ArraySlice<int64>(int64_values));
|
||||
Tensor host_tensor;
|
||||
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
|
||||
LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
|
||||
@ -48,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
|
||||
Tensor host_tensor;
|
||||
std::vector<int32> int32_values = {10, 11};
|
||||
std::unique_ptr<xla::Literal> int32_values_literal =
|
||||
xla::Literal::CreateR1(gtl::ArraySlice<int32>(int32_values));
|
||||
xla::LiteralUtil::CreateR1(gtl::ArraySlice<int32>(int32_values));
|
||||
EXPECT_TRUE(
|
||||
LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
|
||||
.ok());
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
@ -73,8 +74,8 @@ TEST(ConvertGraphDefToXla, Sum) {
|
||||
TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation));
|
||||
|
||||
// Set up arguments.
|
||||
auto x_literal = xla::Literal::CreateR0<int32>(10);
|
||||
auto y_literal = xla::Literal::CreateR0<int32>(32);
|
||||
auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
|
||||
auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
|
||||
auto x_global_or = client->TransferToServer(*x_literal);
|
||||
auto y_global_or = client->TransferToServer(*y_literal);
|
||||
TF_EXPECT_OK(x_global_or.status());
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
@ -206,9 +206,9 @@ TEST_F(XlaCompilerTest, Simple) {
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::Literal::CreateR1<int32>({7, 42});
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::Literal::CreateR1<int32>({-3, 101});
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
@ -222,9 +222,9 @@ TEST_F(XlaCompilerTest, Simple) {
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::Literal::CreateR1<int32>({4, 143});
|
||||
xla::LiteralUtil::CreateR1<int32>({4, 143});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::Literal::MakeTuple({expected0.get()});
|
||||
xla::LiteralUtil::MakeTuple({expected0.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
|
||||
@ -306,7 +306,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::Literal::CreateR1<int32>({7, 42});
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
|
||||
@ -317,9 +317,9 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::Literal::CreateR1<int32>({-7, -42});
|
||||
xla::LiteralUtil::CreateR1<int32>({-7, -42});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::Literal::MakeTuple({expected0.get()});
|
||||
xla::LiteralUtil::MakeTuple({expected0.get()});
|
||||
EXPECT_TRUE(
|
||||
xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
@ -341,7 +341,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::Literal::CreateR1<int32>({7, 42});
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
|
||||
@ -351,11 +351,12 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 = xla::Literal::CreateR0<int32>(7);
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::LiteralUtil::CreateR0<int32>(7);
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::Literal::CreateR1<int32>({-7, -42});
|
||||
xla::LiteralUtil::CreateR1<int32>({-7, -42});
|
||||
std::unique_ptr<xla::Literal> expected =
|
||||
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
|
||||
}
|
||||
}
|
||||
@ -569,11 +570,11 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> input_base =
|
||||
xla::Literal::CreateR1<int32>({7, 42});
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> input_grad2 =
|
||||
xla::Literal::CreateR1<int32>({-3, 101});
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::Literal> input =
|
||||
xla::Literal::MakeTuple({input_base.get(), input_grad2.get()});
|
||||
xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*input).ConsumeValueOrDie();
|
||||
|
||||
@ -583,17 +584,18 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42);
|
||||
std::unique_ptr<xla::Literal> output_read =
|
||||
xla::LiteralUtil::CreateR0<int32>(42);
|
||||
std::unique_ptr<xla::Literal> output_base =
|
||||
xla::Literal::CreateR1<int32>({7, 42});
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> output_grad1 =
|
||||
xla::Literal::CreateR1<int32>({0, 1});
|
||||
xla::LiteralUtil::CreateR1<int32>({0, 1});
|
||||
std::unique_ptr<xla::Literal> output_grad2 =
|
||||
xla::Literal::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple(
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
|
||||
{output_base.get(), output_grad1.get(), output_grad2.get()});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
|
||||
xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
|
||||
@ -796,9 +798,9 @@ TEST_F(XlaCompilerTest, Variables) {
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::Literal::CreateR1<int32>({7, 42});
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::Literal::CreateR1<int32>({-3, 101});
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
@ -812,11 +814,11 @@ TEST_F(XlaCompilerTest, Variables) {
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::Literal::CreateR1<int32>({5, 144});
|
||||
xla::LiteralUtil::CreateR1<int32>({5, 144});
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::Literal::CreateR1<int32>({4, 143});
|
||||
xla::LiteralUtil::CreateR1<int32>({4, 143});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
|
||||
@ -884,9 +886,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}});
|
||||
xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::Literal::CreateR1<int32>({22, 11, 33, 404});
|
||||
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
@ -900,11 +902,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}});
|
||||
xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::Literal::CreateR1<int32>({26, 66, 34, 401});
|
||||
xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
|
||||
@ -953,9 +955,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::Literal::CreateR1<int32>({4, 55, 1, -3});
|
||||
xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::Literal::CreateR1<int32>({22, 11, 33, 404});
|
||||
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
@ -969,11 +971,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::Literal::CreateR1<int32>({27, 67, 35, 402});
|
||||
xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::Literal::CreateR1<int32>({26, 66, 34, 401});
|
||||
xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
@ -94,13 +94,13 @@ xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis,
|
||||
xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
||||
return xla::ConstantLiteral(b, xla::Literal::Zero(type));
|
||||
return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
|
||||
}
|
||||
|
||||
xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
||||
return xla::ConstantLiteral(b, xla::Literal::One(type));
|
||||
return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
|
||||
}
|
||||
|
||||
xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
|
||||
|
@ -281,9 +281,9 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "literal_util",
|
||||
srcs = ["literal_util.cc"],
|
||||
hdrs = ["literal_util.h"],
|
||||
name = "literal",
|
||||
srcs = ["literal.cc"],
|
||||
hdrs = ["literal.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":array2d",
|
||||
@ -300,11 +300,12 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "literal_util_test",
|
||||
srcs = ["literal_util_test.cc"],
|
||||
name = "literal_test",
|
||||
srcs = ["literal_test.cc"],
|
||||
deps = [
|
||||
":array3d",
|
||||
":array4d",
|
||||
":literal",
|
||||
":literal_util",
|
||||
":shape_util",
|
||||
":test",
|
||||
@ -316,6 +317,26 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "literal_util",
|
||||
srcs = ["literal_util.cc"],
|
||||
hdrs = ["literal_util.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":array2d",
|
||||
":array3d",
|
||||
":array4d",
|
||||
":literal",
|
||||
":shape_util",
|
||||
":sparse_index_array",
|
||||
":status_macros",
|
||||
":types",
|
||||
":util",
|
||||
":xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "error_spec",
|
||||
hdrs = ["error_spec.h"],
|
||||
@ -327,6 +348,7 @@ cc_library(
|
||||
hdrs = ["literal_comparison.h"],
|
||||
deps = [
|
||||
":error_spec",
|
||||
":literal",
|
||||
":literal_util",
|
||||
":util",
|
||||
"//tensorflow/core:lib",
|
||||
@ -458,7 +480,7 @@ cc_library(
|
||||
hdrs = ["packed_literal_reader.h"],
|
||||
visibility = [":internal"],
|
||||
deps = [
|
||||
":literal_util",
|
||||
":literal",
|
||||
":shape_util",
|
||||
":status_macros",
|
||||
":statusor",
|
||||
@ -489,7 +511,7 @@ cc_library(
|
||||
hdrs = ["text_literal_reader.h"],
|
||||
visibility = [":internal"],
|
||||
deps = [
|
||||
":literal_util",
|
||||
":literal",
|
||||
":shape_util",
|
||||
":status_macros",
|
||||
":statusor",
|
||||
@ -505,7 +527,7 @@ tf_cc_test(
|
||||
name = "text_literal_reader_test",
|
||||
srcs = ["text_literal_reader_test.cc"],
|
||||
deps = [
|
||||
":literal_util",
|
||||
":literal",
|
||||
":shape_util",
|
||||
":test",
|
||||
":text_literal_reader",
|
||||
@ -522,7 +544,7 @@ cc_library(
|
||||
hdrs = ["text_literal_writer.h"],
|
||||
visibility = [":internal"],
|
||||
deps = [
|
||||
":literal_util",
|
||||
":literal",
|
||||
":shape_util",
|
||||
":status_macros",
|
||||
":types",
|
||||
@ -535,6 +557,7 @@ tf_cc_test(
|
||||
name = "text_literal_writer_test",
|
||||
srcs = ["text_literal_writer_test.cc"],
|
||||
deps = [
|
||||
":literal",
|
||||
":literal_util",
|
||||
":test",
|
||||
":test_helpers",
|
||||
@ -607,6 +630,7 @@ cc_library(
|
||||
":array2d",
|
||||
":array3d",
|
||||
":array4d",
|
||||
":literal_util",
|
||||
":util",
|
||||
":window_util",
|
||||
":xla_data_proto",
|
||||
@ -627,7 +651,7 @@ tf_cc_test(
|
||||
":array2d",
|
||||
":array3d",
|
||||
":array4d",
|
||||
":literal_util",
|
||||
":literal",
|
||||
":reference_util",
|
||||
":test",
|
||||
":util",
|
||||
|
@ -65,7 +65,7 @@ cc_library(
|
||||
deps = [
|
||||
":global_data",
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:service_interface",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
@ -82,6 +82,7 @@ xla_test(
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
":math",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -123,7 +124,7 @@ cc_library(
|
||||
hdrs = ["testing.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
|
||||
return ConstantLiteral(builder, Literal::Zero(type));
|
||||
return ConstantLiteral(builder, LiteralUtil::Zero(type));
|
||||
}
|
||||
|
||||
XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
|
||||
@ -38,7 +38,7 @@ XlaOp ZerosLike(XlaOp prototype) {
|
||||
}
|
||||
|
||||
XlaOp One(XlaBuilder* builder, PrimitiveType type) {
|
||||
return ConstantLiteral(builder, Literal::One(type));
|
||||
return ConstantLiteral(builder, LiteralUtil::One(type));
|
||||
}
|
||||
|
||||
XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
|
||||
@ -61,7 +61,7 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
|
||||
}
|
||||
|
||||
XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
|
||||
return ConstantLiteral(builder, Literal::MinValue(type));
|
||||
return ConstantLiteral(builder, LiteralUtil::MinValue(type));
|
||||
}
|
||||
|
||||
XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
|
||||
@ -81,7 +81,7 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
|
||||
}
|
||||
|
||||
XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
|
||||
return ConstantLiteral(builder, Literal::MaxValue(type));
|
||||
return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
|
||||
}
|
||||
|
||||
XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
@ -31,7 +32,7 @@ class MathTest : public ClientLibraryTestBase {
|
||||
|
||||
XLA_TEST_F(MathTest, SqrtF32) {
|
||||
XlaBuilder builder(TestName());
|
||||
Literal zero_literal = Literal::Zero(PrimitiveType::F32);
|
||||
Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32);
|
||||
|
||||
std::unique_ptr<GlobalData> zero_data =
|
||||
client_->TransferToServer(zero_literal).ConsumeValueOrDie();
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
@ -49,7 +49,7 @@ int64 DataSizeOfShape(const Shape& shape) {
|
||||
XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) {
|
||||
if (ShapeUtil::IsArray(shape)) {
|
||||
return Broadcast(
|
||||
ConstantLiteral(builder, Literal::One(shape.element_type())),
|
||||
ConstantLiteral(builder, LiteralUtil::One(shape.element_type())),
|
||||
AsInt64Slice(shape.dimensions()));
|
||||
}
|
||||
std::vector<XlaOp> parts;
|
||||
|
@ -43,6 +43,7 @@ cc_library(
|
||||
deps = [
|
||||
":xla_computation",
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -64,7 +65,7 @@ tf_cc_test(
|
||||
srcs = ["xla_builder_test.cc"],
|
||||
deps = [
|
||||
":xla_builder",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
|
@ -736,7 +736,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
|
||||
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = ShapeUtil::MakeNil();
|
||||
*instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto();
|
||||
*instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
|
||||
});
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/padding.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -1943,12 +1944,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR0(NativeT value) {
|
||||
return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
|
||||
return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) {
|
||||
return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
|
||||
return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -1960,44 +1961,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
|
||||
}
|
||||
|
||||
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
|
||||
return ConstantLiteral(*Literal::CreateR1(values));
|
||||
return ConstantLiteral(*LiteralUtil::CreateR1(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR2(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values) {
|
||||
return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
|
||||
return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
*Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
|
||||
return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
|
||||
return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
|
||||
const Array2D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
*Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
|
||||
return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
|
||||
return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
|
||||
const Array3D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
*Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
|
||||
*LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2020,13 +2021,13 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
|
||||
return ConstantLiteral(builder, *Literal::CreateR0<NativeT>(value));
|
||||
return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp ConstantR1(XlaBuilder* builder,
|
||||
tensorflow::gtl::ArraySlice<NativeT> values) {
|
||||
return ConstantLiteral(builder, *Literal::CreateR1<NativeT>(values));
|
||||
return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2039,13 +2040,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
|
||||
|
||||
inline XlaOp ConstantR1(XlaBuilder* builder,
|
||||
const tensorflow::core::Bitmap& values) {
|
||||
return ConstantLiteral(builder, *Literal::CreateR1(values));
|
||||
return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp ConstantR2(XlaBuilder* builder,
|
||||
std::initializer_list<std::initializer_list<NativeT>> values) {
|
||||
return ConstantLiteral(builder, *Literal::CreateR2<NativeT>(values));
|
||||
return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2053,12 +2054,14 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
|
||||
const Array<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
builder, *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
builder,
|
||||
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
|
||||
return ConstantLiteral(builder, *Literal::CreateFromArray<NativeT>(values));
|
||||
return ConstantLiteral(builder,
|
||||
*LiteralUtil::CreateFromArray<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2066,14 +2069,15 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
|
||||
const Array2D<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
builder, *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
builder,
|
||||
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
|
||||
const Array2D<NativeT>& values) {
|
||||
return ConstantLiteral(builder,
|
||||
*Literal::CreateR2FromArray2D<NativeT>(values));
|
||||
*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2082,7 +2086,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
|
||||
const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
builder,
|
||||
*Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
|
||||
*LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
|
1967
tensorflow/compiler/xla/literal.cc
Normal file
1967
tensorflow/compiler/xla/literal.cc
Normal file
File diff suppressed because it is too large
Load Diff
1152
tensorflow/compiler/xla/literal.h
Normal file
1152
tensorflow/compiler/xla/literal.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/casts.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
@ -217,7 +218,7 @@ class NearComparator {
|
||||
return Printf(
|
||||
"actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
|
||||
FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
|
||||
Literal::MultiIndexAsString(
|
||||
LiteralUtil::MultiIndexAsString(
|
||||
IndexUtil::LinearIndexToMultidimensionalIndex(shape,
|
||||
linear_index))
|
||||
.c_str(),
|
||||
@ -722,7 +723,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
|
||||
return AppendStatus(result,
|
||||
tensorflow::strings::Printf(
|
||||
"\nat index: %s\nexpected: %s\nactual: %s",
|
||||
Literal::MultiIndexAsString(multi_index).c_str(),
|
||||
LiteralUtil::MultiIndexAsString(multi_index).c_str(),
|
||||
ToStringTruncated(expected).c_str(),
|
||||
ToStringTruncated(actual).c_str()));
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/error_spec.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace xla {
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -33,6 +33,7 @@ cc_library(
|
||||
srcs = ["numpy_bridge.cc"],
|
||||
hdrs = ["numpy_bridge.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -70,7 +71,7 @@ tf_py_wrap_cc(
|
||||
deps = [
|
||||
":local_computation_builder",
|
||||
":numpy_bridge",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
|
@ -109,7 +109,7 @@ limitations under the License.
|
||||
// Must be included first
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
@ -374,7 +375,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
|
||||
TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element));
|
||||
elements.push_back(std::move(literal));
|
||||
}
|
||||
return Literal::MakeTupleOwned(std::move(elements));
|
||||
return LiteralUtil::MakeTupleOwned(std::move(elements));
|
||||
} else if (PyArray_Check(o)) {
|
||||
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(o);
|
||||
int rank = PyArray_NDIM(py_array);
|
||||
@ -383,7 +384,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
|
||||
dimensions[i] = PyArray_DIM(py_array, i);
|
||||
}
|
||||
int np_type = PyArray_TYPE(py_array);
|
||||
auto literal = Literal::CreateFromDimensions(
|
||||
auto literal = LiteralUtil::CreateFromDimensions(
|
||||
NumpyTypeToPrimitiveType(np_type), dimensions);
|
||||
TF_RETURN_IF_ERROR(
|
||||
CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -510,8 +511,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
|
||||
ConvolutionDimensionNumbers dnums) {
|
||||
HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
|
||||
auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs);
|
||||
auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs);
|
||||
auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
|
||||
auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
|
||||
|
||||
std::array<int64, 2> ordered_kernel_strides;
|
||||
std::array<int64, 2> ordered_input_dimensions;
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/array3d.h"
|
||||
#include "tensorflow/compiler/xla/array4d.h"
|
||||
#include "tensorflow/compiler/xla/client/padding.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
@ -53,7 +53,7 @@ class ReferenceUtilTest : public ::testing::Test {
|
||||
|
||||
TEST_F(ReferenceUtilTest, TransposeArray2D) {
|
||||
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
@ -65,7 +65,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
|
||||
{11.f, 12.f},
|
||||
});
|
||||
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
@ -73,7 +73,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
|
||||
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
|
||||
auto add = [](float lhs, float rhs) { return lhs + rhs; };
|
||||
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
|
||||
auto actual_literal = Literal::CreateR1<float>(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
|
||||
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
@ -81,13 +81,13 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
|
||||
TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
|
||||
auto add = [](float lhs, float rhs) { return lhs + rhs; };
|
||||
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
|
||||
auto actual_literal = Literal::CreateR1<float>(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
|
||||
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
|
||||
auto result = Literal::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
|
||||
auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
|
||||
Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
|
||||
[](float a, float b) { return a + b; }));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
|
||||
@ -96,7 +96,7 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
|
||||
TEST_F(ReferenceUtilTest, MapArray2D) {
|
||||
auto identity = [](float value) { return log(exp(value)); };
|
||||
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
@ -106,7 +106,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
|
||||
return value + row + col;
|
||||
};
|
||||
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
@ -117,7 +117,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
|
||||
input->FillWithMultiples(1.0f);
|
||||
auto multiply_by_two = [](float value) { return 2 * value; };
|
||||
auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two);
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
|
||||
|
||||
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
|
||||
expected.FillWithMultiples(2.0f);
|
||||
@ -134,7 +134,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
|
||||
return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width);
|
||||
};
|
||||
auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index);
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
|
||||
|
||||
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
|
||||
expected.Fill(0.0f);
|
||||
@ -144,7 +144,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
|
||||
|
||||
TEST_F(ReferenceUtilTest, SliceArray2D) {
|
||||
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
@ -152,7 +152,7 @@ TEST_F(ReferenceUtilTest, SliceArray2D) {
|
||||
|
||||
TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
|
||||
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
@ -164,7 +164,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) {
|
||||
|
||||
auto result =
|
||||
ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 2, 2}}, {{1, 1, 1}});
|
||||
auto actual_literal = Literal::CreateR3FromArray3D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
|
||||
|
||||
LiteralTestUtil::ExpectR3Near<float>(
|
||||
{{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
|
||||
@ -177,7 +177,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
|
||||
|
||||
auto result =
|
||||
ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 3, 4}}, {{1, 2, 2}});
|
||||
auto actual_literal = Literal::CreateR3FromArray3D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
|
||||
|
||||
LiteralTestUtil::ExpectR3Near<float>(
|
||||
{{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
|
||||
@ -190,7 +190,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) {
|
||||
|
||||
auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 2, 2, 2}},
|
||||
{{1, 1, 1, 1}});
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
|
||||
|
||||
LiteralTestUtil::ExpectR4Near<float>(
|
||||
{{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
|
||||
@ -203,7 +203,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
|
||||
|
||||
auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 3, 4, 5}},
|
||||
{{1, 2, 2, 2}});
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*result);
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
|
||||
|
||||
LiteralTestUtil::ExpectR4Near<float>(
|
||||
{{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
|
||||
@ -218,7 +218,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
|
||||
ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kSame);
|
||||
Array3D<float> expected = {{{17, 28, 39, 20}}};
|
||||
|
||||
auto actual_literal = Literal::CreateR3FromArray3D(*actual);
|
||||
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
@ -231,7 +231,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) {
|
||||
ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kValid);
|
||||
Array3D<float> expected = {{{17, 28, 39}}};
|
||||
|
||||
auto actual_literal = Literal::CreateR3FromArray3D(*actual);
|
||||
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
@ -266,7 +266,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
|
||||
}));
|
||||
// clang-format on
|
||||
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
@ -300,7 +300,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
|
||||
}));
|
||||
// clang-format on
|
||||
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
@ -356,7 +356,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
|
||||
}});
|
||||
// clang-format on
|
||||
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
@ -409,7 +409,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
|
||||
Array4D<float> expected({{{{2514, 2685}}}});
|
||||
// clang-format on
|
||||
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
@ -422,7 +422,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
|
||||
|
||||
auto actual = ReferenceUtil::ApplyElementwise2D(
|
||||
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*actual);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
|
||||
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
|
@ -97,7 +97,7 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
|
||||
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
|
||||
6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal::CreateR1<float>(expected);
|
||||
LiteralUtil::CreateR1<float>(expected);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
|
||||
computation, {}, nullptr));
|
||||
|
@ -136,7 +136,7 @@ cc_library(
|
||||
":hlo_dce",
|
||||
":hlo_pass",
|
||||
":tuple_simplifier",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -227,6 +227,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_query",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -244,7 +245,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_evaluator",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
@ -294,6 +295,7 @@ cc_library(
|
||||
":hlo_reachability",
|
||||
":name_uniquer",
|
||||
"//tensorflow/compiler/xla:array",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
@ -396,6 +398,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
@ -407,7 +410,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -424,7 +427,7 @@ tf_cc_test(
|
||||
srcs = ["hlo_sharding_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -453,7 +456,7 @@ tf_cc_test(
|
||||
srcs = ["call_graph_test.cc"],
|
||||
deps = [
|
||||
":call_graph",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -502,7 +505,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -521,7 +524,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":call_graph",
|
||||
":flatten_call_graph",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -797,7 +800,7 @@ cc_library(
|
||||
hdrs = ["transfer_manager.h"],
|
||||
deps = [
|
||||
":shaped_buffer",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -960,7 +963,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":hlo_scheduling",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
@ -1038,7 +1041,7 @@ tf_cc_test(
|
||||
":hlo_ordering",
|
||||
":hlo_value",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
@ -1121,7 +1124,7 @@ cc_library(
|
||||
hdrs = ["hlo_query.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
],
|
||||
)
|
||||
@ -1170,6 +1173,7 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -1200,6 +1204,7 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -1219,6 +1224,7 @@ cc_library(
|
||||
":hlo_creation_utils",
|
||||
":hlo_pass",
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
],
|
||||
@ -1233,7 +1239,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -1255,6 +1261,7 @@ cc_library(
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -1274,7 +1281,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -1310,7 +1317,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -1345,7 +1352,7 @@ cc_library(
|
||||
":call_inliner",
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -1361,6 +1368,7 @@ tf_cc_test(
|
||||
":conditional_simplifier",
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -1420,7 +1428,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":defuser",
|
||||
":hlo_matchers",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
],
|
||||
@ -1448,7 +1456,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":implicit_broadcast_remover",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
],
|
||||
@ -1490,7 +1498,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":tuple_simplifier",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -1505,7 +1513,7 @@ cc_library(
|
||||
hdrs = ["reshape_mover.h"],
|
||||
deps = [
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -1520,7 +1528,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":reshape_mover",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
@ -1555,7 +1563,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":inliner",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -1572,7 +1580,7 @@ cc_library(
|
||||
hdrs = ["computation_placer.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -1604,7 +1612,7 @@ cc_library(
|
||||
hdrs = ["generic_transfer_manager.h"],
|
||||
deps = [
|
||||
":transfer_manager",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -1695,7 +1703,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
@ -1710,6 +1718,7 @@ tf_cc_binary(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -1724,7 +1733,7 @@ tf_cc_test(
|
||||
srcs = ["hlo_module_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -1822,7 +1831,7 @@ tf_cc_test(
|
||||
":hlo_matchers",
|
||||
":hlo_ordering",
|
||||
":instruction_fusion",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -1859,7 +1868,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_liveness_analysis",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -1920,7 +1929,7 @@ tf_cc_test(
|
||||
":hlo_matchers",
|
||||
":hlo_ordering",
|
||||
":instruction_fusion",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
@ -1955,6 +1964,7 @@ cc_library(
|
||||
":hlo_dataflow_analysis",
|
||||
":logical_buffer",
|
||||
":logical_buffer_analysis",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -1973,6 +1983,7 @@ tf_cc_test(
|
||||
":hlo_matchers",
|
||||
":instruction_fusion",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -2044,7 +2055,7 @@ tf_cc_test(
|
||||
":hlo_graph_dumper",
|
||||
":hlo_matchers",
|
||||
":hlo_runner",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
@ -2169,6 +2180,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dce",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -2189,7 +2201,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_module_dce",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -2213,7 +2225,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":layout_assignment",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_layout",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -2272,7 +2284,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_domain_map",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -2288,7 +2300,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_cse",
|
||||
":hlo_matchers",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -2310,7 +2322,7 @@ cc_library(
|
||||
":hlo_evaluator",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
@ -2325,7 +2337,7 @@ tf_cc_test(
|
||||
":hlo_constant_folding",
|
||||
":hlo_matchers",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -2417,7 +2429,7 @@ cc_library(
|
||||
":hlo_evaluator",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
@ -2552,7 +2564,7 @@ cc_library(
|
||||
hdrs = ["hlo_tfgraph_builder.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/core:framework",
|
||||
@ -2583,7 +2595,7 @@ cc_library(
|
||||
":hlo_casting_utils",
|
||||
":hlo_execution_profile",
|
||||
":hlo_tfgraph_builder",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
@ -2601,6 +2613,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
@ -2632,7 +2645,7 @@ tf_cc_test(
|
||||
":hlo_matchers",
|
||||
":shape_inference",
|
||||
":transpose_folding",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
@ -2653,7 +2666,7 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -2668,7 +2681,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":shape_inference",
|
||||
":zero_sized_hlo_elimination",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -2828,6 +2841,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_creation_utils",
|
||||
":tuple_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
@ -2963,6 +2977,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_lexer",
|
||||
":hlo_sharding_metadata",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -195,7 +196,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
|
||||
HloInstruction* zero =
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::Zero(hlo->shape().element_type()).CloneToUnique()));
|
||||
LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
|
||||
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
|
||||
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
|
||||
return computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||
@ -537,8 +538,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
|
||||
// If a literal is all the same element replace it with a scalar broadcast.
|
||||
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
|
||||
constant->literal().IsAllFirst()) {
|
||||
std::unique_ptr<Literal> unique_scalar =
|
||||
MakeUnique<Literal>(constant->literal().GetFirstScalarLiteral());
|
||||
std::unique_ptr<Literal> unique_scalar = MakeUnique<Literal>(
|
||||
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
|
||||
HloInstruction* scalar = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(unique_scalar)));
|
||||
return ReplaceWithNewInstruction(
|
||||
@ -1093,7 +1094,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
||||
ShapeUtil::IsZeroElementArray(lhs->shape()) ||
|
||||
ShapeUtil::IsZeroElementArray(rhs->shape())) {
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
|
||||
return ReplaceWithNewInstruction(
|
||||
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
|
||||
}
|
||||
@ -1519,7 +1520,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
|
||||
if (IsAll(rhs, 0)) {
|
||||
auto one = HloInstruction::CreateConstant(
|
||||
Literal::One(power->shape().element_type()).CloneToUnique());
|
||||
LiteralUtil::One(power->shape().element_type()).CloneToUnique());
|
||||
std::unique_ptr<HloInstruction> ones;
|
||||
if (ShapeUtil::IsScalar(power->shape())) {
|
||||
ones = std::move(one);
|
||||
@ -1554,7 +1555,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
|
||||
if (IsAll(rhs, -1)) {
|
||||
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::One(rhs->shape().element_type()).CloneToUnique()));
|
||||
LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
|
||||
|
||||
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
|
||||
// broadcast in divide HLO as we are trying to eliminate implicit
|
||||
@ -2098,7 +2099,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
HloInstruction::CreateBroadcast(
|
||||
convolution->shape(),
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::Zero(convolution->shape().element_type())
|
||||
LiteralUtil::Zero(convolution->shape().element_type())
|
||||
.CloneToUnique())),
|
||||
{}));
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -60,7 +60,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
|
||||
|
||||
@ -79,7 +79,7 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
// Create add computation.
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloComputation* add_computation = nullptr;
|
||||
{
|
||||
HloComputation::Builder builder(TestName() + ".add");
|
||||
@ -119,7 +119,7 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
|
||||
|
||||
@ -140,9 +140,9 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
|
||||
HloInstruction* constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(3.14159f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f)));
|
||||
|
||||
HloInstruction* add1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
|
||||
@ -165,7 +165,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction* bcast = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
|
||||
builder.AddInstruction(
|
||||
@ -200,7 +200,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
builder.AddInstruction(HloInstruction::CreateMap(
|
||||
r2f32,
|
||||
{param0, builder.AddInstruction(
|
||||
@ -223,7 +223,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({0, 0, 0})));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
|
||||
HloInstruction* bcast =
|
||||
builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
|
||||
builder.AddInstruction(
|
||||
@ -242,7 +242,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
|
||||
TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({3.14f, 3.14f, 3.14f})));
|
||||
LiteralUtil::CreateR1<float>({3.14f, 3.14f, 3.14f})));
|
||||
|
||||
auto computation = module().AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
@ -258,7 +258,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
|
||||
TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({3.14, 3.14, 4})));
|
||||
LiteralUtil::CreateR1<float>({3.14, 3.14, 4})));
|
||||
|
||||
auto computation = module().AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
@ -277,7 +277,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
|
||||
|
||||
@ -298,7 +298,7 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
r0f32, HloOpcode::kSubtract, param0, constant));
|
||||
|
||||
@ -493,7 +493,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
|
||||
HloInstruction::CreateParameter(0, r1f32, "param0"));
|
||||
HloInstruction* constant =
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({0.f, 1.f, 2.f})));
|
||||
LiteralUtil::CreateR1<float>({0.f, 1.f, 2.f})));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
|
||||
param0, constant));
|
||||
|
||||
@ -559,7 +559,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
|
||||
HloInstruction* div = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
|
||||
|
||||
@ -580,7 +580,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||
HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
|
||||
LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
|
||||
HloInstruction* div = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
|
||||
|
||||
@ -860,7 +860,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
|
||||
|
||||
@ -884,7 +884,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r1f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
|
||||
|
||||
@ -912,7 +912,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
|
||||
|
||||
@ -934,7 +934,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* two = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
|
||||
|
||||
@ -956,7 +956,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* negative_one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(-1)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
|
||||
param0, negative_one));
|
||||
|
||||
@ -1047,7 +1047,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
|
||||
builder.AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
ShapeUtil::MakeShape(F32, {5, 2}), param,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
|
||||
window, add_computation));
|
||||
module().AddEntryComputation(builder.Build());
|
||||
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
|
||||
@ -1074,7 +1074,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
|
||||
builder.AddInstruction(HloInstruction::CreatePad(
|
||||
ShapeUtil::MakeShape(F32, {5, 2}), param,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
|
||||
padding));
|
||||
module().AddEntryComputation(builder.Build());
|
||||
EXPECT_THAT(module().entry_computation()->root_instruction(),
|
||||
@ -1116,7 +1116,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
|
||||
TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* input = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
|
||||
|
||||
@ -1208,7 +1208,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, r1f32, "param1"));
|
||||
HloInstruction* empty_literal = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
|
||||
HloInstruction* empty_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
|
||||
@ -1238,7 +1238,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r1f32, "param0"));
|
||||
HloInstruction* empty_literal = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
|
||||
HloInstruction* empty_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
|
||||
@ -1420,7 +1420,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param0")),
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
|
||||
LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
|
||||
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
|
||||
@ -1443,7 +1443,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param0")),
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
|
||||
LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
|
||||
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
|
||||
@ -1726,7 +1726,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
PaddingConfig no_padding;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
auto dimension = no_padding.add_dimensions();
|
||||
@ -1757,7 +1757,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
PaddingConfig padding;
|
||||
int64 low_padding[2] = {-1, -2};
|
||||
int64 high_padding[2] = {2, -3};
|
||||
@ -2109,7 +2109,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
|
||||
TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* forty_two = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
||||
|
||||
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
|
||||
HloInstruction* broadcast = builder.AddInstruction(
|
||||
@ -2156,7 +2156,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
|
||||
padding.mutable_dimensions(3)->set_edge_padding_high(2);
|
||||
|
||||
HloInstruction* pad_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
|
||||
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
|
||||
ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
|
||||
|
||||
@ -2187,7 +2187,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
|
||||
const Shape reduce_window_shape =
|
||||
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
|
||||
HloInstruction* reduce_init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
|
||||
HloInstruction* reduce_window =
|
||||
builder.AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
reduce_window_shape, pad, reduce_init_value, window,
|
||||
@ -2238,7 +2238,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
|
||||
padding.mutable_dimensions(3)->set_edge_padding_high(2);
|
||||
|
||||
HloInstruction* pad_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
|
||||
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
|
||||
ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
|
||||
|
||||
@ -2273,7 +2273,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
|
||||
const Shape reduce_window_shape =
|
||||
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
|
||||
HloInstruction* reduce_init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
|
||||
HloInstruction* reduce_window =
|
||||
builder.AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
reduce_window_shape, convert, reduce_init_value, window,
|
||||
@ -2344,9 +2344,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
|
||||
|
||||
HloComputation::Builder call_builder(TestName() + ".Call");
|
||||
HloInstruction* zero = call_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({0.0f})));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
|
||||
HloInstruction* one = call_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0f})));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
|
||||
call_builder.AddInstruction(
|
||||
HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
|
||||
|
||||
@ -2362,9 +2362,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const float constant_scalar = 7.3f;
|
||||
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
|
||||
std::unique_ptr<Literal> value =
|
||||
Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
|
||||
Literal::CreateR1<float>(constant_vector).get()});
|
||||
std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR0<float>(constant_scalar).get(),
|
||||
LiteralUtil::CreateR1<float>(constant_vector).get()});
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
|
||||
|
||||
auto computation = module().AddEntryComputation(builder.Build());
|
||||
@ -2387,8 +2387,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
|
||||
shape,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "slice_from")),
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))),
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<int>({0, 0, 0}))),
|
||||
/*slice_sizes=*/{10, 100, 1000}));
|
||||
|
||||
auto computation = module().AddEntryComputation(builder.Build());
|
||||
@ -2421,8 +2421,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, slice_shape, "to_update")),
|
||||
slice,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0})))));
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<int>({0, 0, 0})))));
|
||||
|
||||
auto computation = module().AddEntryComputation(builder.Build());
|
||||
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||
@ -2437,7 +2437,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
HloInstruction* input_array = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({3, 4})));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({3, 4})));
|
||||
HloInstruction* inner_bcast = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
|
||||
Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
|
||||
@ -2546,7 +2546,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
|
||||
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
|
||||
pad_shape, input,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
|
||||
padding));
|
||||
|
||||
HloComputation* add_computation = nullptr;
|
||||
@ -2565,7 +2565,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
|
||||
Window window = window_util::MakeWindow(
|
||||
decorate_spatials(param.reduce_window_spatials, 1, 1));
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
|
||||
ShapeInference::InferReduceWindowShape(
|
||||
pad->shape(), zero->shape(), window,
|
||||
@ -2704,7 +2704,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
|
||||
|
||||
Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
|
||||
auto* lhs = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
|
||||
/*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
|
||||
|
||||
Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
|
||||
@ -2783,7 +2783,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
|
||||
|
||||
Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
|
||||
auto* rhs = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
|
||||
/*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
|
||||
|
||||
DotDimensionNumbers dot_dnums;
|
||||
@ -2830,7 +2830,7 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
|
||||
HloInstruction* const update = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, update_shape, "update"));
|
||||
HloInstruction* const start_indices = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int>({0})));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int>({0})));
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
dslice_shape, operand, update, start_indices));
|
||||
const HloComputation* const computation =
|
||||
@ -2879,7 +2879,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
|
||||
int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
|
||||
Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
|
||||
auto* lhs = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
|
||||
/*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
|
||||
/*cols=*/lhs_cols)));
|
||||
|
||||
@ -2887,7 +2887,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
|
||||
int32 start_col = (spec.lcd == 0) ? spec.s : 0;
|
||||
const auto start_indices =
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<int32>({start_row, start_col})));
|
||||
LiteralUtil::CreateR1<int32>({start_row, start_col})));
|
||||
int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
|
||||
int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
|
||||
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
|
||||
@ -2898,7 +2898,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
|
||||
int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
|
||||
Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
|
||||
auto* rhs = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
|
||||
/*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
|
||||
/*cols=*/rhs_cols)));
|
||||
|
||||
@ -2946,7 +2946,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
|
||||
int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
|
||||
Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
|
||||
auto* lhs = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
|
||||
/*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
|
||||
/*cols=*/lhs_cols)));
|
||||
|
||||
@ -2957,7 +2957,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
|
||||
int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
|
||||
Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
|
||||
auto* rhs = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
|
||||
/*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
|
||||
/*cols=*/rhs_cols)));
|
||||
|
||||
@ -2965,7 +2965,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
|
||||
int32 start_col = (spec.rcd == 0) ? spec.s : 0;
|
||||
const auto start_indices =
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<int32>({start_row, start_col})));
|
||||
LiteralUtil::CreateR1<int32>({start_row, start_col})));
|
||||
int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
|
||||
int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
|
||||
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -97,7 +98,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
||||
add_instruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::MakeShape(operand->shape().element_type(), {}),
|
||||
add_instruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR0<float>(-0.5f))))),
|
||||
LiteralUtil::CreateR0<float>(-0.5f))))),
|
||||
{}));
|
||||
return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower,
|
||||
operand, exponent);
|
||||
@ -113,7 +114,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
||||
add_instruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::MakeShape(operand->shape().element_type(), {}),
|
||||
add_instruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR0<float>(1.0 / element_count))))),
|
||||
LiteralUtil::CreateR0<float>(1.0 / element_count))))),
|
||||
{}));
|
||||
return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply,
|
||||
operand, elem_count_recip);
|
||||
@ -200,11 +201,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
|
||||
HloInstruction* offset = batch_norm->mutable_operand(2);
|
||||
const Shape feature_shape = scale->shape();
|
||||
|
||||
auto zero_literal = Literal::CreateR0(0.0f);
|
||||
auto zero_literal = LiteralUtil::CreateR0(0.0f);
|
||||
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
|
||||
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
|
||||
|
||||
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
|
||||
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
|
||||
auto epsilon = add(HloInstruction::CreateBroadcast(
|
||||
operand_shape,
|
||||
@ -320,7 +321,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
|
||||
HloInstruction* var = batch_norm->mutable_operand(4);
|
||||
const Shape feature_shape = scale->shape();
|
||||
|
||||
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
|
||||
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
|
||||
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
operand_shape,
|
||||
@ -447,11 +448,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
|
||||
const int64 feature_count = activation_shape.dimensions(feature_index);
|
||||
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
|
||||
|
||||
auto zero_literal = Literal::CreateR0(0.0f);
|
||||
auto zero_literal = LiteralUtil::CreateR0(0.0f);
|
||||
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
|
||||
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
|
||||
|
||||
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
|
||||
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
|
||||
auto epsilon_scalar =
|
||||
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
|
||||
@ -542,7 +543,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
|
||||
Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add));
|
||||
|
||||
auto elements_per_feature_literal =
|
||||
Literal::CreateR0<float>(elements_per_feature_int64);
|
||||
LiteralUtil::CreateR0<float>(elements_per_feature_int64);
|
||||
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
|
||||
elements_per_feature_literal->Convert(ptype));
|
||||
auto elements_per_feature = add(
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
|
@ -133,9 +133,9 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
|
||||
array_b.FillUnique(10.0f);
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateFromArray(array_a)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateFromArray(array_b)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
|
||||
HloInstruction* dot = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b));
|
||||
|
||||
@ -150,10 +150,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
|
||||
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
|
||||
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
*Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)),
|
||||
*LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
|
||||
dot->operand(0)->literal()));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
*Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)),
|
||||
*LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
|
||||
dot->operand(1)->literal()));
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
@ -125,7 +125,7 @@ class BufferAssignmentTest : public HloTestBase {
|
||||
auto param =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
|
||||
auto value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
|
||||
return builder.Build();
|
||||
@ -142,7 +142,7 @@ class BufferAssignmentTest : public HloTestBase {
|
||||
const string& name) {
|
||||
auto builder = HloComputation::Builder(name);
|
||||
auto const4 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
|
||||
auto index = builder.AddInstruction(
|
||||
@ -167,9 +167,9 @@ class BufferAssignmentTest : public HloTestBase {
|
||||
const string& name) {
|
||||
auto builder = HloComputation::Builder(name);
|
||||
auto const1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
|
||||
auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
|
||||
auto indexc = builder.AddInstruction(
|
||||
@ -290,7 +290,7 @@ static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
|
||||
TEST_F(BufferAssignmentTest, ScalarConstant) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
@ -304,9 +304,9 @@ TEST_F(BufferAssignmentTest, BufferForConst) {
|
||||
// no buffers assigned, and their consumer has a buffer.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
|
||||
LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
|
||||
auto add = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
|
||||
auto module = CreateNewModule();
|
||||
@ -327,7 +327,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) {
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32vec100_, "param0"));
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
|
||||
auto negate = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
|
||||
auto tuple = builder.AddInstruction(
|
||||
@ -352,7 +352,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) {
|
||||
// This computation copies a constant to output.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
auto copy = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
|
||||
auto module = CreateNewModule();
|
||||
@ -660,7 +660,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
|
||||
auto exp2 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
|
||||
/*shape=*/f32vec10_,
|
||||
/*operand=*/exp2,
|
||||
@ -708,9 +708,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) {
|
||||
// Creates the main kernel and verifies instruction counts.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const3 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
|
||||
auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
auto tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
|
||||
auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
@ -773,11 +773,11 @@ TEST_F(BufferAssignmentTest, ExampleConditional) {
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto pred = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
auto const1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
|
||||
auto const2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(12.4f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
|
||||
auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
|
||||
r0f32_, pred, const1, true_computation, const2, false_computation));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
@ -1200,8 +1200,9 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) {
|
||||
// Test that a tuple constant which is forwarded to the computation output
|
||||
// is properly handled.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple(
|
||||
{Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()})));
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
|
||||
LiteralUtil::CreateR0<int64>(1).get()})));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
@ -1584,7 +1585,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) {
|
||||
auto b = HloComputation::Builder(TestName() + ".cond");
|
||||
b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
b.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
|
||||
condition = module->AddEmbeddedComputation(b.Build());
|
||||
}
|
||||
HloComputation* body;
|
||||
@ -1647,9 +1648,9 @@ class WhileBufferAssignmentTest : public HloTestBase {
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
|
||||
auto ten = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(10)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
|
||||
return builder.Build();
|
||||
@ -1708,7 +1709,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
|
||||
HloInstruction::CreateParameter(2, data_shape_, "weights1"));
|
||||
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
|
||||
auto output0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
auto output1 = builder.AddInstruction(
|
||||
@ -1851,7 +1852,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
|
||||
auto build_cond = [&]() {
|
||||
auto builder = HloComputation::Builder("cond");
|
||||
auto const4 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
|
||||
auto param =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
@ -1863,7 +1864,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
|
||||
auto build_body = [&]() {
|
||||
auto builder = HloComputation::Builder("body");
|
||||
auto const9 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(9)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
|
||||
auto param =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
|
||||
builder.AddInstruction(
|
||||
@ -1891,7 +1892,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
|
||||
HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
|
||||
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
|
||||
auto add = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
|
||||
auto cond2 = module->AddEmbeddedComputation(build_cond());
|
||||
@ -1953,7 +1954,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
|
||||
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
|
||||
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
|
||||
auto output0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
|
||||
@ -1997,16 +1998,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) {
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param"));
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto add = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
|
||||
sub_computation = module->AddEmbeddedComputation(builder.Build(add));
|
||||
}
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
|
||||
auto constant3 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
|
||||
auto call1 = builder.AddInstruction(
|
||||
HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
|
||||
auto call2 = builder.AddInstruction(
|
||||
@ -2058,9 +2059,9 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
|
||||
auto input0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, data_shape_, "input0"));
|
||||
@ -2142,7 +2143,7 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
|
||||
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
|
||||
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
|
||||
auto output0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
auto output1 = builder.AddInstruction(
|
||||
|
@ -439,11 +439,13 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
|
||||
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
|
||||
// the buffers containing {3} and 3 are dead.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto inner_tuple0 = Literal::MakeTuple(
|
||||
{Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()});
|
||||
auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0<int64>(3).get()});
|
||||
auto inner_tuple0 =
|
||||
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
|
||||
LiteralUtil::CreateR0<int64>(1).get()});
|
||||
auto inner_tuple1 =
|
||||
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
|
||||
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
|
||||
LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
inner_tuple0->shape(), tuple_constant, 0));
|
||||
|
||||
@ -491,7 +493,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
tuple_element0_shape, tuple_param0, 0));
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
|
||||
|
||||
@ -503,7 +505,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
tuple_element1_shape, tuple_param0, 1));
|
||||
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
|
||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
|
||||
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
|
||||
|
||||
@ -555,7 +557,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
tuple_element0_shape, tuple_param0, 0));
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
|
||||
|
||||
@ -627,7 +629,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
|
||||
|
||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
HloInstruction* slice = nullptr;
|
||||
if (update_uses_tuple_element1) {
|
||||
// Create a slice instruction as an additional user of 'gte1'.
|
||||
@ -638,7 +640,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
}
|
||||
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
|
||||
auto starts = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
|
||||
auto dynamic_update_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
data_shape, gte1, update, starts));
|
||||
@ -757,7 +759,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
|
||||
|
||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
|
||||
if (tuple_element1_has_two_uses) {
|
||||
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
|
||||
@ -766,7 +768,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
}
|
||||
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
|
||||
auto starts = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
|
||||
auto dynamic_update_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
data_shape, gte1, update, starts));
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -82,7 +82,7 @@ class CallGraphTest : public HloTestBase {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
|
||||
return builder.Build();
|
||||
@ -247,11 +247,11 @@ TEST_F(CallGraphTest, ComputationWithConditional) {
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* pred = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
HloInstruction* const1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
|
||||
HloInstruction* const2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(12.6f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.6f)));
|
||||
HloInstruction* conditional =
|
||||
builder.AddInstruction(HloInstruction::CreateConditional(
|
||||
kScalarShape, pred, const1, true_computation, const2,
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -48,9 +48,9 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
|
||||
// the "one" value.
|
||||
HloComputation::Builder inner(TestName() + ".inner");
|
||||
HloInstruction* zero = inner.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(24.0f)));
|
||||
HloInstruction* one = inner.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
||||
TF_ASSERT_OK(zero->AddControlDependencyTo(one));
|
||||
auto module = CreateNewModule();
|
||||
HloComputation* inner_computation =
|
||||
@ -87,7 +87,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
|
||||
// little trickier.
|
||||
HloComputation::Builder just_false(TestName() + ".false");
|
||||
just_false.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
HloComputation* false_computation =
|
||||
module->AddEmbeddedComputation(just_false.Build());
|
||||
|
||||
@ -99,7 +99,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
|
||||
|
||||
HloComputation::Builder outer(TestName() + ".outer");
|
||||
HloInstruction* init_value = outer.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
outer.AddInstruction(
|
||||
HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
|
||||
|
||||
@ -123,9 +123,9 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) {
|
||||
|
||||
HloComputation::Builder just_false(TestName() + ".false");
|
||||
auto* true_constant = just_false.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<bool>({true})));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<bool>({true})));
|
||||
auto* false_constant = just_false.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant));
|
||||
HloComputation* false_computation =
|
||||
module->AddEmbeddedComputation(just_false.Build());
|
||||
@ -147,7 +147,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
|
||||
|
||||
HloComputation::Builder outfeeder(TestName() + ".outfeeder");
|
||||
auto value = outfeeder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
|
||||
auto token = outfeeder.AddInstruction(HloInstruction::CreateAfterAll({}));
|
||||
outfeeder.AddInstruction(
|
||||
HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/""));
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
@ -55,7 +55,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
|
||||
true_computation_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(S32, {}), "param"));
|
||||
auto one = true_computation_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
|
||||
true_computation_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one));
|
||||
@ -73,7 +73,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
|
||||
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}),
|
||||
"param"));
|
||||
auto forty_two = false_computation_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
|
||||
|
||||
false_computation_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two));
|
||||
@ -82,11 +82,11 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
|
||||
}
|
||||
|
||||
auto false_instrn = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(S32, {}), "false_param"));
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
|
||||
builder.AddInstruction(HloInstruction::CreateConditional(
|
||||
ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation,
|
||||
@ -106,7 +106,7 @@ TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
|
||||
HloComputation* computation = MakeConditional(&module());
|
||||
|
||||
auto* true_op = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
|
||||
TF_ASSERT_OK(
|
||||
true_op->AddControlDependencyTo(computation->root_instruction()));
|
||||
|
||||
@ -123,7 +123,7 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) {
|
||||
true_computation->AddInstruction(HloInstruction::CreateAfterAll({}));
|
||||
auto* send = true_computation->AddInstruction(HloInstruction::CreateSend(
|
||||
true_computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
|
||||
token, /*channel_id=*/0));
|
||||
true_computation->AddInstruction(HloInstruction::CreateSendDone(send));
|
||||
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
@ -108,7 +108,7 @@ TEST_F(CopyInsertionTest, SingleConstant) {
|
||||
// be copied before entering the tuple.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
HloInstruction* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({constant}));
|
||||
|
||||
@ -132,7 +132,7 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* constant =
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
|
||||
LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
|
||||
auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
|
||||
Layout reversed_layout =
|
||||
LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
|
||||
@ -167,9 +167,9 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
HloInstruction* constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
HloInstruction* constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
|
||||
|
||||
HloInstruction* x = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
|
||||
@ -197,11 +197,11 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
|
||||
// the computation result. Verify that copies are added properly.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
HloInstruction* constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
|
||||
HloInstruction* constant3 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
|
||||
|
||||
HloInstruction* tuple1 = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({constant1, constant2}));
|
||||
@ -209,7 +209,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
|
||||
HloInstruction::CreateTuple({constant3, constant2}));
|
||||
|
||||
HloInstruction* pred = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
|
||||
|
||||
@ -255,8 +255,9 @@ TEST_F(CopyInsertionTest, BitcastConstant) {
|
||||
// The output of a bitcast is its operand (same buffer), so a bitcast
|
||||
// constant feeding the result must have a copy added.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0, 42.0})));
|
||||
HloInstruction* constant =
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({1.0, 42.0})));
|
||||
HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant));
|
||||
|
||||
@ -370,9 +371,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
|
||||
// copy is added.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
HloInstruction* constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
|
||||
|
||||
HloInstruction* tuple1 = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({constant1, constant2}));
|
||||
@ -380,7 +381,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
|
||||
HloInstruction::CreateTuple({constant2, constant1}));
|
||||
|
||||
HloInstruction* pred = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
|
||||
HloInstruction* gte =
|
||||
@ -413,7 +414,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
const Shape& loop_state_shape) {
|
||||
auto builder = HloComputation::Builder(TestName() + ".Condition");
|
||||
auto limit_const = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(10)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10)));
|
||||
auto loop_state = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
|
||||
auto induction_variable =
|
||||
@ -442,7 +443,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
induction_variable_shape_, loop_state, 0));
|
||||
auto inc = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
|
||||
// Update data GTE(1).
|
||||
@ -480,7 +481,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
induction_variable_shape_, loop_state, 0));
|
||||
auto inc = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
|
||||
// add0 = Add(in0, 1)
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
@ -549,7 +550,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
induction_variable_shape_, loop_state, 0));
|
||||
auto inc = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
// add0 = Add(in0, 1)
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
|
||||
@ -564,8 +565,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
data = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
|
||||
}
|
||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
auto update = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
|
||||
{1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
// add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
|
||||
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_shape_, HloOpcode::kAdd, data, update));
|
||||
@ -598,7 +600,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
induction_variable_shape_, loop_state, 0));
|
||||
auto inc = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
gte0->shape(), HloOpcode::kAdd, gte0, inc));
|
||||
|
||||
@ -608,8 +610,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
// GTE(GTE(loop_state, 1), 0) -> Add
|
||||
auto gte10 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
|
||||
auto update10 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
auto update10 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
|
||||
{1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_shape_, HloOpcode::kAdd, gte10, update10));
|
||||
|
||||
@ -633,10 +636,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
bool nested = false) {
|
||||
auto builder = HloComputation::Builder(TestName() + ".While");
|
||||
auto induction_var_init = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
|
||||
|
||||
auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
|
||||
auto data_init = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
|
||||
{0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
|
||||
|
||||
if (nested) {
|
||||
auto inner_init = builder.AddInstruction(
|
||||
@ -659,8 +663,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
|
||||
HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
|
||||
auto builder = HloComputation::Builder(TestName() + ".While");
|
||||
auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
|
||||
auto data_init = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
|
||||
{0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
|
||||
return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
|
||||
&builder);
|
||||
}
|
||||
@ -677,11 +682,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
auto builder = HloComputation::Builder(TestName() + ".While");
|
||||
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto v1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto v2 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
|
||||
@ -689,7 +694,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
|
||||
|
||||
auto pred = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
auto data_init = builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
|
||||
|
||||
@ -701,7 +706,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
auto builder = HloComputation::Builder(TestName() + ".While");
|
||||
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto one_vec = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
|
||||
auto data_init =
|
||||
@ -714,11 +719,12 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
|
||||
auto builder = HloComputation::Builder(TestName() + ".While");
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto data_init = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
|
||||
auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
auto one_vec = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
|
||||
{1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
// Take a reference to 'data_init' to make it interfere with while result.
|
||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_shape_, HloOpcode::kAdd, data_init, one_vec));
|
||||
@ -750,7 +756,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
const bool nested =
|
||||
ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
|
||||
auto induction_var_init = builder->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
|
||||
auto condition = module_->AddEmbeddedComputation(
|
||||
BuildConditionComputation(loop_state_shape));
|
||||
auto body = module_->AddEmbeddedComputation(
|
||||
@ -1252,7 +1258,6 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
|
||||
auto loop_init = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
|
||||
|
||||
|
||||
// Two while loops shares the same loop init tuple.
|
||||
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
loop_state_shape, condition1, body1, loop_init));
|
||||
@ -1310,7 +1315,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
|
||||
auto cond_constant = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
cond_builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
cond_constant->shape(), HloOpcode::kNot, cond_constant));
|
||||
HloComputation* condition =
|
||||
@ -1318,9 +1323,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
|
||||
auto tuple = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({constant1, constant2}));
|
||||
auto xla_while = builder.AddInstruction(
|
||||
@ -1375,7 +1380,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
|
||||
auto cond_constant = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
cond_builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
cond_constant->shape(), HloOpcode::kNot, cond_constant));
|
||||
HloComputation* condition =
|
||||
@ -1383,9 +1388,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
|
||||
auto tuple = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({constant1, constant2}));
|
||||
auto xla_while = builder.AddInstruction(
|
||||
@ -1435,7 +1440,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
|
||||
auto cond_constant = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
cond_builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
cond_constant->shape(), HloOpcode::kNot, cond_constant));
|
||||
HloComputation* condition =
|
||||
@ -1443,7 +1448,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
|
||||
builder.AddInstruction(
|
||||
@ -1520,7 +1525,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) {
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
|
||||
auto cond_constant = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
cond_builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
cond_constant->shape(), HloOpcode::kNot, cond_constant));
|
||||
HloComputation* condition =
|
||||
@ -1575,14 +1580,14 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
|
||||
body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
body_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
|
||||
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto cond_builder = HloComputation::Builder("condition");
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
HloComputation* condition =
|
||||
module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
@ -1644,7 +1649,7 @@ std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "loop_state"));
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
constant->shape(), HloOpcode::kNot, constant));
|
||||
return builder.Build();
|
||||
|
@ -37,6 +37,7 @@ cc_library(
|
||||
srcs = ["cpu_transfer_manager.cc"],
|
||||
hdrs = ["cpu_transfer_manager.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -72,7 +73,7 @@ cc_library(
|
||||
":ir_emitter",
|
||||
":parallel_task_assignment",
|
||||
":simple_orc_jit",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -355,7 +356,7 @@ tf_cc_binary(
|
||||
srcs = ["sample_harness.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -717,7 +718,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":cpu_layout_assignment",
|
||||
":target_machine_features_fake",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_layout",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -809,7 +810,7 @@ tf_cc_test(
|
||||
":cpu_executable",
|
||||
":parallel_task_assignment",
|
||||
":target_machine_features_fake",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_layout",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -892,7 +893,7 @@ tf_cc_test(
|
||||
srcs = ["cpu_copy_insertion_test.cc"],
|
||||
deps = [
|
||||
":cpu_copy_insertion",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
|
@ -60,11 +60,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// The input dimensions are in CNHW order.
|
||||
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR4FromArray4D(Array4D<float>(
|
||||
LiteralUtil::CreateR4FromArray4D(Array4D<float>(
|
||||
kInputFeatureCount, kBatchSize, kInputSize, kInputSize))));
|
||||
// The kernel dimensions are in OIHW order.
|
||||
auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR4FromArray4D(Array4D<float>(
|
||||
LiteralUtil::CreateR4FromArray4D(Array4D<float>(
|
||||
kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize))));
|
||||
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
@ -122,11 +122,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// The input dimensions are in NHWC order.
|
||||
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR4FromArray4D(Array4D<float>(
|
||||
LiteralUtil::CreateR4FromArray4D(Array4D<float>(
|
||||
kBatchSize, kInputSize, kInputSize, kInputFeatureCount))));
|
||||
// The kernel dimensions are in HWIO order.
|
||||
auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR4FromArray4D(Array4D<float>(
|
||||
LiteralUtil::CreateR4FromArray4D(Array4D<float>(
|
||||
kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount))));
|
||||
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
|
@ -38,7 +38,7 @@ limitations under the License.
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
@ -74,14 +74,14 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
|
||||
body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
body_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
|
||||
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto cond_builder = HloComputation::Builder("condition");
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
HloComputation* condition =
|
||||
module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
@ -114,7 +114,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) {
|
||||
auto sub_param = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
auto constant = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
|
||||
auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape_, HloOpcode::kAdd, sub_param, constant));
|
||||
sub_builder.AddInstruction(
|
||||
|
@ -282,7 +282,7 @@ class OpcodeFusionTest : public InstructionFusionTest {
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "arg0"));
|
||||
HloInstruction* one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
|
||||
return module->AddEmbeddedComputation(builder.Build());
|
||||
@ -595,7 +595,7 @@ TEST_F(OpcodeFusionTest, MessOfFusileNodes) {
|
||||
auto pad = builder.AddInstruction(HloInstruction::CreatePad(
|
||||
ShapeUtil::MakeShape(S32, {5}), idx_choice,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0))),
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
|
||||
padding_config));
|
||||
|
||||
auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -180,7 +181,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions(
|
||||
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
|
||||
literal_shape.dimensions().size());
|
||||
*literal = std::move(*Literal::CreateFromDimensions(
|
||||
*literal = std::move(*LiteralUtil::CreateFromDimensions(
|
||||
literal_shape.element_type(), dimensions));
|
||||
TF_ASSIGN_OR_RETURN(Shape received_shape,
|
||||
TransferArrayBufferFromOutfeed(
|
||||
@ -211,7 +212,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
tensorflow::bit_cast<const int64*>(
|
||||
tuple_element_shape.dimensions().data()),
|
||||
tuple_element_shape.dimensions().size());
|
||||
auto empty = Literal::CreateFromDimensions(
|
||||
auto empty = LiteralUtil::CreateFromDimensions(
|
||||
tuple_element_shape.element_type(), dimensions);
|
||||
int64 size = GetByteSizeRequirement(tuple_element_shape);
|
||||
buffer_data.push_back({empty->untyped_data(), size});
|
||||
@ -232,7 +233,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
|
||||
*elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i);
|
||||
}
|
||||
*literal = std::move(*Literal::MakeTupleOwned(std::move(elements)));
|
||||
*literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements)));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -38,12 +38,13 @@ int main(int argc, char** argv) {
|
||||
|
||||
// Transfer parameters.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
|
||||
xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> param1_literal = xla::Literal::CreateR2<float>(
|
||||
{{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::LiteralUtil::CreateR2<float>(
|
||||
{{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
|
||||
|
@ -40,7 +40,7 @@ tf_cc_test(
|
||||
name = "cpu_fusion_test",
|
||||
srcs = ["cpu_fusion_test.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -82,7 +82,7 @@ tf_cc_test(
|
||||
name = "cpu_noalias_test",
|
||||
srcs = ["cpu_noalias_test.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -128,7 +128,7 @@ tf_cc_test(
|
||||
name = "cpu_infeed_test",
|
||||
srcs = ["cpu_infeed_test.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
|
@ -40,7 +40,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest {
|
||||
|
||||
HloInstruction* constant =
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR2FromArray2D(backing_array)));
|
||||
LiteralUtil::CreateR2FromArray2D(backing_array)));
|
||||
HloInstruction* param =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
|
||||
builder.AddInstruction(
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -43,8 +43,8 @@ class CpuFusionTest : public HloTestBase {
|
||||
|
||||
TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_literal1 = Literal::CreateR1<float>({1.0, 2.0, 3.0});
|
||||
auto input_literal2 = Literal::CreateR1<float>({-2.0, -42.0, 2.0});
|
||||
auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
|
||||
auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
|
||||
Shape vshape = input_literal1->shape();
|
||||
|
||||
auto input1 = builder.AddInstruction(
|
||||
@ -83,7 +83,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
|
||||
|
||||
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_literal = Literal::CreateR1<float>({-1.5, -2.5, -3.0});
|
||||
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
|
||||
Shape vshape = input_literal->shape();
|
||||
|
||||
auto input = builder.AddInstruction(
|
||||
@ -99,7 +99,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
|
||||
auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
vshape,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))),
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))),
|
||||
{}));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor));
|
||||
@ -134,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
|
||||
// middle.
|
||||
auto module = CreateNewModule();
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_literal = Literal::CreateR1<float>({-1.5, -2.5, -3.0});
|
||||
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
|
||||
Shape vshape = input_literal->shape();
|
||||
|
||||
auto input = builder.AddInstruction(
|
||||
@ -166,7 +166,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
|
||||
ShapeUtil::MakeShape(F32, {6, 1}), concatenate)),
|
||||
/*init_value=*/
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0))),
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
|
||||
/*dimensions_to_reduce=*/{1}, add_f32));
|
||||
|
||||
auto exp = builder.AddInstruction(
|
||||
@ -176,7 +176,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
|
||||
auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
cshape,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))),
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))),
|
||||
{}));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor));
|
||||
@ -231,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
|
||||
// operand vectors. Test for this problem by counting the number of nodes in
|
||||
// each fusion instruction to ensure that negate is not duplicated.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0});
|
||||
auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
|
||||
Shape vshape = input_literal->shape();
|
||||
|
||||
auto constant = builder.AddInstruction(
|
||||
@ -292,10 +292,10 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
|
||||
// computation. The duplication is caused by the other use of exp2 in the
|
||||
// tuple.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_literal1 = Literal::CreateR1<float>({1.0, 2.0, 3.0});
|
||||
auto input_literal2 = Literal::CreateR1<float>({-2.0, -42.0, 2.0});
|
||||
auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
|
||||
auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
|
||||
Shape shape = constant->shape();
|
||||
|
||||
auto exp1 = builder.AddInstruction(
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
|
||||
};
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR0Bool) {
|
||||
TestInfeedRoundTrip(*Literal::CreateR0<bool>(true));
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR1U32) {
|
||||
TestInfeedRoundTrip(*Literal::CreateR1<uint32>({1, 2, 3}));
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR2F32) {
|
||||
TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64));
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR3F32) {
|
||||
TestInfeedRoundTrip(
|
||||
*Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
|
||||
*LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
|
||||
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
|
||||
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
|
||||
|
||||
TestInfeedRoundTrip(
|
||||
*Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
|
||||
r3_dim0minor));
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
|
||||
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
|
||||
r3_dim0minor));
|
||||
|
||||
TestInfeedRoundTrip(
|
||||
*Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
|
||||
r3_dim0major));
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
|
||||
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
|
||||
r3_dim0major));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR4S32) {
|
||||
TestInfeedRoundTrip(*Literal::CreateR4(
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR4(
|
||||
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
|
||||
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedTuple) {
|
||||
TestInfeedRoundTrip(
|
||||
*Literal::MakeTuple({Literal::CreateR1<uint32>({1, 2, 3}).get(),
|
||||
Literal::CreateR0<bool>(false).get()}));
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
|
||||
LiteralUtil::CreateR0<bool>(false).get()}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
|
||||
TestInfeedRoundTrip(*Literal::MakeTuple({}));
|
||||
TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
|
||||
}
|
||||
|
||||
// Tests Infeed operation used in a while loop, as in the code below. The
|
||||
@ -156,13 +156,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
|
||||
});
|
||||
|
||||
// Send 5 Infeed data of shape F32[3].
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({1, 2, 3})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({4, 5, 6})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({7, 8, 9})));
|
||||
ASSERT_IS_OK(
|
||||
client_->TransferToInfeed(*Literal::CreateR1<float>({10, 11, 12})));
|
||||
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
|
||||
ASSERT_IS_OK(
|
||||
client_->TransferToInfeed(*Literal::CreateR1<float>({13, 14, 15})));
|
||||
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
|
||||
ASSERT_IS_OK(
|
||||
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
|
||||
ASSERT_IS_OK(
|
||||
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
|
||||
ASSERT_IS_OK(
|
||||
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
|
||||
|
||||
delete computation_thread; // Joins the thread.
|
||||
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
|
||||
@ -247,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
|
||||
|
||||
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*Literal::MakeTuple({Literal::CreateR1<float>({1, 2}).get(),
|
||||
Literal::CreateR0<bool>(true).get()})));
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
|
||||
LiteralUtil::CreateR0<bool>(true).get()})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*Literal::MakeTuple({Literal::CreateR1<float>({3, 4}).get(),
|
||||
Literal::CreateR0<bool>(true).get()})));
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
|
||||
LiteralUtil::CreateR0<bool>(true).get()})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*Literal::MakeTuple({Literal::CreateR1<float>({5, 6}).get(),
|
||||
Literal::CreateR0<bool>(true).get()})));
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
|
||||
LiteralUtil::CreateR0<bool>(true).get()})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*Literal::MakeTuple({Literal::CreateR1<float>({7, 8}).get(),
|
||||
Literal::CreateR0<bool>(false).get()})));
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
|
||||
LiteralUtil::CreateR0<bool>(false).get()})));
|
||||
|
||||
// Asynchronously launch the execution on the device.
|
||||
std::unique_ptr<GlobalData> result;
|
||||
@ -272,14 +275,14 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
|
||||
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
|
||||
sleep(1);
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*Literal::MakeTuple({Literal::CreateR1<float>({1, 2, 3}).get(),
|
||||
Literal::CreateR0<bool>(true).get()})));
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
|
||||
LiteralUtil::CreateR0<bool>(true).get()})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*Literal::MakeTuple({Literal::CreateR1<float>({7, 8, 9}).get(),
|
||||
Literal::CreateR0<bool>(false).get()})));
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
|
||||
LiteralUtil::CreateR0<bool>(false).get()})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*Literal::MakeTuple({Literal::CreateR1<float>({4, 5, 6}).get(),
|
||||
Literal::CreateR0<bool>(true).get()})));
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
|
||||
LiteralUtil::CreateR0<bool>(true).get()})));
|
||||
|
||||
// Wait for the execution to be done, and transfer the result.
|
||||
delete computation_thread; // Joins the thread.
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
|
||||
@ -42,7 +42,7 @@ TEST_F(CpuNoAliasTest, Concat) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> literal =
|
||||
Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
HloInstruction* param_x = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, param_shape, "x"));
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/defuser.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
@ -124,7 +124,7 @@ TEST_F(DefuserTest, NonTrivialFusionInstruction) {
|
||||
auto div = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
|
||||
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
|
||||
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
|
||||
auto add2 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
|
||||
|
||||
@ -162,7 +162,7 @@ TEST_F(DefuserTest, MultipleFusionInstructions) {
|
||||
auto div = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
|
||||
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
|
||||
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
|
||||
auto add2 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
|
||||
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
@ -57,8 +57,8 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
|
||||
std::unique_ptr<Literal> lhs = Literal::CreateR3<int32>({{{1}, {2}}});
|
||||
std::unique_ptr<Literal> rhs = Literal::CreateR3<int32>({{{3}, {4}}});
|
||||
std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
|
||||
std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
|
||||
RunTest(hlo_text, {lhs.get(), rhs.get()});
|
||||
}
|
||||
} // namespace
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -80,7 +80,7 @@ class FlattenCallGraphTest : public HloTestBase {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
|
||||
return builder.Build();
|
||||
@ -157,7 +157,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(PRED, {}), "param0"));
|
||||
HloInstruction* false_constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
|
||||
HloOpcode::kEq, param0, false_constant));
|
||||
@ -168,7 +168,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
|
||||
{
|
||||
HloComputation::Builder builder(TestName() + ".entry");
|
||||
HloInstruction* false_constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation,
|
||||
false_constant));
|
||||
@ -232,11 +232,11 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
|
||||
// computation in the true and false branch.
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto pred = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
|
||||
auto constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
|
||||
builder.AddInstruction(HloInstruction::CreateConditional(
|
||||
kScalarShape, pred, constant1, sub_computation, constant2,
|
||||
sub_computation));
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gather_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -113,7 +114,7 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
|
||||
const Shape& index_shape = index_vector->shape();
|
||||
HloInstruction* zero =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateFromDimensions(index_shape.element_type(), {1})));
|
||||
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
|
||||
|
||||
// We extract out individual components from the smaller index and concatenate
|
||||
// them (interspersing zeros as needed) into the larger index.
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
@ -150,7 +150,7 @@ cc_library(
|
||||
":parallel_loop_emitter",
|
||||
":partition_assignment",
|
||||
":while_transformer",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -199,7 +199,7 @@ cc_library(
|
||||
srcs = ["elemental_ir_emitter.cc"],
|
||||
hdrs = ["elemental_ir_emitter.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -351,6 +351,7 @@ cc_library(
|
||||
":cudnn_convolution_runner",
|
||||
":gpu_executable",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
@ -382,7 +383,7 @@ cc_library(
|
||||
hdrs = ["cudnn_convolution_rewriter.h"],
|
||||
deps = [
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -517,6 +518,7 @@ cc_library(
|
||||
hdrs = ["pad_insertion.h"],
|
||||
deps = [
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
@ -533,7 +535,7 @@ cc_library(
|
||||
hdrs = ["gpu_transfer_manager.h"],
|
||||
deps = [
|
||||
":gpu_compiler",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -624,6 +626,7 @@ cc_library(
|
||||
hdrs = ["cudnn_batchnorm_rewriter.h"],
|
||||
deps = [
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
@ -716,7 +719,7 @@ cc_library(
|
||||
srcs = ["while_transformer.cc"],
|
||||
hdrs = ["while_transformer.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user