From 8779f768a3c0fa8d48c25f71d65513c208c77432 Mon Sep 17 00:00:00 2001 From: Kay Zhu Date: Tue, 3 Jul 2018 19:01:49 -0700 Subject: [PATCH] [TF:XLA] Split literal_util into {literal, literal_util}. Currently Literal classes sits in literal_util.{h,cc} instead of literal.{h,cc}. It also contains helper functions that are better fit to be their own separate class/namespace. This change starts this process by moving most static factory methods to LiteralUtil namespace. PiperOrigin-RevId: 203217065 --- tensorflow/compiler/tf2xla/BUILD | 8 +- tensorflow/compiler/tf2xla/kernels/BUILD | 6 +- .../compiler/tf2xla/kernels/bcast_ops.cc | 2 +- tensorflow/compiler/tf2xla/kernels/elu_op.cc | 2 +- .../compiler/tf2xla/kernels/index_ops_cpu.cc | 8 +- .../compiler/tf2xla/kernels/pooling_ops.cc | 2 +- .../compiler/tf2xla/kernels/reduction_ops.cc | 2 +- .../tf2xla/kernels/reduction_ops_common.cc | 2 +- tensorflow/compiler/tf2xla/kernels/relu_op.cc | 2 +- .../compiler/tf2xla/kernels/reshape_op.cc | 2 +- .../compiler/tf2xla/kernels/reverse_op.cc | 2 +- .../compiler/tf2xla/kernels/scan_ops.cc | 2 +- .../compiler/tf2xla/kernels/sequence_ops.cc | 2 +- .../compiler/tf2xla/kernels/split_op.cc | 2 +- .../compiler/tf2xla/kernels/stack_ops.cc | 2 +- .../tf2xla/kernels/tensor_array_ops.cc | 2 +- tensorflow/compiler/tf2xla/kernels/topk_op.cc | 2 +- .../compiler/tf2xla/kernels/training_ops.cc | 2 +- .../compiler/tf2xla/kernels/unpack_op.cc | 2 +- .../compiler/tf2xla/kernels/variable_ops.cc | 2 +- .../compiler/tf2xla/kernels/while_op.cc | 2 +- tensorflow/compiler/tf2xla/lib/BUILD | 11 +- tensorflow/compiler/tf2xla/lib/batch_dot.cc | 2 +- tensorflow/compiler/tf2xla/lib/cholesky.cc | 2 +- tensorflow/compiler/tf2xla/lib/scatter.cc | 4 +- .../compiler/tf2xla/lib/triangular_solve.cc | 2 +- .../tf2xla/lib/triangular_solve_test.cc | 2 +- tensorflow/compiler/tf2xla/lib/util.cc | 32 +- tensorflow/compiler/tf2xla/lib/util_test.cc | 2 +- tensorflow/compiler/tf2xla/lib/while_loop.cc | 9 +- tensorflow/compiler/tf2xla/literal_util.cc | 2 +- tensorflow/compiler/tf2xla/literal_util.h | 2 +- .../compiler/tf2xla/literal_util_test.cc | 5 +- tensorflow/compiler/tf2xla/tf2xla_test.cc | 5 +- .../compiler/tf2xla/xla_compiler_test.cc | 74 +- tensorflow/compiler/tf2xla/xla_context.cc | 2 +- tensorflow/compiler/tf2xla/xla_helpers.cc | 4 +- tensorflow/compiler/xla/BUILD | 44 +- tensorflow/compiler/xla/client/BUILD | 2 +- tensorflow/compiler/xla/client/client.cc | 2 +- tensorflow/compiler/xla/client/client.h | 2 +- tensorflow/compiler/xla/client/lib/BUILD | 3 +- .../compiler/xla/client/lib/constants.cc | 8 +- .../compiler/xla/client/lib/math_test.cc | 3 +- tensorflow/compiler/xla/client/lib/testing.cc | 4 +- .../compiler/xla/client/xla_client/BUILD | 3 +- .../xla/client/xla_client/xla_builder.cc | 2 +- .../xla/client/xla_client/xla_builder.h | 40 +- tensorflow/compiler/xla/literal.cc | 1967 +++++++++++++++ tensorflow/compiler/xla/literal.h | 1152 +++++++++ tensorflow/compiler/xla/literal_comparison.cc | 5 +- tensorflow/compiler/xla/literal_comparison.h | 2 +- .../{literal_util_test.cc => literal_test.cc} | 536 +++-- tensorflow/compiler/xla/literal_util.cc | 2115 +---------------- tensorflow/compiler/xla/literal_util.h | 1171 +-------- .../compiler/xla/packed_literal_reader.cc | 2 +- .../compiler/xla/packed_literal_reader.h | 2 +- tensorflow/compiler/xla/python/BUILD | 3 +- .../xla/python/local_computation_builder.i | 2 +- .../compiler/xla/python/numpy_bridge.cc | 5 +- tensorflow/compiler/xla/python/numpy_bridge.h | 2 +- tensorflow/compiler/xla/reference_util.cc | 5 +- .../compiler/xla/reference_util_test.cc | 46 +- .../compiler/xla/rpc/grpc_client_test.cc | 2 +- tensorflow/compiler/xla/service/BUILD | 97 +- .../xla/service/algebraic_simplifier.cc | 15 +- .../xla/service/algebraic_simplifier_test.cc | 112 +- .../xla/service/batchnorm_expander.cc | 17 +- .../xla/service/batchnorm_expander_test.cc | 2 +- .../xla/service/bfloat16_propagation.cc | 2 +- .../xla/service/bfloat16_propagation_test.cc | 8 +- .../xla/service/buffer_assignment_test.cc | 65 +- .../xla/service/buffer_liveness_test.cc | 24 +- .../compiler/xla/service/call_graph_test.cc | 10 +- .../compiler/xla/service/call_inliner_test.cc | 16 +- .../xla/service/computation_placer.cc | 2 +- .../xla/service/conditional_simplifier.cc | 2 +- .../service/conditional_simplifier_test.cc | 12 +- .../xla/service/copy_insertion_test.cc | 103 +- tensorflow/compiler/xla/service/cpu/BUILD | 11 +- .../service/cpu/conv_canonicalization_test.cc | 8 +- .../compiler/xla/service/cpu/cpu_compiler.cc | 2 +- .../service/cpu/cpu_copy_insertion_test.cc | 8 +- .../cpu/cpu_instruction_fusion_test.cc | 4 +- .../service/cpu/cpu_layout_assignment_test.cc | 2 +- .../xla/service/cpu/cpu_transfer_manager.cc | 7 +- .../xla/service/cpu/sample_harness.cc | 9 +- .../compiler/xla/service/cpu/tests/BUILD | 6 +- .../cpu/tests/cpu_external_constants_test.cc | 2 +- .../xla/service/cpu/tests/cpu_fusion_test.cc | 24 +- .../xla/service/cpu/tests/cpu_infeed_test.cc | 77 +- .../xla/service/cpu/tests/cpu_noalias_test.cc | 4 +- .../compiler/xla/service/defuser_test.cc | 6 +- .../compiler/xla/service/dfs_hlo_visitor.h | 2 +- .../service/dfs_hlo_visitor_with_default.h | 2 +- .../xla/service/elemental_ir_emitter_test.cc | 4 +- .../xla/service/flatten_call_graph_test.cc | 14 +- .../compiler/xla/service/gather_expander.cc | 3 +- .../xla/service/generic_transfer_manager.cc | 2 +- tensorflow/compiler/xla/service/gpu/BUILD | 13 +- .../service/gpu/cudnn_batchnorm_rewriter.cc | 28 +- .../gpu/cudnn_convolution_algorithm_picker.cc | 8 +- .../service/gpu/cudnn_convolution_rewriter.cc | 2 +- .../xla/service/gpu/elemental_ir_emitter.cc | 2 +- .../service/gpu/gpu_layout_assignment_test.cc | 12 +- .../xla/service/gpu/gpu_transfer_manager.cc | 2 +- .../service/gpu/instruction_fusion_test.cc | 8 +- .../xla/service/gpu/ir_emitter_unnested.cc | 2 +- .../compiler/xla/service/gpu/pad_insertion.cc | 11 +- .../xla/service/gpu/while_transformer.cc | 2 +- .../xla/service/gpu/while_transformer_test.cc | 16 +- .../compiler/xla/service/graphviz_example.cc | 5 +- .../xla/service/heap_simulator_test.cc | 6 +- .../xla/service/hlo_alias_analysis_test.cc | 66 +- .../xla/service/hlo_computation_test.cc | 54 +- .../xla/service/hlo_constant_folding.cc | 2 +- .../xla/service/hlo_constant_folding_test.cc | 16 +- .../xla/service/hlo_cost_analysis_test.cc | 10 +- .../xla/service/hlo_creation_utils.cc | 9 +- .../xla/service/hlo_creation_utils_test.cc | 46 +- tensorflow/compiler/xla/service/hlo_cse.cc | 2 +- .../compiler/xla/service/hlo_cse_test.cc | 62 +- .../xla/service/hlo_dataflow_analysis_test.cc | 150 +- .../compiler/xla/service/hlo_dce_test.cc | 20 +- .../xla/service/hlo_element_type_converter.cc | 2 +- .../compiler/xla/service/hlo_evaluator.cc | 10 +- .../xla/service/hlo_evaluator_test.cc | 395 +-- .../xla/service/hlo_evaluator_typed_visitor.h | 22 +- .../compiler/xla/service/hlo_graph_dumper.cc | 2 +- .../xla/service/hlo_graph_dumper_test.cc | 3 +- .../compiler/xla/service/hlo_instruction.cc | 2 +- .../compiler/xla/service/hlo_instruction.h | 2 +- .../xla/service/hlo_instruction_test.cc | 52 +- .../compiler/xla/service/hlo_instructions.cc | 3 +- .../xla/service/hlo_liveness_analysis_test.cc | 2 +- .../compiler/xla/service/hlo_matchers_test.cc | 7 +- .../compiler/xla/service/hlo_module_test.cc | 8 +- .../compiler/xla/service/hlo_ordering_test.cc | 6 +- tensorflow/compiler/xla/service/hlo_parser.cc | 7 +- tensorflow/compiler/xla/service/hlo_query.cc | 2 +- .../xla/service/hlo_reachability_test.cc | 10 +- .../xla/service/hlo_rematerialization_test.cc | 8 +- .../xla/service/hlo_scheduling_test.cc | 40 +- .../compiler/xla/service/hlo_sharding.h | 2 +- .../compiler/xla/service/hlo_sharding_test.cc | 2 +- .../hlo_subcomputation_unification_test.cc | 6 +- .../xla/service/hlo_tfgraph_builder.cc | 2 +- .../xla/service/hlo_tfgraph_builder_test.cc | 2 +- .../implicit_broadcast_remover_test.cc | 2 +- .../compiler/xla/service/inliner_test.cc | 28 +- .../compiler/xla/service/interpreter/BUILD | 2 +- .../xla/service/interpreter/executable.cc | 2 +- .../xla/service/layout_assignment_test.cc | 22 +- tensorflow/compiler/xla/service/llvm_ir/BUILD | 2 +- .../compiler/xla/service/llvm_ir/llvm_util.cc | 2 +- .../compiler/xla/service/llvm_ir/llvm_util.h | 2 +- .../compiler/xla/service/reshape_mover.cc | 2 +- .../xla/service/reshape_mover_test.cc | 15 +- .../compiler/xla/service/transfer_manager.h | 2 +- .../xla/service/transpose_folding_test.cc | 8 +- .../service/tuple_points_to_analysis_test.cc | 81 +- .../xla/service/tuple_simplifier_test.cc | 2 +- .../while_loop_invariant_code_motion_test.cc | 4 +- .../xla/service/while_loop_simplifier_test.cc | 4 +- tensorflow/compiler/xla/service/while_util.cc | 9 +- .../xla/service/zero_sized_hlo_elimination.cc | 2 +- .../zero_sized_hlo_elimination_test.cc | 2 +- tensorflow/compiler/xla/tests/BUILD | 70 +- .../xla/tests/array_elementwise_ops_test.cc | 89 +- .../xla/tests/batch_normalization_test.cc | 94 +- .../compiler/xla/tests/bfloat16_test.cc | 18 +- .../xla/tests/broadcast_simple_test.cc | 96 +- .../compiler/xla/tests/broadcast_test.cc | 79 +- tensorflow/compiler/xla/tests/call_test.cc | 20 +- .../xla/tests/check_execution_arity_test.cc | 10 +- .../xla/tests/client_library_test_base.cc | 15 +- .../xla/tests/client_library_test_base.h | 44 +- tensorflow/compiler/xla/tests/client_test.cc | 10 +- .../xla/tests/compilation_cache_test.cc | 16 +- .../xla/tests/compute_constant_test.cc | 10 +- tensorflow/compiler/xla/tests/concat_test.cc | 16 +- .../compiler/xla/tests/conditional_test.cc | 32 +- .../compiler/xla/tests/constants_test.cc | 25 +- tensorflow/compiler/xla/tests/convert_test.cc | 16 +- .../convolution_dimension_numbers_test.cc | 3 +- .../compiler/xla/tests/convolution_test.cc | 68 +- .../xla/tests/convolution_variants_test.cc | 14 +- tensorflow/compiler/xla/tests/copy_test.cc | 27 +- .../xla/tests/cross_replica_sum_test.cc | 16 +- .../compiler/xla/tests/custom_call_test.cc | 6 +- .../xla/tests/deconstruct_tuple_test.cc | 4 +- .../compiler/xla/tests/dot_operation_test.cc | 68 +- .../compiler/xla/tests/dynamic_ops_test.cc | 42 +- .../xla/tests/execution_profile_test.cc | 2 +- .../exhaustive_f32_elementwise_op_test.cc | 2 +- tensorflow/compiler/xla/tests/fusion_test.cc | 155 +- .../xla/tests/gather_operation_test.cc | 123 +- tensorflow/compiler/xla/tests/half_test.cc | 2 +- .../compiler/xla/tests/literal_test_util.h | 31 +- .../xla/tests/literal_test_util_test.cc | 46 +- .../compiler/xla/tests/llvm_compiler_test.cc | 5 +- .../xla/tests/local_client_allocation_test.cc | 4 +- .../xla/tests/local_client_execute_test.cc | 123 +- tensorflow/compiler/xla/tests/map_test.cc | 52 +- .../xla/tests/matrix_ops_simple_test.cc | 24 +- .../xla/tests/multioutput_fusion_test.cc | 120 +- tensorflow/compiler/xla/tests/pad_test.cc | 51 +- tensorflow/compiler/xla/tests/params_test.cc | 60 +- tensorflow/compiler/xla/tests/prng_test.cc | 4 +- .../compiler/xla/tests/reduce_hlo_test.cc | 30 +- .../xla/tests/reduce_precision_test.cc | 15 +- tensorflow/compiler/xla/tests/reduce_test.cc | 42 +- .../compiler/xla/tests/reduce_window_test.cc | 136 +- tensorflow/compiler/xla/tests/replay_test.cc | 6 +- .../compiler/xla/tests/reshape_motion_test.cc | 2 +- tensorflow/compiler/xla/tests/reshape_test.cc | 208 +- tensorflow/compiler/xla/tests/reverse_test.cc | 2 +- .../tests/round_trip_packed_literal_test.cc | 2 +- .../xla/tests/round_trip_transfer_test.cc | 50 +- .../xla/tests/scalar_computations_test.cc | 25 +- .../xla/tests/select_and_scatter_test.cc | 2 +- tensorflow/compiler/xla/tests/slice_test.cc | 8 +- tensorflow/compiler/xla/tests/test_utils.cc | 9 +- tensorflow/compiler/xla/tests/test_utils.h | 2 +- .../compiler/xla/tests/token_hlo_test.cc | 12 +- .../xla/tests/transfer_manager_test.cc | 80 +- tensorflow/compiler/xla/tests/tuple_test.cc | 106 +- .../compiler/xla/tests/unary_op_test.cc | 8 +- tensorflow/compiler/xla/tests/while_test.cc | 69 +- .../compiler/xla/text_literal_reader.cc | 2 +- tensorflow/compiler/xla/text_literal_reader.h | 2 +- .../compiler/xla/text_literal_reader_test.cc | 2 +- .../compiler/xla/text_literal_writer.cc | 2 +- tensorflow/compiler/xla/text_literal_writer.h | 2 +- .../compiler/xla/text_literal_writer_test.cc | 6 +- tensorflow/compiler/xla/tools/BUILD | 6 +- .../compiler/xla/tools/replay_computation.cc | 2 +- tensorflow/compiler/xla/tools/show_literal.cc | 2 +- .../compiler/xla/tools/show_text_literal.cc | 2 +- 239 files changed, 6151 insertions(+), 5779 deletions(-) create mode 100644 tensorflow/compiler/xla/literal.cc create mode 100644 tensorflow/compiler/xla/literal.h rename tensorflow/compiler/xla/{literal_util_test.cc => literal_test.cc} (76%) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 40e32f2e757..fd31c265443 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -162,7 +162,7 @@ cc_library( ":sharding_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -202,7 +202,7 @@ cc_library( ], visibility = [":friends"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu_internal", @@ -285,6 +285,7 @@ tf_cc_test( deps = [ ":tf2xla", ":tf2xla_proto", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", @@ -327,7 +328,7 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", @@ -364,6 +365,7 @@ tf_cc_test( ], deps = [ ":common", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/core:framework", "//tensorflow/core:test", diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index a8eb7d942df..d1e89828771 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -114,6 +114,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -159,7 +160,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -175,7 +176,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -210,6 +211,7 @@ tf_kernel_library( ":index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client/lib:arithmetic", diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index ee2c920453c..ba3b1c9dab7 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 2c76bcee259..81f42e504e4 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index a020ebc729e..22a45b2a11e 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel { std::vector args; args.push_back(ctx->Input(0)); args.push_back(xla::ConstantLiteral( - &b, *xla::Literal::CreateR1(input_shape.dim_sizes()))); + &b, *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(xla::ConstantLiteral( - &b, *xla::Literal::CreateR1(output_shape.dim_sizes()))); + &b, *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, *xla::Literal::CreateR0(dim))); + xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index a81f5fddf69..12d9cb9bac6 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 46fae59ad4f..be7f2bce8cb 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 909783ecb3c..ed1d1c66109 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index a4ba6c748a7..f4b804e5467 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index e0ca8dd8e27..354fec9be75 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 037c4222585..ec15b4cc7a5 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 76924c6a01a..27ab3e1bf5b 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index bc3d0bf5dfe..25a5bcbe1dd 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index ca74cf24507..242638f9811 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 591e61b4c82..df919005701 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 2f650ce3052..26326f18b84 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 9962f1207d6..1ddcb08c8e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index bef6161e854..b62a6e778df 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index 0e5d58ecbae..f951127bb95 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index febac828735..bb27b5d56f3 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 340165bac6a..9413a30a6c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index dfa3c0595ac..20fa03746c9 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -40,7 +40,7 @@ cc_library( ":triangular_solve", ":util", ":while_loop", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -73,7 +73,7 @@ cc_library( deps = [ ":util", ":while_loop", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -92,7 +92,7 @@ cc_library( deps = [ ":batch_dot", ":util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -111,7 +111,7 @@ xla_test( deps = [ ":triangular_solve", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -133,6 +133,7 @@ cc_library( srcs = ["util.cc"], hdrs = ["util.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -151,7 +152,7 @@ xla_test( ":batch_dot", ":util", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index f9f3a8c8cfc..3c4eec081ba 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -84,7 +84,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, dimensions.push_back(y_shape.dimensions(y_outer_dim)); return xla::Broadcast( xla::ConstantLiteral(builder, - xla::Literal::Zero(x_shape.element_type())), + xla::LiteralUtil::Zero(x_shape.element_type())), dimensions); } diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index cc840de393e..35b137aa2cc 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 85e3d3ab85a..6a5be1c2be5 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -114,7 +114,7 @@ xla::StatusOr XlaScatter( auto buffer = loop_vars[2]; auto zero_index = xla::ConstantLiteral( - body_builder, xla::Literal::Zero(indices_shape.element_type())); + body_builder, xla::LiteralUtil::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. xla::XlaOp index; diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 588afaac651..ce0f28db8f6 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index d5ffc1498e4..f1bff6037bf 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index fdc8bfca493..a6f5d346cb5 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -28,6 +29,13 @@ limitations under the License. namespace tensorflow { +xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { + return xla::Broadcast( + xla::ConstantLiteral(builder, + xla::LiteralUtil::Zero(shape.element_type())), + xla::AsInt64Slice(shape.dimensions())); +} + xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, double value) { switch (type) { @@ -56,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, xla::Literal literal; switch (type) { case xla::U8: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::U32: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::U64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::S8: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::S32: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::S64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::F32: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::F64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::C64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -89,11 +97,11 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: literal = std::move( - *xla::Literal::CreateR0(static_cast(value))); + *xla::LiteralUtil::CreateR0(static_cast(value))); break; case xla::F16: - literal = std::move( - *xla::Literal::CreateR0(static_cast(value))); + literal = std::move(*xla::LiteralUtil::CreateR0( + static_cast(value))); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index 7d0f2222a9a..442fe92c34c 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 7cc88f34d29..574e70ddeea 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -100,8 +100,9 @@ xla::StatusOr> XlaForEachIndex( std::vector updated_values; updated_values.reserve(values.size()); updated_values.push_back(xla::Add( - iteration, xla::ConstantLiteral( - body_builder, xla::Literal::One(num_iterations_type)))); + iteration, + xla::ConstantLiteral(body_builder, + xla::LiteralUtil::One(num_iterations_type)))); values.remove_prefix(1); TF_ASSIGN_OR_RETURN(std::vector body_outputs, @@ -113,8 +114,8 @@ xla::StatusOr> XlaForEachIndex( std::vector values; values.reserve(initial_values.size() + 1); - values.push_back( - xla::ConstantLiteral(builder, xla::Literal::Zero(num_iterations_type))); + values.push_back(xla::ConstantLiteral( + builder, xla::LiteralUtil::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index b43405a1a40..2fb66913ada 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index ab7e861f333..0610a57029e 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -18,7 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index f3d6787daaa..a3404c2b3df 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -27,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { { std::vector int64_values = {1, 2, 3}; std::unique_ptr int64_values_literal = - xla::Literal::CreateR1(gtl::ArraySlice(int64_values)); + xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) @@ -48,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { Tensor host_tensor; std::vector int32_values = {10, 11}; std::unique_ptr int32_values_literal = - xla::Literal::CreateR1(gtl::ArraySlice(int32_values)); + xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); EXPECT_TRUE( LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) .ok()); diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 84c133ffabe..f0b30dcf4e9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -73,8 +74,8 @@ TEST(ConvertGraphDefToXla, Sum) { TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. - auto x_literal = xla::Literal::CreateR0(10); - auto y_literal = xla::Literal::CreateR0(32); + auto x_literal = xla::LiteralUtil::CreateR0(10); + auto y_literal = xla::LiteralUtil::CreateR0(32); auto x_global_or = client->TransferToServer(*x_literal); auto y_global_or = client->TransferToServer(*y_literal); TF_EXPECT_OK(x_global_or.status()); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 07af8ef54b7..6f76816a861 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -206,9 +206,9 @@ TEST_F(XlaCompilerTest, Simple) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({-3, 101}); + xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -222,9 +222,9 @@ TEST_F(XlaCompilerTest, Simple) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({4, 143}); + xla::LiteralUtil::CreateR1({4, 143}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get()}); + xla::LiteralUtil::MakeTuple({expected0.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -306,7 +306,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -317,9 +317,9 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({-7, -42}); + xla::LiteralUtil::CreateR1({-7, -42}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get()}); + xla::LiteralUtil::MakeTuple({expected0.get()}); EXPECT_TRUE( xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -341,7 +341,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -351,11 +351,12 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = xla::Literal::CreateR0(7); + std::unique_ptr expected0 = + xla::LiteralUtil::CreateR0(7); std::unique_ptr expected1 = - xla::Literal::CreateR1({-7, -42}); + xla::LiteralUtil::CreateR1({-7, -42}); std::unique_ptr expected = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); } } @@ -569,11 +570,11 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { // Tests that the generated computation works. std::unique_ptr input_base = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr input_grad2 = - xla::Literal::CreateR1({-3, 101}); + xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr input = - xla::Literal::MakeTuple({input_base.get(), input_grad2.get()}); + xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); std::unique_ptr param0_data = client_->TransferToServer(*input).ConsumeValueOrDie(); @@ -583,17 +584,18 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr output_read = xla::Literal::CreateR0(42); + std::unique_ptr output_read = + xla::LiteralUtil::CreateR0(42); std::unique_ptr output_base = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr output_grad1 = - xla::Literal::CreateR1({0, 1}); + xla::LiteralUtil::CreateR1({0, 1}); std::unique_ptr output_grad2 = - xla::Literal::CreateR1({-3, 101}); - std::unique_ptr output_resource = xla::Literal::MakeTuple( + xla::LiteralUtil::CreateR1({-3, 101}); + std::unique_ptr output_resource = xla::LiteralUtil::MakeTuple( {output_base.get(), output_grad1.get(), output_grad2.get()}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); + xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -796,9 +798,9 @@ TEST_F(XlaCompilerTest, Variables) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({-3, 101}); + xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -812,11 +814,11 @@ TEST_F(XlaCompilerTest, Variables) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({5, 144}); + xla::LiteralUtil::CreateR1({5, 144}); std::unique_ptr expected1 = - xla::Literal::CreateR1({4, 143}); + xla::LiteralUtil::CreateR1({4, 143}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -884,9 +886,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR2({{4, 55}, {1, -3}}); + xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -900,11 +902,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR2({{27, 67}, {35, 402}}); + xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); std::unique_ptr expected1 = - xla::Literal::CreateR1({26, 66, 34, 401}); + xla::LiteralUtil::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -953,9 +955,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({4, 55, 1, -3}); + xla::LiteralUtil::CreateR1({4, 55, 1, -3}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -969,11 +971,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({27, 67, 35, 402}); + xla::LiteralUtil::CreateR1({27, 67, 35, 402}); std::unique_ptr expected1 = - xla::Literal::CreateR1({26, 66, 34, 401}); + xla::LiteralUtil::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index fd39a58ce64..0dea3664769 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index edbc5e95a8c..4d1b3b1a135 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -94,13 +94,13 @@ xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis, xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return xla::ConstantLiteral(b, xla::Literal::Zero(type)); + return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type)); } xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return xla::ConstantLiteral(b, xla::Literal::One(type)); + return xla::ConstantLiteral(b, xla::LiteralUtil::One(type)); } xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type, diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 03e542855ba..2f43b18d390 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -281,9 +281,9 @@ tf_cc_test( ) cc_library( - name = "literal_util", - srcs = ["literal_util.cc"], - hdrs = ["literal_util.h"], + name = "literal", + srcs = ["literal.cc"], + hdrs = ["literal.h"], visibility = ["//visibility:public"], deps = [ ":array2d", @@ -300,11 +300,12 @@ cc_library( ) tf_cc_test( - name = "literal_util_test", - srcs = ["literal_util_test.cc"], + name = "literal_test", + srcs = ["literal_test.cc"], deps = [ ":array3d", ":array4d", + ":literal", ":literal_util", ":shape_util", ":test", @@ -316,6 +317,26 @@ tf_cc_test( ], ) +cc_library( + name = "literal_util", + srcs = ["literal_util.cc"], + hdrs = ["literal_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":array2d", + ":array3d", + ":array4d", + ":literal", + ":shape_util", + ":sparse_index_array", + ":status_macros", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + cc_library( name = "error_spec", hdrs = ["error_spec.h"], @@ -327,6 +348,7 @@ cc_library( hdrs = ["literal_comparison.h"], deps = [ ":error_spec", + ":literal", ":literal_util", ":util", "//tensorflow/core:lib", @@ -458,7 +480,7 @@ cc_library( hdrs = ["packed_literal_reader.h"], visibility = [":internal"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":status_macros", ":statusor", @@ -489,7 +511,7 @@ cc_library( hdrs = ["text_literal_reader.h"], visibility = [":internal"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":status_macros", ":statusor", @@ -505,7 +527,7 @@ tf_cc_test( name = "text_literal_reader_test", srcs = ["text_literal_reader_test.cc"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":test", ":text_literal_reader", @@ -522,7 +544,7 @@ cc_library( hdrs = ["text_literal_writer.h"], visibility = [":internal"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":status_macros", ":types", @@ -535,6 +557,7 @@ tf_cc_test( name = "text_literal_writer_test", srcs = ["text_literal_writer_test.cc"], deps = [ + ":literal", ":literal_util", ":test", ":test_helpers", @@ -607,6 +630,7 @@ cc_library( ":array2d", ":array3d", ":array4d", + ":literal_util", ":util", ":window_util", ":xla_data_proto", @@ -627,7 +651,7 @@ tf_cc_test( ":array2d", ":array3d", ":array4d", - ":literal_util", + ":literal", ":reference_util", ":test", ":util", diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 8f08d3b2e04..25666cad40e 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -65,7 +65,7 @@ cc_library( deps = [ ":global_data", "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:service_interface", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 3d596a6e654..3a157c69cd7 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 68f0d0ac78c..69d4d300ca9 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index a6b9b472532..6933e9a838c 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -82,6 +82,7 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ ":math", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -123,7 +124,7 @@ cc_library( hdrs = ["testing.h"], deps = [ "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index 1686389a234..031d62e4ffe 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -21,7 +21,7 @@ limitations under the License. namespace xla { XlaOp Zero(XlaBuilder* builder, PrimitiveType type) { - return ConstantLiteral(builder, Literal::Zero(type)); + return ConstantLiteral(builder, LiteralUtil::Zero(type)); } XlaOp Zeros(XlaBuilder* builder, const Shape& shape) { @@ -38,7 +38,7 @@ XlaOp ZerosLike(XlaOp prototype) { } XlaOp One(XlaBuilder* builder, PrimitiveType type) { - return ConstantLiteral(builder, Literal::One(type)); + return ConstantLiteral(builder, LiteralUtil::One(type)); } XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { @@ -61,7 +61,7 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { } XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) { - return ConstantLiteral(builder, Literal::MinValue(type)); + return ConstantLiteral(builder, LiteralUtil::MinValue(type)); } XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { @@ -81,7 +81,7 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { } XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) { - return ConstantLiteral(builder, Literal::MaxValue(type)); + return ConstantLiteral(builder, LiteralUtil::MaxValue(type)); } XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) { diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 1df4e6ea42a..068cd2e5861 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -31,7 +32,7 @@ class MathTest : public ClientLibraryTestBase { XLA_TEST_F(MathTest, SqrtF32) { XlaBuilder builder(TestName()); - Literal zero_literal = Literal::Zero(PrimitiveType::F32); + Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); std::unique_ptr zero_data = client_->TransferToServer(zero_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 731ad13b8d0..534c5098683 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -49,7 +49,7 @@ int64 DataSizeOfShape(const Shape& shape) { XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { if (ShapeUtil::IsArray(shape)) { return Broadcast( - ConstantLiteral(builder, Literal::One(shape.element_type())), + ConstantLiteral(builder, LiteralUtil::One(shape.element_type())), AsInt64Slice(shape.dimensions())); } std::vector parts; diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index ee00a9eada8..763653c685c 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -43,6 +43,7 @@ cc_library( deps = [ ":xla_computation", "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -64,7 +65,7 @@ tf_cc_test( srcs = ["xla_builder_test.cc"], deps = [ ":xla_builder", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 12efcb4b4f7..d4759a0fff2 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -736,7 +736,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); - *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); + *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); } diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 274aba8a310..fbcdb4c802f 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -1943,12 +1944,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, template XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*Literal::CreateR0(value)); + return ConstantLiteral(*LiteralUtil::CreateR0(value)); } template XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(*Literal::CreateR1(values)); + return ConstantLiteral(*LiteralUtil::CreateR1(values)); } template @@ -1960,44 +1961,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { } inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*Literal::CreateR1(values)); + return ConstantLiteral(*LiteralUtil::CreateR1(values)); } template XlaOp XlaBuilder::ConstantR2( std::initializer_list> values) { - return ConstantLiteral(*Literal::CreateR2(values)); + return ConstantLiteral(*LiteralUtil::CreateR2(values)); } template XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, const Layout& layout) { return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(*Literal::CreateFromArray(values)); + return ConstantLiteral(*LiteralUtil::CreateFromArray(values)); } template XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); + return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D(values)); } template XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { return ConstantLiteral( - *Literal::CreateR3FromArray3DWithLayout(values, layout)); + *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template @@ -2020,13 +2021,13 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { template XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, *Literal::CreateR0(value)); + return ConstantLiteral(builder, *LiteralUtil::CreateR0(value)); } template XlaOp ConstantR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(builder, *Literal::CreateR1(values)); + return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); } template @@ -2039,13 +2040,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { inline XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values) { - return ConstantLiteral(builder, *Literal::CreateR1(values)); + return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); } template XlaOp ConstantR2(XlaBuilder* builder, std::initializer_list> values) { - return ConstantLiteral(builder, *Literal::CreateR2(values)); + return ConstantLiteral(builder, *LiteralUtil::CreateR2(values)); } template @@ -2053,12 +2054,14 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, const Array& values, const Layout& layout) { return ConstantLiteral( - builder, *Literal::CreateFromArrayWithLayout(values, layout)); + builder, + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { - return ConstantLiteral(builder, *Literal::CreateFromArray(values)); + return ConstantLiteral(builder, + *LiteralUtil::CreateFromArray(values)); } template @@ -2066,14 +2069,15 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, const Array2D& values, const Layout& layout) { return ConstantLiteral( - builder, *Literal::CreateFromArrayWithLayout(values, layout)); + builder, + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantR2FromArray2D(XlaBuilder* builder, const Array2D& values) { return ConstantLiteral(builder, - *Literal::CreateR2FromArray2D(values)); + *LiteralUtil::CreateR2FromArray2D(values)); } template @@ -2082,7 +2086,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, const Layout& layout) { return ConstantLiteral( builder, - *Literal::CreateR3FromArray3DWithLayout(values, layout)); + *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc new file mode 100644 index 00000000000..5db124b5a22 --- /dev/null +++ b/tensorflow/compiler/xla/literal.cc @@ -0,0 +1,1967 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::strings::Printf; +using tensorflow::strings::StrCat; + +namespace xla { + +namespace { + +constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; + +// Converts between little and big endian. +// +// Precondition: size % 2 == 0 (elements in the array are 16 bits long) +void ConvertEndianShort(string* bytes) { + CHECK_EQ(bytes->size() / 2, 0); + for (int64 i = 0; i < bytes->size(); i += 2) { + std::swap((*bytes)[i], (*bytes)[i + 1]); + } +} + +void ConvertEndianShort(char* bytes, int64 size) { + CHECK_EQ(size / 2, 0); + for (int64 i = 0; i < size; i += 2) { + std::swap(bytes[i], bytes[i + 1]); + } +} + +} // namespace + +LiteralBase::~LiteralBase() {} + +std::ostream& operator<<(std::ostream& out, const Literal& literal) { + out << literal.ToString(); + return out; +} + +Literal::StrideConfig::StrideConfig( + const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions) + : dimensions(dimensions), + base(dimensions.size(), 0), + step(dimensions.size(), 1) { + if (!dimensions.empty()) { + // Selects the shape with the largest minor dimension as the one upon + // which to run the tight stride loop. + if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >= + dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) { + minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0); + dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); + } else { + minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0); + source_stride = + IndexUtil::GetDimensionStride(source_shape, minor_dimension); + } + minor_loop_size = dimensions[minor_dimension]; + step[minor_dimension] = minor_loop_size; + } +} + +Literal::Literal(const Shape& shape) + : Literal(shape, /*allocate_arrays=*/true) {} + +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else if (ShapeUtil::IsArray(shape)) { + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); + } else { + piece->set_buffer(new char[piece->size_bytes()]); + } + } + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(piece->size_bytes(), 0); + } +} + +Literal::Literal(const Shape& shape, bool allocate_arrays) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); + + SetPiece(*shape_, root_piece_, allocate_arrays); +} + +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; + } +} + +void Literal::DeallocateBuffers() { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); +} + +Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } + +Literal& Literal::operator=(Literal&& other) { + DCHECK(&other.root_piece_->subshape() == other.shape_.get()); + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + DCHECK(&root_piece_->subshape() == shape_.get()); + + return *this; +} + +std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { + auto literal = MakeUnique(shape); + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); + return literal; +} + +const SparseIndexArray* LiteralBase::sparse_indices( + const ShapeIndex& shape_index) const { + return piece(shape_index).sparse_indices(); +} + +SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { + return piece(shape_index).sparse_indices(); +} + +template +Status Literal::CopySliceFromInternal( + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); + TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); + + auto linear_index = [](const Shape& shape, + tensorflow::gtl::ArraySlice multi_index) { + return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); + }; + + if (ShapeUtil::Rank(src_literal.shape()) == 0 || + ShapeUtil::Rank(shape()) == 0) { + // If any of the two shapes are scalars, we can just call the StridedCopy() + // directly, and we know we will be copying only one value. + TF_RET_CHECK(copy_size.empty()); + StridedCopy(data(), linear_index(shape(), dest_base), 0, + src_literal.data(), + linear_index(src_literal.shape(), src_base), 0, 1); + } else if (!ShapeUtil::IsZeroElementArray(shape()) && + !ShapeUtil::IsZeroElementArray(src_literal.shape())) { + // Perform copy if neither src nor dest has dimensions with zero element, + // otherwise it's a no-op. + TF_RET_CHECK(src_base.size() == dest_base.size()); + TF_RET_CHECK(src_base.size() == copy_size.size()); + + // Scan the source from minor, stepping in copy size blocks, then within + // the index enumaration functor, do a strided copy advancing source index + // by one (walking through the minor dimension), and destination index by + // proper stride size at the matching dimension. + DimensionVector src_indexes(src_base.size(), 0); + DimensionVector dest_indexes(dest_base.size(), 0); + Literal::StrideConfig stride_config(src_literal.shape(), shape(), + copy_size); + + auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { + // Map from multi-dimensional index, to source index. + std::transform(indexes.begin(), indexes.end(), src_base.begin(), + src_indexes.begin(), std::plus()); + // Map from multi-dimensional index, to destination index. + std::transform(indexes.begin(), indexes.end(), dest_base.begin(), + dest_indexes.begin(), std::plus()); + + int64 src_index = linear_index(src_literal.shape(), src_indexes); + int64 dest_index = linear_index(shape(), dest_indexes); + + // `this->` is needed to workaround MSVC bug: #16882 + StridedCopy(this->data(), dest_index, stride_config.dest_stride, + src_literal.data(), src_index, + stride_config.source_stride, stride_config.minor_loop_size); + return true; + }; + + ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base, + stride_config.dimensions, stride_config.step, + copy_proc); + } + return Status::OK(); +} + +Status Literal::CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index) { + DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); + const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( + src_literal.shape(), src_index); + const int64 dest_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + char* dest_address = + static_cast(untyped_data()) + dest_linear_index * primitive_size; + const char* source_address = + static_cast(src_literal.untyped_data()) + + src_linear_index * primitive_size; + if (dest_address != source_address) { + memcpy(dest_address, source_address, primitive_size); + } + return Status::OK(); +} + +/* static */ StatusOr> Literal::CreateFromProto( + const LiteralProto& proto) { + if (!proto.has_shape()) { + return InvalidArgument("LiteralProto has no shape"); + } + if (!LayoutUtil::HasLayout(proto.shape())) { + return InvalidArgument("LiteralProto has no layout"); + } + + auto literal = MakeUnique(proto.shape()); + + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } + + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } + if (piece->subshape().element_type() == TOKEN) { + return Status::OK(); + } + + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); + + return std::move(literal); +} + +std::vector Literal::DecomposeTuple() { + CHECK(ShapeUtil::IsTuple(shape())); + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), + /*allocate_arrays=*/false)); + Literal& element = elements.back(); + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); + } + // Set this literal to be nil-shaped. + *this = Literal(); + return elements; +} + +namespace { + +// Copies the elements in 'src' to 'dest'. The shape and layout of the data in +// the array slices are indicated by dest_shape and src_shape respectively. +template +void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, + tensorflow::gtl::ArraySlice src, + const Shape& dest_shape, const Shape& src_shape) { + CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); + if (ShapeUtil::IsZeroElementArray(dest_shape)) { + return; + } + std::vector index(ShapeUtil::Rank(dest_shape)); + do { + dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = + src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; + } while (IndexUtil::BumpIndices(dest_shape, &index)); +} + +} // namespace + +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { + CHECK(subshape_ != nullptr); + CHECK(src.subshape_ != nullptr); + if (ShapeUtil::Equal(subshape(), src.subshape())) { + // If the layouts are equal it's faster just to memcpy. + memcpy(buffer(), src.buffer(), src.size_bytes()); + } else { + TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); + std::vector origin(ShapeUtil::Rank(subshape()), 0); + switch (subshape().element_type()) { +#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ + case (XLA_T): \ + CopyElementsBetween(data(), src.data(), \ + subshape(), src.subshape()); \ + break; + COPY_ELEMENTS(U8, uint8); + COPY_ELEMENTS(U16, uint16); + COPY_ELEMENTS(U32, uint32); + COPY_ELEMENTS(U64, uint64); + COPY_ELEMENTS(S8, int8); + COPY_ELEMENTS(S16, int16); + COPY_ELEMENTS(S32, int32); + COPY_ELEMENTS(S64, int64); + COPY_ELEMENTS(F16, half); + COPY_ELEMENTS(BF16, bfloat16); + COPY_ELEMENTS(F32, float); + COPY_ELEMENTS(F64, double); + COPY_ELEMENTS(C64, complex64); + COPY_ELEMENTS(PRED, bool); +#undef COPY_ELEMENTS + default: + return Unimplemented( + "Copying a Literal object with element type %s is not implemented.", + PrimitiveType_Name(subshape().element_type()).c_str()); + } + } + return Status::OK(); +} + +Status Literal::CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index, + const ShapeIndex& src_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + const Shape& src_subshape = + ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index); + if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { + return InvalidArgument( + "Destination subshape incompatible with source subshape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_subshape).c_str()); + } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } + + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); +} + +Status Literal::MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { + return InvalidArgument( + "Destination subshape not equal to source shape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_literal.shape()).c_str()); + } + + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } + + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); + + src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); + + return Status::OK(); +} + +Status Literal::CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); + TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) + << ShapeUtil::HumanString(src_literal.shape()); + TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); + + switch (shape().element_type()) { + case U8: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case U16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case U32: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case U64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S8: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S32: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case F16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case BF16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case F32: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case F64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case C64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case PRED: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + default: + break; + } + return Unimplemented( + "Copying a slice from a Literal object with element type %d is not " + "implemented.", + shape().element_type()); +} + +void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(element_count(), values.bits()); + CHECK_EQ(shape().element_type(), PRED); + for (int64 i = 0; i < static_cast(values.bits()); ++i) { + Set({i}, values.get(i)); + } +} + +std::unique_ptr LiteralBase::Relayout( + const Layout& new_layout, const ShapeIndex& shape_index) const { + // Create new shape with 'new_layout' set at the given shape index. + Shape new_shape = shape(); + Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); + TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); + *subshape->mutable_layout() = new_layout; + auto result = MakeUnique(new_shape); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; +} + +std::unique_ptr LiteralBase::Relayout( + const Shape& shape_with_layout) const { + CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) + << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) + << " not compatible with literal shape " + << ShapeUtil::HumanString(shape()); + std::unique_ptr result = CreateFromShape(shape_with_layout); + ShapeUtil::ForEachSubshape( + result->shape(), + [this, &result](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(subshape)) { + TF_CHECK_OK(result->CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); + } + }); + return result; +} + +StatusOr> LiteralBase::Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Broadcast only supports arrays."); + } + + for (int64 i = 0; i < dimensions.size(); i++) { + TF_RET_CHECK(shape().dimensions(i) == + result_shape.dimensions(dimensions[i])); + } + + std::unique_ptr result = MakeUnique(result_shape); + + // scratch_source_index is temporary storage space for the computed index into + // the input literal. We put it here to avoid allocating an std::vector in + // every iteration of ShapeUtil::ForEachIndex. + std::vector scratch_source_index(shape().dimensions_size()); + + char* dest_data = static_cast(result->untyped_data()); + const char* source_data = static_cast(untyped_data()); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + ShapeUtil::ForEachIndex( + result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + for (int64 i = 0; i < dimensions.size(); ++i) { + scratch_source_index[i] = output_index[dimensions[i]]; + } + int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( + result_shape, output_index); + int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( + shape(), scratch_source_index); + memcpy(dest_data + primitive_size * dest_index, + source_data + primitive_size * source_index, primitive_size); + return true; + }); + + return std::move(result); +} + +StatusOr> LiteralBase::Reshape( + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Reshape does not support tuples."); + } + std::unique_ptr output; + if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { + output = + Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); + } else { + output = CloneToUnique(); + } + // Because the layout is monotonic, we can simply reuse the same sequence of + // values without changing their order. + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); + + int64 elements_before = ShapeUtil::ElementsIn(shape()); + int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + if (elements_before != elements_after) { + return InvalidArgument( + "Shapes before and after Literal::Reshape have different numbers " + "of elements: %s vs %s.", + ShapeUtil::HumanString(shape()).c_str(), + ShapeUtil::HumanString(output->shape()).c_str()); + } + return std::move(output); +} + +std::unique_ptr LiteralBase::Transpose( + tensorflow::gtl::ArraySlice permutation) const { + CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) + << "Given permutation is not a permutation of dimension numbers"; + // To transpose the array, we just permute the dimensions and layout, and + // do a straight memory copy of the raw data set. + // This is considerably faster than iterating over every array element using + // the EachCell<>() and Set<>() APIs. + std::vector inverse_permutation = InversePermutation(permutation); + Shape permuted_shape = + ShapeUtil::PermuteDimensions(inverse_permutation, shape()); + // Replace the layout with one affine to this shape, such that a + // transpose operation can be performed by leaving the flat values + // representation intact. + // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. + // The shape with affine layout resulting from that operation will be + // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the + // most minor. + // + // Essentially, given MinMaj(Di) the position of the Di dimension within the + // minor to major vector, and given T(Di) the index that the original Di + // dimension has within the transposed array, a layout is affine if + // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major + // vector of the affine layout. + CHECK(LayoutUtil::IsDenseArray(permuted_shape)); + Layout* layout = permuted_shape.mutable_layout(); + layout->clear_minor_to_major(); + for (auto index : LayoutUtil::MinorToMajor(shape())) { + layout->add_minor_to_major(inverse_permutation[index]); + } + auto new_literal = MakeUnique(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), + ShapeUtil::ByteSizeOf(shape())); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); + return new_literal; +} + +template +std::unique_ptr LiteralBase::SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const { + auto result_literal = MakeUnique(result_shape); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + NativeT value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; +} + +std::unique_ptr LiteralBase::Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const { + CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; + + DimensionVector result_dimensions; + for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { + CHECK_GE(start_indices[dnum], 0); + CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) + << "dnum = " << dnum; + int64 dimension = limit_indices[dnum] - start_indices[dnum]; + CHECK_GE(dimension, 0) << "dnum = " << dnum; + result_dimensions.push_back(dimension); + } + const auto result_shape = + ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, + LayoutUtil::MinorToMajor(shape())); + switch (result_shape.element_type()) { + case F32: + return SliceInternal(result_shape, start_indices); + case BF16: + return SliceInternal(result_shape, start_indices); + case C64: + return SliceInternal(result_shape, start_indices); + case S32: + return SliceInternal(result_shape, start_indices); + case U32: + return SliceInternal(result_shape, start_indices); + default: + LOG(FATAL) << "not yet implemented: " + << PrimitiveType_Name(result_shape.element_type()); + } +} + +Literal LiteralBase::Clone() const { + Literal result(shape()); + TF_CHECK_OK(result.CopyFrom(*this)); + return result; +} + +std::unique_ptr LiteralBase::CloneToUnique() const { + auto result = MakeUnique(shape()); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; +} + +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsDenseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return Get(multi_index, shape_index) ? "true" : "false"; + case S8: + return StrCat(Get(multi_index, shape_index)); + case S16: + return StrCat(Get(multi_index, shape_index)); + case S32: + return StrCat(Get(multi_index, shape_index)); + case S64: + return StrCat(Get(multi_index, shape_index)); + case U8: + return StrCat(Get(multi_index, shape_index)); + case U16: + return StrCat(Get(multi_index, shape_index)); + case U32: + return StrCat(Get(multi_index, shape_index)); + case U64: + return StrCat(Get(multi_index, shape_index)); + case F16: + return StrCat(static_cast(Get(multi_index, shape_index))); + case F32: + return StrCat(Get(multi_index, shape_index)); + case BF16: + return StrCat( + static_cast(Get(multi_index, shape_index))); + case F64: + return StrCat(Get(multi_index, shape_index)); + case C64: { + complex64 c = Get(multi_index, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } + default: + LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); + } +} + +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsSparseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return GetSparseElement(sparse_element_number, shape_index) + ? "true" + : "false"; + case S8: + return StrCat(GetSparseElement(sparse_element_number, shape_index)); + case S16: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case S32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case S64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U8: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U16: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case F16: + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); + case F32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case BF16: + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); + case F64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case C64: { + complex64 c = + GetSparseElement(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } + default: + LOG(FATAL) << "Invalid element type for sparse arrays: " + << PrimitiveType_Name(subshape.element_type()); + } +} + +StatusOr LiteralBase::GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + return Get(multi_index); + case U8: + return Get(multi_index); + case S32: + return Get(multi_index); + case S64: + return Get(multi_index); + case U32: + return Get(multi_index); + case U64: + return Get(multi_index); + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } +} + +size_t LiteralBase::Hash() const { + using tensorflow::Hash64; + using tensorflow::Hash64Combine; + + size_t hash_value = ShapeUtil::Hash(shape()); + + ShapeUtil::ForEachSubshape( + shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsArray(subshape)) { + return; + } + + CHECK(LayoutUtil::IsDense(subshape.layout())); + hash_value = Hash64Combine( + hash_value, Hash64(static_cast(untyped_data(index)), + size_bytes(index))); + }); + + return hash_value; +} + +Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value) { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + Set(multi_index, value); + break; + case U8: + Set(multi_index, value); + break; + case S32: + Set(multi_index, value); + break; + case S64: + Set(multi_index, value); + break; + case U32: + Set(multi_index, value); + break; + case U64: + Set(multi_index, value); + break; + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } + return Status::OK(); +} + +tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Piece& p = piece(shape_index); + CHECK_GE(sparse_element_number, 0); + CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); + return p.sparse_indices()->At(sparse_element_number); +} + +void Literal::SortSparseElements(const ShapeIndex& shape_index) { + piece(shape_index).SortSparseElements(); +} + +void LiteralBase::Piece::SortSparseElements() { + switch (subshape().element_type()) { + case PRED: + SortSparseElementsInternal(); + break; + case S8: + SortSparseElementsInternal(); + break; + case U8: + SortSparseElementsInternal(); + break; + case S16: + SortSparseElementsInternal(); + break; + case U16: + SortSparseElementsInternal(); + break; + case S32: + SortSparseElementsInternal(); + break; + case U32: + SortSparseElementsInternal(); + break; + case S64: + SortSparseElementsInternal(); + break; + case U64: + SortSparseElementsInternal(); + break; + case F32: + SortSparseElementsInternal(); + break; + case F64: + SortSparseElementsInternal(); + break; + case C64: + SortSparseElementsInternal(); + break; + case F16: + SortSparseElementsInternal(); + break; + case BF16: + SortSparseElementsInternal(); + break; + default: + LOG(FATAL) << "Element type not valid for sparse array: " + << PrimitiveType_Name(subshape().element_type()); + } +} + +template +void LiteralBase::Piece::SortSparseElementsInternal() { + CHECK(LayoutUtil::IsSparseArray(subshape())); + int64 num_elements = sparse_indices()->index_count(); + auto values = data(); + CHECK_LE(num_elements, values.size()); + sparse_indices()->SortWithValues( + tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); +} + +namespace { + +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); + + auto shape_to_string = [print_layout](const Shape& shape) { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(shape); + } else { + return ShapeUtil::HumanString(shape); + } + }; + + // TODO(b/32894291): refactor this code to reduce code duplication. + if (ShapeUtil::IsTuple(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" (\n"); + std::vector tuple_pieces; + for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { + ShapeIndex element_index = shape_index; + element_index.push_back(i); + std::vector element_pieces; + ToStringHelper(literal, element_index, print_layout, &element_pieces); + tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); + } + pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); + pieces->push_back("\n)"); + return; + } + + if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + return; + } + + if (LayoutUtil::IsSparseArray(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); + } + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back( + tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); + } + pieces->push_back(literal.GetSparseElementAsString(i)); + } + pieces->push_back("}"); + return; + } + + CHECK(LayoutUtil::IsDenseArray(subshape)); + + auto element_to_string = + [&](tensorflow::gtl::ArraySlice indices) -> string { + PrimitiveType element_type = subshape.element_type(); + if (element_type == PRED) { + // We display predicates in a densely packed form. + return literal.Get(indices, shape_index) ? "1" : "0"; + } + return ((!indices.empty() && indices.back() > 0) ? ", " : "") + + literal.GetAsString(indices, shape_index); + }; + + if (ShapeUtil::Rank(subshape) == 0) { + pieces->push_back(literal.GetAsString({}, shape_index)); + } else if (ShapeUtil::Rank(subshape) == 1) { + pieces->push_back("{"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(element_to_string({i0})); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 2) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(" { "); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(element_to_string({i0, i1})); + } + pieces->push_back(" "); + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 3) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(i0 > 0 ? ",\n{" : "{"); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(i1 > 0 ? ",\n { " : " { "); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(element_to_string({i0, i1, i2})); + } + pieces->push_back(" }"); + } + pieces->push_back(" }"); + } + pieces->push_back("\n}"); + } else if (ShapeUtil::Rank(subshape) == 4) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(" {"); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(element_to_string({i0, i1, i2, i3})); + } + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); + } + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 5) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(" {"); + for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { + pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); + } + pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" + : "},\n"); + } + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); + } + pieces->push_back("}"); + } else { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {"); + literal.EachCellAsString( + [&](tensorflow::gtl::ArraySlice indices, const string& value) { + pieces->push_back(" "); + pieces->push_back(value); + }); + pieces->push_back("}"); + } +} + +} // namespace + +int64 LiteralBase::sparse_element_count() const { + CHECK(LayoutUtil::IsSparseArray(shape())); + return sparse_indices()->index_count(); +} + +string LiteralBase::ToString(bool print_layout) const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, print_layout, &pieces); + return tensorflow::str_util::Join(pieces, ""); +} + +void LiteralBase::EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const { + if (ShapeUtil::IsZeroElementArray(shape())) { + return; + } + std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( + shape(), /*linear_index=*/0); + do { + per_cell(indices, GetAsString(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); +} + +namespace { +template +std::unique_ptr ConvertBetweenNativeTypesWithConverter( + const LiteralBase& src_literal, const ConverterType& converter) { + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( + src_literal.shape(), + primitive_util::NativeToPrimitiveType())); + auto src_data = src_literal.data(); + auto dest_data = result_literal->template data(); + int64 num_elements = src_literal.element_count(); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = converter(src_data[i]); + } + return result_literal; +} + +template +std::unique_ptr ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { return static_cast(src); }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +template +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), + std::unique_ptr>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { + return tensorflow::bit_cast(src); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +// This template specialization is here to make the compiler happy. bit_cast has +// a static check that the types are the same size. This specialization should +// never be used because the source and destination types are checked for +// identical sizes higher up. +template +typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), + std::unique_ptr>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + LOG(FATAL) << "Invalid bitcast between types of different sizes."; +} + +template +std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique( + ShapeUtil::ChangeElementType(src_literal.shape(), C64)); + using NativeSrcT = + typename primitive_util::PrimitiveTypeToNative::type; + tensorflow::gtl::ArraySlice src_data = + src_literal.data(); + tensorflow::gtl::MutableArraySlice dest_data = + result_literal->data(); + int64 num_elements = src_literal.element_count(); + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = complex64(static_cast(src_data[i]), 0); + } + return result_literal; +} + +template +std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, + bool bitcast) { + CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); + if (bitcast) { + return BitcastBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } else { + return ConvertBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } +} + +template +StatusOr> ConvertIfDestTypeMatches( + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, + bool bitcast) { + switch (primitive_dest_type) { +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch(src_literal, \ + bitcast); + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F16) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) + CONVERT_IF_TYPES_MATCH(BF16) +#undef CONVERT_IF_TYPES_MATCH + case C64: + if (!bitcast) { + return ConvertToC64(src_literal); + } + break; + // Other types are not yet supported. + default: + break; + } + return Unimplemented( + "Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(src_literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); +} + +StatusOr> ConvertSwitch( + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { + TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); + if (literal.shape().element_type() == primitive_dest_type) { + return literal.CloneToUnique(); + } + switch (literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \ + bitcast); + CONVERT_IF_DEST_TYPE_MATCHES(PRED) + CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S32) + CONVERT_IF_DEST_TYPE_MATCHES(S64) + CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U32) + CONVERT_IF_DEST_TYPE_MATCHES(U64) + CONVERT_IF_DEST_TYPE_MATCHES(F16) + CONVERT_IF_DEST_TYPE_MATCHES(F32) + CONVERT_IF_DEST_TYPE_MATCHES(F64) + CONVERT_IF_DEST_TYPE_MATCHES(BF16) +#undef CONVERT_IF_DEST_TYPE_MATCHES + // Other types are not yet supported. + default: + return Unimplemented( + "%s from type %s to type %s is not implemented.", + (bitcast ? "Bitcast converting" : "Converting"), + PrimitiveType_Name(literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); + } +} + +} // namespace + +StatusOr> LiteralBase::Convert( + PrimitiveType primitive_dest_type) const { + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); +} + +StatusOr> LiteralBase::BitcastConvert( + PrimitiveType primitive_dest_type) const { + if (primitive_util::BitWidth(shape().element_type()) != + primitive_util::BitWidth(primitive_dest_type)) { + return InvalidArgument( + "Cannot bitcast convert from %s to %s, bit widths are different: %d != " + "%d", + PrimitiveType_Name(shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str(), + primitive_util::BitWidth(shape().element_type()), + primitive_util::BitWidth(primitive_dest_type)); + } + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); +} + +StatusOr> LiteralBase::ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16) const { + if (!ShapeUtil::IsTuple(dest_shape)) { + if (round_f32_to_bf16 && shape().element_type() == F32 && + dest_shape.element_type() == BF16) { + auto converter = [](float src) { + return tensorflow::bfloat16::round_to_bfloat16(src); + }; + return ConvertBetweenNativeTypesWithConverter(*this, + converter); + } + return Convert(dest_shape.element_type()); + } + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + auto element = LiteralSlice(*this, {i}); + TF_ASSIGN_OR_RETURN( + auto new_element, + element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); + elements.push_back(std::move(*new_element)); + } + auto converted = MakeUnique(); + *converted = Literal::MoveIntoTuple(&elements); + return std::move(converted); +} + +/* static */ Literal Literal::MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements) { + std::vector element_shapes; + for (const Literal& element : elements) { + element_shapes.push_back(element.shape()); + } + Literal literal(ShapeUtil::MakeTupleShape(element_shapes), + /*allocate_arrays=*/false); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK( + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); + } + return literal; +} + +template +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector* multi_index) const { + if (multi_index->size() == ShapeUtil::Rank(subshape())) { + return (Get(*multi_index) == other.Get(*multi_index)); + } + for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { + multi_index->push_back(i); + if (!EqualElementsInternal(other, multi_index)) { + return false; + } + multi_index->pop_back(); + } + return true; +} + +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { + DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); + + std::vector multi_index; + switch (subshape().element_type()) { + case PRED: + return EqualElementsInternal(other, &multi_index); + case U8: + return EqualElementsInternal(other, &multi_index); + case S32: + return EqualElementsInternal(other, &multi_index); + case S64: + return EqualElementsInternal(other, &multi_index); + case U32: + return EqualElementsInternal(other, &multi_index); + case U64: + return EqualElementsInternal(other, &multi_index); + case F32: + return EqualElementsInternal(other, &multi_index); + case F64: + return EqualElementsInternal(other, &multi_index); + case F16: + return EqualElementsInternal(other, &multi_index); + case BF16: + return EqualElementsInternal(other, &multi_index); + case C64: + return EqualElementsInternal(other, &multi_index); + default: + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " + << PrimitiveType_Name(subshape().element_type()); + } +} + +bool LiteralBase::operator==(const LiteralBase& other) const { + if (!ShapeUtil::Compatible(shape(), other.shape())) { + return false; + } + + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); +} + +namespace { + +template +static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, + NativeT value) { + for (int64 i = 0; i < data.size(); ++i) { + if (data[i] != value) { + return false; + } + } + return true; +} + +} // namespace + +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case U8: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case U32: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case U64: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case S8: + return AllElementsEqualValue(piece.data(), value); + case S32: + return AllElementsEqualValue(piece.data(), value); + case S64: + return AllElementsEqualValue(piece.data(), value); + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case PRED: + if (value == 0) { + return AllElementsEqualValue(piece.data(), false); + } + if (value == 1) { + return AllElementsEqualValue(piece.data(), true); + } + return false; + default: + return false; + } + return false; + }; + + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsAllComplex(complex64 value) const { + switch (shape().element_type()) { + case C64: + return AllElementsEqualValue(root_piece().data(), + value); + default: + return false; + } +} + +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::IsZeroElementArray(piece.subshape())) { + return false; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; + + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { + CHECK(ShapeUtil::IsArray(shape())); + switch (shape().element_type()) { + case U8: + return Get(indices) == 0; + case U32: + return Get(indices) == 0; + case U64: + return Get(indices) == 0; + case S8: + return Get(indices) == 0; + case S32: + return Get(indices) == 0; + case S64: + return Get(indices) == 0; + case F32: + return Get(indices) == 0.0f; + case F64: + return Get(indices) == 0.0; + case C64: + return Get(indices) == complex64(0.0f, 0.0f); + case F16: + return Get(indices) == static_cast(0.0f); + case BF16: + return Get(indices) == static_cast(0.0f); + case PRED: + return Get(indices) == false; + default: + LOG(FATAL) << "Input literal must be an array."; + } +} + +namespace { + +template +void CopyToRepeatedField(RepeatedFieldT* dest, + const tensorflow::gtl::ArraySlice src) { + *dest = RepeatedFieldT(src.begin(), src.end()); +} + +} // namespace + +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { + *proto->mutable_shape() = subshape(); + switch (subshape().element_type()) { + case PRED: + CopyToRepeatedField(proto->mutable_preds(), data()); + break; + case U8: + proto->set_u8s(static_cast(data().data()), + element_count()); + break; + case U32: + CopyToRepeatedField(proto->mutable_u32s(), data()); + break; + case U64: + CopyToRepeatedField(proto->mutable_u64s(), data()); + break; + case S32: + CopyToRepeatedField(proto->mutable_s32s(), data()); + break; + case S64: + CopyToRepeatedField(proto->mutable_s64s(), data()); + break; + case F16: + *proto->mutable_f16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_f16s()); + } + break; + case BF16: + *proto->mutable_bf16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_bf16s()); + } + break; + case F32: + CopyToRepeatedField(proto->mutable_f32s(), data()); + break; + case F64: + CopyToRepeatedField(proto->mutable_f64s(), data()); + break; + case C64: + for (complex64 value : data()) { + proto->add_c64s(value.real()); + proto->add_c64s(value.imag()); + } + break; + case TUPLE: + case TOKEN: + // Nothing to do but assign the shape which is done above. + return; + default: + LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); + } +} + +const void* LiteralBase::Piece::untyped_data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); +} + +void* LiteralBase::Piece::untyped_data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); +} + +namespace { + +template +Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, + const RepeatedFieldT& src) { + if (dest.size() != src.size()) { + return InvalidArgument( + "Expected %lu elements in LiteralProto repeated field, has %d", + dest.size(), src.size()); + } + std::copy(src.begin(), src.end(), dest.begin()); + return Status::OK(); +} + +} // namespace + +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { + // These conditions should have been checked in Literal::CreateFromProto. + TF_RET_CHECK(proto.has_shape()); + TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); + TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + + switch (subshape().element_type()) { + case PRED: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); + break; + case U8: { + auto u8_data = data(); + TF_RET_CHECK(proto.u8s().size() == u8_data.size()); + std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin()); + } break; + case S32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s32s())); + break; + case S64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s64s())); + break; + case U32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u32s())); + break; + case U64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); + break; + case F16: { + const string& s(proto.f16s()); + TF_RET_CHECK(data().size() * sizeof(half) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; + + case BF16: { + const string& s(proto.bf16s()); + TF_RET_CHECK(data().size() * sizeof(bfloat16) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; + case F32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f32s())); + break; + case F64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f64s())); + break; + case C64: { + auto complex_data = data(); + TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2); + for (int64 i = 0; i < complex_data.size(); ++i) { + complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; + } + } break; + case TUPLE: + LOG(FATAL) << "Should not be called on tuple shapes: " + << ShapeUtil::HumanString(subshape()); + break; + default: + LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); + } + return Status::OK(); +} + +LiteralProto LiteralBase::ToProto() const { + LiteralProto proto; + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); + + if (LayoutUtil::IsSparseArray(shape())) { + CopyToRepeatedField(proto.mutable_sparse_indices(), + sparse_indices()->data()); + } + + return proto; +} + +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { + return piece(shape_index).untyped_data(); +} + +void* Literal::untyped_data(const ShapeIndex& shape_index) { + return piece(shape_index).untyped_data(); +} + +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { + return piece(shape_index).size_bytes(); +} + +string LiteralBase::GetR1U8AsString() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(shape().element_type(), U8); + return string(tensorflow::bit_cast(data().data()), + ShapeUtil::ElementsIn(shape())); +} + +void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { + CHECK(ShapeUtil::IsTuple(shape)); + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + if (ShapeUtil::IsTuple(subshape)) { + BuildPieceSubtree(subshape, &child_piece); + } + + piece->emplace_back(std::move(child_piece)); + } +} + +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} + +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} + +BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = Piece(); + root_piece_.set_buffer(const_cast(src_buf_ptr)); + root_piece_.set_subshape(shape_.get()); +} + +BorrowingLiteral::BorrowingLiteral( + tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); + root_piece_ = Piece(); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + const auto& src_shape = shape_->tuple_shapes(i); + CHECK(ShapeUtil::IsArray(src_shape)); + root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h new file mode 100644 index 00000000000..dd67dfa8d4a --- /dev/null +++ b/tensorflow/compiler/xla/literal.h @@ -0,0 +1,1152 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_ +#define TENSORFLOW_COMPILER_XLA_LITERAL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/sparse_index_array.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Forward declare Literal and LiteralSlice class to be used by the creation +// methods in the base class. +class Literal; +class LiteralSlice; + +// Abstract base class for literals. +class LiteralBase { + public: + virtual ~LiteralBase() = 0; + + // Literals are equal if they have compatible shapes and the same data + // values. Layout is not compared. + bool operator==(const LiteralBase& other) const; + bool operator!=(const LiteralBase& other) const { return !(*this == other); } + + // Returns the shape of the literal. + const Shape& shape() const { return root_piece().subshape(); } + + // Serialize to proto. + LiteralProto ToProto() const; + + // Returns an ArraySlice of the array for this literal for the given NativeT + // (e.g., float). CHECKs if the subshape of the literal at the given + // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type + // to native type. + template + tensorflow::gtl::ArraySlice data( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to the sparse index array. Returns nullptr if the + // literal is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to (or size of) the underlying buffer holding the + // array at the given shape index. CHECKs if the subshape of the literal at + // the given ShapeIndex is not array. + const void* untyped_data(const ShapeIndex& shape_index = {}) const; + int64 size_bytes(const ShapeIndex& shape_index = {}) const; + + // Returns this literal's data as a string. This literal must be a rank-1 U8 + // array. + string GetR1U8AsString() const; + + // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. + string ToString(bool print_layout = false) const; + + // Gets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + // Overloads of Get for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + NativeT GetFirstElement() const; + + // As Get(), but determines the correct type and converts the value + // into text. + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + // As Get(), but determines the correct type and converts the value into + // int64. This literal must be an array. + StatusOr GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + template + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + // + // This literal must have a dense layout. + void EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const; + template + void EachCell(std::function indices, + NativeT value)> + per_cell) const; + + // Returns whether every element in this literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. Also if this literal is not array-shaped false is returned. + bool IsAll(int8 value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. Also if this literal is not array-shaped false is returned. + bool IsAllFloat(float value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. + bool IsAllComplex(complex64 value) const; + + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + + // Returns whether this literal is zero at the specified index. This literal + // must be an array with a dense layout. + bool IsZero(tensorflow::gtl::ArraySlice indices) const; + + // Returns the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Returns the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + // Compute a hash for this literal. This literal must not be a sparse tensor + // or a tuple containing a sparse tensor. + size_t Hash() const; + + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + + // Converts this literal to another primitive type using a bitcast + // conversion. The to and from primitive types must have the same bit + // width. Returns an error if the conversion is not possible. This literal + // must be array-shaped. + StatusOr> BitcastConvert( + PrimitiveType primitive_dest_type) const; + + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. This literal must be array-shaped. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; + + // Clones the underlying buffers into a new Literal, or new + // std::unique_ptr. + Literal Clone() const; + std::unique_ptr CloneToUnique() const; + + // TODO(b/67651157): The methods below which perform computation on Literals + // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with + // evaluator code which operates on Literals. + // + // Creates a new value that has the equivalent value as this + // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, + // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. + // + // For tuple shaped literals, shape_index should be used to select the inner + // array that the new layout applies to. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + std::unique_ptr Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; + + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The + // implementation currently only supports monotonic dim0-major layouts. + // This literal must be an array. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by broadcasting this literal with `dimensions` to + // yield a literal of shape `result_shape`. + StatusOr> Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by reordering the dimensions of this literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + // This literal must be an array. + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; + + // Creates a sub-array from this literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + // This literal must be an array. + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this + // literal replicated four times. + // This literal must be an array. + template + std::unique_ptr Replicate(int64 times) const; + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + // + // Note: It's an antipattern to use this method then immediately call + // Literal::Populate on the result (since that results in zero initialization, + // then reinitialization. Conside if a call to MakeUnique(shape), + // followed by the call to Literal::Populate can be used instead. + static std::unique_ptr CreateFromShape(const Shape& shape); + + protected: + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Returns the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Returns the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } + + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } + + // Returns the number of elements in this piece's array. + int64 element_count() const { + // If this is a sparse array, use the number of elements represented by + // the indices in the associated SparseIndexArray. + return LayoutUtil::IsSparseArray(subshape()) + ? sparse_indices()->index_count() + : ShapeUtil::ElementsIn(subshape()); + } + + // Returns the child piece at 'index' of this piece. + Piece& child(int64 index) { return children_[index]; } + + // Adds a child piece to this piece's children. + void emplace_back(Piece child_piece) { + children_.emplace_back(std::move(child_piece)); + } + + // Returns the size of children pieces of this piece. + int64 children_size() { return children_.size(); } + + // Visitor functions that recursively traverses the piece and calls the + // given function at each child piece. The function has the type: + // void (const ShapeIndex& index, const Piece& piece) + template + void ForEachSubpiece(const Fn& func) const { + ShapeIndex index; + return ForEachHelper( + [&func](const ShapeIndex& index, const Piece& piece) { + func(index, piece); + return Status::OK(); + }, + *this, &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, const Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachSubpieceWithStatus(const Fn& func) const { + ShapeIndex index; + return ForEachHelper(func, *this, &index); + } + // Same as above, but the function has the type: + // Bool (const ShapeIndex& index, const Piece& piece) + // The first non-true return value is returned by the function. + template + bool ForEachSubpieceWithBool(const Fn& func) const { + ShapeIndex index; + return ForEachHelperBool(func, *this, &index); + } + // Same as above, but the function has the type: + // Void (const ShapeIndex& index, Piece& piece) + template + void ForEachMutableSubpiece(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + [&func](const ShapeIndex& index, Piece* piece) { + func(index, piece); + return Status::OK(); + }, + const_cast(this), &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachMutableSubpieceWithStatus(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + func, const_cast(this), &index); + } + + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; + + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; + + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); + + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); + + // Sorts the elements in a sparse array. + void SortSparseElements(); + + private: + // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. + // The first non-OK (or non-true) value is returned by the function. + // The callable 'func' has the same signature as described above in + // ForEachSubpiece*. + template + Status ForEachHelper(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + template + bool ForEachHelperBool(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + if (!func(*index, piece)) { + return false; + } + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + if (!ForEachHelperBool(func, piece.children_[i], index)) { + return false; + } + index->pop_back(); + } + return true; + } + template + Status ForEachMutableHelper(const Fn& func, Piece* piece, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece->children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR( + ForEachMutableHelper(func, &piece->children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; + + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); + + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; + + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + + // Children pieces for tuple shaped pieces. + std::vector children_ = {}; + }; // class Piece + + const Piece& piece(const ShapeIndex& shape_index) const { + Piece* piece = &const_cast(root_piece()); + for (const auto i : shape_index) { + DCHECK_GE(i, 0); + DCHECK_LT(i, piece->children_size()); + piece = &piece->child(i); + } + return *piece; + } + + // Returns the piece at the root of the shape. + virtual const Piece& root_piece() const = 0; + + // LiteralSlice and Literal must access Pieces of other Literals. + friend class Literal; + friend class LiteralSlice; + friend class BorrowingLiteral; + + private: + template + std::unique_ptr SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const; +}; + +// Class representing literal values in XLA. +// +// The underlying buffer and shape is always owned by this class. +class Literal : public LiteralBase { + public: + Literal() : Literal(ShapeUtil::MakeNil()) {} + + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); + + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); + Literal& operator=(Literal&& other); + + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } + + // Returns a MutableArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::data; + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // Returns a pointer to the underlying buffer holding the array at the given + // shape index. CHECKs if the subshape of the literal at the given ShapeIndex + // is not array. + void* untyped_data(const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::untyped_data; + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to this literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and this literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); + + // Sets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); + // Overloads of Set for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. + template + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); + + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); + + // Populate this literal with the given values. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // literal.PopulateR2FromArray2D(values); + // + // // Populate with int32s. + // literal.PopulateR2({{1, 2}, {3, 4}}); + // + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. + template + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); + template + void PopulateR2(std::initializer_list> values); + template + void PopulateFromArray(const Array& values); + template + void PopulateR2FromArray2D(const Array2D& values); + template + void PopulateR3FromArray3D(const Array3D& values); + template + void PopulateR4FromArray4D(const Array4D& values); + + // Populates literal values by calling the generator function for every cell + // in this literal object. + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. + template + Status Populate(const FnType& generator); + + // A parallel version of Populate(). This can be used if the generator is + // thread-safe and the values for the shape's different elements are + // independent. + template + Status PopulateParallel(const FnType& generator); + + // Fills this literal with the given value. + template + void PopulateWithValue(NativeT value); + + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); + + // Serialize from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); + + private: + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); + + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return const_cast(LiteralBase::piece(shape_index)); + } + + Piece& root_piece() const override { return *root_piece_; }; + + // Internal template helper for the Literal::CopySliceFrom(), matching its + // arguments one by one. + template + Status CopySliceFromInternal(const LiteralBase& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Utility structure which is used to create the optimal configuration for + // a ShapeUtil::ForEachIndex() scan across two literals. + struct StrideConfig { + StrideConfig(const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions); + + // The dimensions of the stride operation. Essentially every dimension + // will be iterated from base[i] to base[i]+dimensions[i], in step[i] + // steps. + tensorflow::gtl::ArraySlice dimensions; + DimensionVector base; + DimensionVector step; + int64 minor_dimension = 0; + // The size of the strides for source and destination. One of the two + // (the one looping through its most minor dimension) will be 1, while + // the other will be the stride size at the dimension matching the other + // shape most minor dimension being scanned. + int64 dest_stride = 1; + int64 source_stride = 1; + // The size of the inner loop on the most minor dimension. + int64 minor_loop_size = 1; + }; + + // Literal class always owns the shape. The parent class borrows this shape. + std::unique_ptr shape_; + + Piece* root_piece_ = nullptr; + + // Implementation details shared between Populate() and PopulateParallel() + template + Status PopulateInternal(const FnType& generator, bool parallel); + + // Deallocate the buffers held by this literal. + void DeallocateBuffers(); + + friend class LiteralBase; +}; +std::ostream& operator<<(std::ostream& out, const Literal& literal); + +// A read-only view of a Literal. A LiteralSlice contains pointers to shape and +// literal buffers always owned by others. +class LiteralSlice : public LiteralBase { + public: + LiteralSlice() : LiteralBase() {} + + // Implicit conversion constructors. + LiteralSlice(const LiteralBase& literal); + LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); + + private: + const Piece& root_piece() const override { return *root_piece_; }; + + const Piece* root_piece_; // Not owned. +}; + +// A read-only Literal where the underlying buffers are never owned by this +// class. +class BorrowingLiteral : public LiteralBase { + public: + BorrowingLiteral() : LiteralBase() {} + + // 'src_buf_ptr' is not owned by this class and must outlive the + // lifetime of this class. It points to an appropirately sized buffer with + // data interpretered as indicated by 'shape'. + // This constructor is only used for array shapes. + BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Similar as above, except to be used for constructing non-nested tuples. + BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + const Shape& shape); + // TODO(b/79707221): adding constructors for nested tuples as well. + + private: + // Recursively builds the subtree for the given piece and sets the subshapes + // of the given piece with the given shape. + void BuildPieceSubtree(const Shape& shape, Piece* piece); + + // Accessor for the root piece of this literal. + const Piece& root_piece() const override { return root_piece_; }; + Piece root_piece_; + + // Shape of this literal. Stored as unique_ptr so such that the (default) + // move construction of this class would be trivially correct: the pointer to + // Shape root_piece_ stores will still point to the correct address. + std::unique_ptr shape_; +}; + +template +tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) + << "Attempting to access " + << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) + << " type, but literal element type is " + << PrimitiveType_Name(subshape().element_type()); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(buffer()), element_count()); +} + +template +tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) + << "Attempting to access " + << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) + << " type, but literal element type is " + << PrimitiveType_Name(subshape().element_type()); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(buffer()), element_count()); +} + +template +NativeT LiteralBase::Piece::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(LayoutUtil::IsDenseArray(subshape())); + return data()[IndexUtil::MultidimensionalIndexToLinearIndex( + subshape(), multi_index)]; +} + +template +void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + CHECK(LayoutUtil::IsDenseArray(subshape())); + data()[IndexUtil::MultidimensionalIndexToLinearIndex( + subshape(), multi_index)] = value; +} + +template +tensorflow::gtl::ArraySlice LiteralBase::data( + const ShapeIndex& shape_index) const { + return piece(shape_index).data(); +} + +template +tensorflow::gtl::MutableArraySlice Literal::data( + const ShapeIndex& shape_index) { + return piece(shape_index).data(); +} + +template +inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { + return piece(shape_index).Get(multi_index); +} + +template +inline NativeT LiteralBase::Get( + tensorflow::gtl::ArraySlice multi_index) const { + return root_piece().Get(multi_index); +} + +template +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value) { + return piece(shape_index).Set(multi_index, value); +} + +template +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + return root_piece().Set(multi_index, value); +} + +template +NativeT LiteralBase::GetFirstElement() const { + return data().at(0); +} + +template +NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { + CHECK( + LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); + return data(shape_index)[sparse_element_number]; +} + +template +void Literal::AppendSparseElement( + tensorflow::gtl::ArraySlice multi_index, NativeT value, + const ShapeIndex& shape_index) { + Piece& p = piece(shape_index); + const Shape& subshape = p.subshape(); + CHECK(LayoutUtil::IsSparseArray(subshape)); + int64 rank = ShapeUtil::Rank(subshape); + CHECK_EQ(multi_index.size(), rank); + int64 last_element = p.sparse_indices()->index_count(); + CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); + p.sparse_indices()->Append(multi_index); + CHECK_LT(last_element, p.data().size()); + p.data()[last_element] = value; +} + +template +void LiteralBase::EachCell( + std::function indices, + NativeT value)> + per_cell) const { + if (ShapeUtil::IsZeroElementArray(shape())) { + return; + } + std::vector indices(ShapeUtil::Rank(shape()), 0); + do { + per_cell(indices, Get(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); +} + +template +inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + for (int64 i = 0; i < values.size(); ++i) { + Set({i}, values[i]); + } +} + +template +void Literal::PopulateR2( + std::initializer_list> values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 2); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + + const int64 dim0_size = values.size(); + const int64 dim1_size = values.begin()->size(); + CHECK_EQ(dim0_size, shape().dimensions(0)); + CHECK_EQ(dim1_size, shape().dimensions(1)); + + int64 dim0 = 0; + for (auto inner_list : values) { + int64 dim1 = 0; + for (auto value : inner_list) { + Set({dim0, dim1}, value); + ++dim1; + } + CHECK_EQ(dim1_size, dim1); + ++dim0; + } +} + +template +void Literal::PopulateFromArray(const Array& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); + for (int dim = 0; dim < values.num_dimensions(); ++dim) { + CHECK_EQ(values.dim(dim), shape().dimensions(dim)); + } + values.Each([this](tensorflow::gtl::ArraySlice indices, + NativeT value) { this->Set(indices, value); }); +} + +template +void Literal::PopulateR2FromArray2D(const Array2D& values) { + PopulateFromArray(values); +} + +template +void Literal::PopulateR3FromArray3D(const Array3D& values) { + PopulateFromArray(values); +} + +template +void Literal::PopulateR4FromArray4D(const Array4D& values) { + PopulateFromArray(values); +} + +template +void Literal::PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort) { + CHECK(LayoutUtil::IsSparseArray(shape())); + int rank = ShapeUtil::Rank(shape()); + CHECK_EQ(indices.rank(), rank); + int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); + CHECK_LE(indices.max_indices(), max_elements); + int64 num_elements = values.size(); + CHECK_LE(num_elements, max_elements); + CHECK_EQ(num_elements, indices.index_count()); + auto root_data = root_piece().data(); + // Piece::data() returns an ArraySlice of size equal to the number of indices + // in the SparseIndexArray. So there is no need to adjust the size of the data + // here. It is enough to just copy the incoming values into the data buffer. + std::copy(values.begin(), values.end(), root_data.begin()); + *this->root_piece().sparse_indices() = std::move(indices); + if (sort) { + auto root_data = this->root_piece().data(); + this->root_piece().sparse_indices()->SortWithValues(root_data); + } + DCHECK(this->root_piece().sparse_indices()->Validate(shape())); +} + +template +Status Literal::PopulateInternal(const FnType& generator, bool parallel) { + const Shape& this_shape = shape(); + const int64 rank = ShapeUtil::Rank(this_shape); + TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); + TF_RET_CHECK(this_shape.element_type() == + primitive_util::NativeToPrimitiveType()); + tensorflow::gtl::MutableArraySlice literal_data = data(); + if (rank > 0) { + StrideConfig stride_config(this_shape, this_shape, + AsInt64Slice(this_shape.dimensions())); + int64 minor_dimension_size = + ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); + + auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { + DimensionVector minor_scan_indexes(rank, 0); + const int64 index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); + std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); + for (int64 i = 0; i < minor_dimension_size; ++i) { + minor_scan_indexes[stride_config.minor_dimension] = i; + literal_data.at(index + i) = generator(minor_scan_indexes); + } + }; + if (parallel) { + ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base, + stride_config.dimensions, + stride_config.step, init_function); + } else { + ShapeUtil::ForEachIndex( + this_shape, stride_config.base, stride_config.dimensions, + stride_config.step, + [&init_function](tensorflow::gtl::ArraySlice indexes) { + init_function(indexes); + return true; + }); + } + } else { + // For scalars. + literal_data.at(0) = generator({}); + } + return Status::OK(); +} +template +Status Literal::Populate(const FnType& generator) { + return PopulateInternal(generator, /*parallel=*/false); +} + +template +Status Literal::PopulateParallel(const FnType& generator) { + return PopulateInternal(generator, /*parallel=*/true); +} + +template +void Literal::PopulateWithValue(NativeT value) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + for (NativeT& element : data()) { + element = value; + } +} + +template +std::unique_ptr LiteralBase::Replicate(int64 times) const { + DimensionVector bounds = {times}; + bounds.reserve(shape().dimensions_size() + 1); + for (int64 bound : shape().dimensions()) { + bounds.push_back(bound); + } + auto literal = + MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); + int64 elements = ShapeUtil::ElementsIn(literal->shape()); + if (elements == 0) { + return literal; + } + + DimensionVector output_indices(bounds.size(), 0); + tensorflow::gtl::ArraySlice input_indices = output_indices; + input_indices.remove_prefix(1); + + bool done = false; + while (!done) { + const auto element = Get(input_indices); + literal->Set(output_indices, element); + + done = true; + for (int n = 0; n < output_indices.size(); ++n) { + ++output_indices[n]; + if (output_indices[n] < bounds[n]) { + done = false; + break; + } + output_indices[n] = 0; + } + } + return literal; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LITERAL_H_ diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 2125ab7c61a..94993cc8744 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -217,7 +218,7 @@ class NearComparator { return Printf( "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), - Literal::MultiIndexAsString( + LiteralUtil::MultiIndexAsString( IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index)) .c_str(), @@ -722,7 +723,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { return AppendStatus(result, tensorflow::strings::Printf( "\nat index: %s\nexpected: %s\nactual: %s", - Literal::MultiIndexAsString(multi_index).c_str(), + LiteralUtil::MultiIndexAsString(multi_index).c_str(), ToStringTruncated(expected).c_str(), ToStringTruncated(actual).c_str())); } diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h index 00a13e36193..9e5bf7c1d06 100644 --- a/tensorflow/compiler/xla/literal_comparison.h +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -20,7 +20,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ #include "tensorflow/compiler/xla/error_spec.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/lib/core/status.h" namespace xla { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_test.cc similarity index 76% rename from tensorflow/compiler/xla/literal_util_test.cc rename to tensorflow/compiler/xla/literal_test.cc index 493d807591d..e8f919950f0 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" @@ -76,11 +77,11 @@ class LiteralUtilTest : public ::testing::Test { layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3}); literal_r4_2x2x3x3_dim0major_ = - Literal::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0major_); + LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0major_); literal_r4_2x2x3x3_dim0minor_ = - Literal::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0minor_); + LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0minor_); } Layout layout_r2_dim0major_; @@ -94,47 +95,47 @@ class LiteralUtilTest : public ::testing::Test { }; TEST_F(LiteralUtilTest, LiteralScalarToString) { - auto true_lit = Literal::CreateR0(true); + auto true_lit = LiteralUtil::CreateR0(true); ASSERT_EQ("true", true_lit->ToString()); - auto false_lit = Literal::CreateR0(false); + auto false_lit = LiteralUtil::CreateR0(false); ASSERT_EQ("false", false_lit->ToString()); - auto u32_lit = Literal::CreateR0(42); + auto u32_lit = LiteralUtil::CreateR0(42); ASSERT_EQ("42", u32_lit->ToString()); - auto s32_lit = Literal::CreateR0(-999); + auto s32_lit = LiteralUtil::CreateR0(-999); ASSERT_EQ("-999", s32_lit->ToString()); - auto f32_lit = Literal::CreateR0(3.14f); + auto f32_lit = LiteralUtil::CreateR0(3.14f); ASSERT_EQ("3.14", f32_lit->ToString()); - auto f16_lit = Literal::CreateR0(static_cast(0.5f)); + auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); ASSERT_EQ("0.5", f16_lit->ToString()); - auto c64_lit = Literal::CreateR0({3.14f, 2.78f}); + auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); - auto bf16_lit = Literal::CreateR0(static_cast(0.5f)); + auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); ASSERT_EQ("0.5", bf16_lit->ToString()); // 3.14 will be truncated to 3.125 in bfloat16 format. auto bf16_lit_truncated = - Literal::CreateR0(static_cast(3.14f)); + LiteralUtil::CreateR0(static_cast(3.14f)); ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); auto bf16_lit_truncated2 = - Literal::CreateR0(static_cast(9.001f)); + LiteralUtil::CreateR0(static_cast(9.001f)); ASSERT_EQ("9", bf16_lit_truncated2->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { - auto pred_vec = Literal::CreateR1({true, false, true}); + auto pred_vec = LiteralUtil::CreateR1({true, false, true}); ASSERT_EQ("{101}", pred_vec->ToString()); } TEST_F(LiteralUtilTest, R2ToString) { - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); const string expected = R"(s32[3,2] { { 1, 2 }, { 3, 4 }, @@ -144,7 +145,8 @@ TEST_F(LiteralUtilTest, R2ToString) { } TEST_F(LiteralUtilTest, R3ToString) { - const auto literal = Literal::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); + const auto literal = + LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); const string expected = R"(s32[3,2,1] { { { 1 }, { 2 } }, @@ -157,9 +159,9 @@ TEST_F(LiteralUtilTest, R3ToString) { } TEST_F(LiteralUtilTest, TupleToString) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -182,7 +184,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { }); // clang-format on - auto literal = Literal::CreateR3FromArray3D(array_3d); + auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); string result = literal->ToString(); const string expected = R"(f32[2,3,2] { @@ -205,7 +207,7 @@ TEST_F(LiteralUtilTest, CreateSparse) { {3, 5, 6}, }; std::vector values = {7, 8, 9, 10}; - auto literal = Literal::CreateSparse( + auto literal = LiteralUtil::CreateSparse( dimensions, SparseIndexArray(indices.n1() + 3, indices), values); Array2D expected_indices = { @@ -224,7 +226,7 @@ TEST_F(LiteralUtilTest, CreateSparse) { TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off - auto literal = Literal::CreateR4Projected({ + auto literal = LiteralUtil::CreateR4Projected({ {1, 2}, {1001, 1002}, {2001, 2002}, @@ -284,7 +286,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format off - auto literal = Literal::CreateR2({ + auto literal = LiteralUtil::CreateR2({ {3.1f, 4.2f}, {9.3f, 12.4f}, }); @@ -303,26 +305,27 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { TEST_F(LiteralUtilTest, ScalarEquality) { // Test equality with scalars. - auto f32_42 = Literal::CreateR0(42.0); - auto f32_42_clone = Literal::CreateR0(42.0); + auto f32_42 = LiteralUtil::CreateR0(42.0); + auto f32_42_clone = LiteralUtil::CreateR0(42.0); EXPECT_EQ(*f32_42, *f32_42); EXPECT_EQ(*f32_42, *f32_42_clone); - auto f32_123 = Literal::CreateR0(123.0); + auto f32_123 = LiteralUtil::CreateR0(123.0); EXPECT_NE(*f32_42, *f32_123); - auto f64_42 = Literal::CreateR0(42.0); + auto f64_42 = LiteralUtil::CreateR0(42.0); EXPECT_NE(*f32_42, *f64_42); } TEST_F(LiteralUtilTest, NonScalarEquality) { // Test equality with nonscalars. - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_clone = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_different = Literal::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); - auto vector_literal = Literal::CreateR1({1.0, 2.0, 3.0, 4.0}); - auto scalar = Literal::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_clone = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_different = + LiteralUtil::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); + auto vector_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); + auto scalar = LiteralUtil::CreateR0(1.0); Literal nil(ShapeUtil::MakeNil()); EXPECT_EQ(*matrix, *matrix); @@ -335,19 +338,19 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { } TEST_F(LiteralUtilTest, TokenEquality) { - auto token0 = Literal::CreateToken(); - auto token1 = Literal::CreateToken(); - auto scalar = Literal::CreateR0(1.0); + auto token0 = LiteralUtil::CreateToken(); + auto token1 = LiteralUtil::CreateToken(); + auto scalar = LiteralUtil::CreateR0(1.0); EXPECT_EQ(*token0, *token1); EXPECT_NE(*token0, *scalar); - EXPECT_EQ(*Literal::MakeTuple({token0.get()}), - *Literal::MakeTuple({token0.get()})); - EXPECT_EQ(*Literal::MakeTuple({token0.get(), scalar.get()}), - *Literal::MakeTuple({token1.get(), scalar.get()})); - EXPECT_NE(*Literal::MakeTuple({token0.get(), scalar.get()}), - *Literal::MakeTuple({scalar.get(), token1.get()})); + EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}), + *LiteralUtil::MakeTuple({token0.get()})); + EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), + *LiteralUtil::MakeTuple({token1.get(), scalar.get()})); + EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), + *LiteralUtil::MakeTuple({scalar.get(), token1.get()})); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { @@ -371,43 +374,46 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) { TEST_F(LiteralUtilTest, TupleEquality) { // Test equality with tuples. - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. - auto scalar_clone = Literal::CreateR0(1.0); - auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()}); + auto scalar_clone = LiteralUtil::CreateR0(1.0); + auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); EXPECT_EQ(*tuple1, *tuple2); // Tuple with elements reversed. - auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()}); + auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); EXPECT_NE(*tuple1, *reversed_tuple); // Tuple with different value. - auto scalar_42 = Literal::CreateR0(42.0); - auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()}); + auto scalar_42 = LiteralUtil::CreateR0(42.0); + auto different_tuple = + LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); EXPECT_NE(*tuple1, *different_tuple); } TEST_F(LiteralUtilTest, C64Equality) { // Test equality with tuples. - auto vector = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. - auto vector_clone = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_clone = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); EXPECT_EQ(*vector, *vector_clone); - auto vector_reversed = Literal::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); + auto vector_reversed = + LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); EXPECT_NE(*vector, *vector_reversed); } TEST_F(LiteralUtilTest, IsAllTuple) { - auto element1 = Literal::CreateR0(0.0); - auto element2 = Literal::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = Literal::MakeTuple({element1.get(), element1.get()}); + auto element1 = LiteralUtil::CreateR0(0.0); + auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); + auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); // Tuples should always return false for IsAll. EXPECT_FALSE(tuple->IsAll(0)); @@ -416,140 +422,141 @@ TEST_F(LiteralUtilTest, IsAllTuple) { // Verifies that CreateFromShape works for tuples. TEST_F(LiteralUtilTest, CreateFromShapeTuple) { - auto scalar = Literal::CreateR0(0.0); - auto matrix = Literal::CreateR2({{0, 0}, {0, 0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = LiteralUtil::CreateR0(0.0); + auto matrix = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); auto x = Literal::CreateFromShape(tuple->shape()); EXPECT_EQ(*tuple, *x); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(Literal::CreateR0(false)->IsAll(0)); - EXPECT_TRUE(Literal::CreateR0(true)->IsAll(1)); - EXPECT_FALSE(Literal::CreateR0(false)->IsAll(1)); - EXPECT_FALSE(Literal::CreateR0(false)->IsAll(2)); - EXPECT_FALSE(Literal::CreateR0(true)->IsAll(0)); - EXPECT_FALSE(Literal::CreateR0(true)->IsAll(2)); - EXPECT_FALSE(Literal::CreateR0(true)->IsAll(-1)); + EXPECT_TRUE(LiteralUtil::CreateR0(false)->IsAll(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(true)->IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE(Literal::CreateR0(255)->IsAll(int8_min)); + EXPECT_FALSE(LiteralUtil::CreateR0(255)->IsAll(int8_min)); - EXPECT_TRUE(Literal::CreateR0(42.0)->IsAll(42)); - EXPECT_FALSE(Literal::CreateR0(42.0001)->IsAll(42)); + EXPECT_TRUE(LiteralUtil::CreateR0(42.0)->IsAll(42)); + EXPECT_FALSE(LiteralUtil::CreateR0(42.0001)->IsAll(42)); - EXPECT_TRUE(Literal::CreateR1({100, 100, 100})->IsAll(100)); - EXPECT_FALSE(Literal::CreateR1({100, 100, 100.001})->IsAll(100)); + EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100})->IsAll(100)); + EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001})->IsAll(100)); - EXPECT_TRUE(Literal::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE(Literal::CreateR2({{h8}, {h8}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}})->IsAll(8)); bfloat16 b8(8.0f); bfloat16 b9(9.0f); - EXPECT_TRUE(Literal::CreateR2({{b8}, {b8}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{b8}, {b9}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{b9}, {b8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}})->IsAll(8)); // 9.001 will be truncated to 9.0 bfloat16 b91(9.001f); bfloat16 b90(9.00f); - EXPECT_TRUE(Literal::CreateR2({{b91}, {b90}})->IsAll(9.0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}})->IsAll(9.0)); complex64 c8_9 = {8, 9}; - EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); auto uint64_max = std::numeric_limits::max(); - EXPECT_FALSE(Literal::CreateR2( + EXPECT_FALSE(LiteralUtil::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) ->IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(Literal::CreateR0(false)->IsAllFloat(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); - EXPECT_TRUE( - Literal::CreateR2({{.5, .5, .5}, {.5, .5, .5}})->IsAllFloat(.5)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}) + ->IsAllFloat(.5)); - EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsAllComplex) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(Literal::CreateR0(false)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}}) + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}) ->IsAllComplex({8.0f, 9.0f})); - EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}}) + EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}) ->IsAllComplex({8.0f, 9.0f})); - EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c7_9}}) + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c7_9}}) ->IsAllComplex({8.0f, 9.0f})); } TEST_F(LiteralUtilTest, IsAllFirst) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(Literal::CreateR1({false, true})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({false, false})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({false, true})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({false, false})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); + EXPECT_FALSE( + LiteralUtil::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); } TEST_F(LiteralUtilTest, IsZero) { - auto scalar_zero = Literal::CreateR0(0.0f); - auto scalar_one = Literal::CreateR0(1.0f); + auto scalar_zero = LiteralUtil::CreateR0(0.0f); + auto scalar_one = LiteralUtil::CreateR0(1.0f); EXPECT_TRUE(scalar_zero->IsZero({})); EXPECT_FALSE(scalar_one->IsZero({})); - auto array = Literal::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); + auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); EXPECT_FALSE(array->IsZero({0, 1})); EXPECT_TRUE(array->IsZero({0, 2})); EXPECT_TRUE(array->IsZero({1, 1})); EXPECT_FALSE(array->IsZero({1, 2})); - auto complex_zero = Literal::CreateR0(0.0f); - auto complex_nonzero = Literal::CreateR0(0.5f); + auto complex_zero = LiteralUtil::CreateR0(0.0f); + auto complex_nonzero = LiteralUtil::CreateR0(0.5f); EXPECT_TRUE(complex_zero->IsZero({})); EXPECT_FALSE(complex_nonzero->IsZero({})); } @@ -563,7 +570,7 @@ TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. TypeParam half = TypeParam(1) / TypeParam(2); - auto data = Literal::CreateR2({{half, 2}, {3, 4}}); + auto data = LiteralUtil::CreateR2({{half, 2}, {3, 4}}); const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); @@ -577,7 +584,7 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { } TEST_F(LiteralUtilTest, ReshapeR0) { - auto original = Literal::CreateR0(1.7f); + auto original = LiteralUtil::CreateR0(1.7f); auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); EXPECT_EQ(*original, *reshape); } @@ -585,13 +592,13 @@ TEST_F(LiteralUtilTest, ReshapeR0) { TEST_F(LiteralUtilTest, ReshapeR4) { // clang-format off // F32[1x3x2x4] - auto original = Literal::CreateR4WithLayout({{ + auto original = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // F32[1x3x4x2] - auto expected = Literal::CreateR3WithLayout({ + auto expected = LiteralUtil::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, @@ -605,13 +612,13 @@ TEST_F(LiteralUtilTest, ReshapeR4) { TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { // clang-format off // F32[1x3x2x4] - auto original = Literal::CreateR4WithLayout({{ + auto original = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0minor_); // F32[1x3x4x2] - auto expected = Literal::CreateR3WithLayout({ + auto expected = LiteralUtil::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, @@ -623,7 +630,7 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { } TEST_F(LiteralUtilTest, TransposeR0) { - auto original = Literal::CreateR0(1.7f); + auto original = LiteralUtil::CreateR0(1.7f); auto reshape = original->Transpose(/*permutation=*/{}); EXPECT_EQ(*original, *reshape); } @@ -631,7 +638,7 @@ TEST_F(LiteralUtilTest, TransposeR0) { TEST_F(LiteralUtilTest, TransposeR4) { // clang-format off // F32[1x3x2x4] - auto original = Literal::CreateR4({{ + auto original = LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -659,7 +666,7 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. - auto mat_dim0minor = Literal::CreateR2WithLayout( + auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); EXPECT_EQ(mat_dim0minor->element_count(), 6); EXPECT_THAT(mat_dim0minor->data(), ElementsAre(1, 4, 2, 5, 3, 6)); @@ -670,7 +677,7 @@ TEST_F(LiteralUtilTest, TestR2LinearLayout) { ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). - auto mat_dim0major = Literal::CreateR2WithLayout( + auto mat_dim0major = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); EXPECT_EQ(mat_dim0major->element_count(), 6); EXPECT_THAT(mat_dim0major->data(), ElementsAre(1, 2, 3, 4, 5, 6)); @@ -695,8 +702,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { {10, 11, 12}, }, }); // clang-format on - auto lit_dim0minor = - Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0minor_); + auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( + arr3d, layout_r3_dim0minor_); EXPECT_EQ(lit_dim0minor->element_count(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; @@ -710,8 +717,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). - auto lit_dim0major = - Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0major_); + auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( + arr3d, layout_r3_dim0major_); EXPECT_EQ(lit_dim0major->element_count(), 12); EXPECT_THAT(lit_dim0major->data(), testing::ElementsAreArray(expected_dim0major)); @@ -723,28 +730,28 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { } TEST_F(LiteralUtilTest, SliceR0S32) { - auto input = Literal::CreateR0(1); + auto input = LiteralUtil::CreateR0(1); auto result = input->Slice({}, {}); EXPECT_EQ(*input, *result); } TEST_F(LiteralUtilTest, SliceR1F32) { - auto input = Literal::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); + auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); auto result = input->Slice({3}, {4}); - auto expected = Literal::CreateR1({4.0}); + auto expected = LiteralUtil::CreateR1({4.0}); EXPECT_EQ(*expected, *result); } TEST_F(LiteralUtilTest, SliceR2U32) { - auto input_3x4 = - Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto input_3x4 = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); auto result = input_3x4->Slice({0, 2}, {2, 4}); - auto expected = Literal::CreateR2({{3, 4}, {7, 8}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); EXPECT_EQ(*expected, *result); } TEST_F(LiteralUtilTest, SliceR3U32Full) { - auto input_2x3x2 = Literal::CreateR3( + auto input_2x3x2 = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); EXPECT_EQ(*input_2x3x2, *result); @@ -753,21 +760,21 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) { TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); - auto expected = Literal::CreateR1({77}); + auto expected = LiteralUtil::CreateR1({77}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output(ShapeUtil::MakeShape(U64, {2})); output.PopulateR1({{77, 88}}); - auto expected = Literal::CreateR1({{77, 88}}); + auto expected = LiteralUtil::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR1C64) { Literal output(ShapeUtil::MakeShape(C64, {1})); output.PopulateR1({{77, 88}}); - auto expected = Literal::CreateR1({{77, 88}}); + auto expected = LiteralUtil::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } @@ -775,7 +782,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); auto expected = - Literal::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + LiteralUtil::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); EXPECT_EQ(output, *expected); } @@ -783,7 +790,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { Literal output(ShapeUtil::MakeShape(BF16, {})); bfloat16 h(0.25f); output.PopulateWithValue(h); - auto expected = Literal::CreateR0(h); + auto expected = LiteralUtil::CreateR0(h); EXPECT_EQ(output, *expected); } @@ -791,7 +798,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { Literal output(ShapeUtil::MakeShape(BF16, {3})); bfloat16 h(0.5f); output.PopulateWithValue(h); - auto expected = Literal::CreateR1({h, h, h}); + auto expected = LiteralUtil::CreateR1({h, h, h}); EXPECT_EQ(output, *expected); } @@ -799,28 +806,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { Literal output(ShapeUtil::MakeShape(BF16, {2, 2})); bfloat16 h(2.0f); output.PopulateWithValue(h); - auto expected = Literal::CreateR2({{h, h}, {h, h}}); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output(ShapeUtil::MakeShape(F32, {})); output.PopulateWithValue(2.5f); - auto expected = Literal::CreateR0(2.5f); + auto expected = LiteralUtil::CreateR0(2.5f); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output(ShapeUtil::MakeShape(S64, {3})); output.PopulateWithValue(-7); - auto expected = Literal::CreateR1({-7, -7, -7}); + auto expected = LiteralUtil::CreateR1({-7, -7, -7}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output(ShapeUtil::MakeShape(U64, {2, 2})); output.PopulateWithValue(42); - auto expected = Literal::CreateR2({{42, 42}, {42, 42}}); + auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); EXPECT_EQ(output, *expected); } @@ -828,7 +835,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateWithValue({4, 2}); auto expected = - Literal::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); EXPECT_EQ(output, *expected); } @@ -836,7 +843,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output(ShapeUtil::MakeShape(F16, {})); half h(0.25f); output.PopulateWithValue(h); - auto expected = Literal::CreateR0(h); + auto expected = LiteralUtil::CreateR0(h); EXPECT_EQ(output, *expected); } @@ -844,7 +851,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { Literal output(ShapeUtil::MakeShape(F16, {3})); half h(0.5f); output.PopulateWithValue(h); - auto expected = Literal::CreateR1({h, h, h}); + auto expected = LiteralUtil::CreateR1({h, h, h}); EXPECT_EQ(output, *expected); } @@ -852,15 +859,15 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { Literal output(ShapeUtil::MakeShape(F16, {2, 2})); half h(2.0f); output.PopulateWithValue(h); - auto expected = Literal::CreateR2({{h, h}, {h, h}}); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, ReplicateR2U32) { - auto input = - Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto input = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); auto output = input->Replicate(3); - auto expected = Literal::CreateR3( + auto expected = LiteralUtil::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); @@ -914,12 +921,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { } TEST_F(LiteralUtilTest, CopyFromScalars) { - auto zero = Literal::CreateR0(0); - auto nine = Literal::CreateR0(9); + auto zero = LiteralUtil::CreateR0(0); + auto nine = LiteralUtil::CreateR0(9); TF_EXPECT_OK(zero->CopyFrom(*nine)); EXPECT_EQ(*zero, *nine); - auto vect = Literal::CreateR1({3, 4, 9, 12, 5, 17, 21}); + auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); EXPECT_EQ(zero->Get({}), 17); TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); @@ -928,13 +935,13 @@ TEST_F(LiteralUtilTest, CopyFromScalars) { TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0}); - const auto const_nine = Literal::CreateR1({9}); + const auto const_nine = LiteralUtil::CreateR1({9}); const auto const_empty = Literal::CreateFromShape(empty_r1_shape); { // Source contains dimension with zero elements. const auto empty = Literal::CreateFromShape(empty_r1_shape); - auto nine = Literal::CreateR1({9}); + auto nine = LiteralUtil::CreateR1({9}); TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); EXPECT_EQ(*nine, *const_nine); @@ -943,7 +950,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { { // Copy 0 element to destination with zero elements. const auto empty = Literal::CreateFromShape(empty_r1_shape); - auto nine = Literal::CreateR1({9}); + auto nine = LiteralUtil::CreateR1({9}); TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); EXPECT_EQ(*empty, *const_empty); @@ -958,16 +965,16 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) { } TEST_F(LiteralUtilTest, CopyFromArrays) { - auto scalar_42 = Literal::CreateR0(42.0); - auto scalar_123 = Literal::CreateR0(123.0); + auto scalar_42 = LiteralUtil::CreateR0(42.0); + auto scalar_123 = LiteralUtil::CreateR0(123.0); EXPECT_NE(*scalar_42, *scalar_123); TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, /*src_shape_index=*/{})); EXPECT_EQ(*scalar_42, *scalar_123); EXPECT_EQ(scalar_42->Get({}), 123.0f); - auto matrix_1234 = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_5678 = Literal::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); + auto matrix_1234 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_5678 = LiteralUtil::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); EXPECT_NE(*matrix_1234, *matrix_5678); EXPECT_EQ(matrix_1234->Get({0, 0}), 1.0f); TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, @@ -977,19 +984,19 @@ TEST_F(LiteralUtilTest, CopyFromArrays) { } TEST_F(LiteralUtilTest, CopyFromTuples) { - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = Literal::MakeTuple( + auto nested_tuple = LiteralUtil::MakeTuple( {matrix.get(), - Literal::MakeTuple({Literal::CreateR0(42).get(), - Literal::CreateR1({23.0, 44.0}).get(), - &nil_literal}) + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) .get()}); // Create a tuple the same shape as the inner tuple of nested_tuple but with // different values.. - auto tuple = Literal::MakeTuple({Literal::CreateR0(-5).get(), - Literal::CreateR1({2.0, 4.0}).get(), - &nil_literal}); + auto tuple = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(-5).get(), + LiteralUtil::CreateR1({2.0, 4.0}).get(), &nil_literal}); EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); @@ -1010,8 +1017,8 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 4.0); } TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { - auto tuple = Literal::MakeTuple( - {Literal::CreateR0(-2).get(), Literal::CreateR0(4).get()}); + auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0(-2).get(), + LiteralUtil::CreateR0(4).get()}); EXPECT_EQ(tuple->Get({}, {0}), -2); EXPECT_EQ(tuple->Get({}, {1}), 4); @@ -1025,8 +1032,8 @@ TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { } TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto vector = Literal::CreateR1({5.0, 7.0}); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto vector = LiteralUtil::CreateR1({5.0, 7.0}); Status status = matrix->CopyFrom(*vector); ASSERT_FALSE(status.ok()); ASSERT_THAT(status.error_message(), @@ -1051,7 +1058,7 @@ TEST_F(LiteralUtilTest, F16) { half h1(1.0f); half h2(2.0f); - auto m2 = Literal::CreateR2({{h1, h2}, {h2, h1}}); + auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); Literal* l2 = m2.get(); const char* d2 = reinterpret_cast(l2->data().data()); EXPECT_EQ(d2[0], 0); @@ -1150,12 +1157,12 @@ TEST_F(LiteralUtilTest, PopulateParallel) { TEST_F(LiteralUtilTest, ConvertR4) { // clang-format off - auto original = Literal::CreateR4WithLayout({{ + auto original = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); - auto expected = Literal::CreateR4WithLayout({{ + auto expected = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -1169,42 +1176,42 @@ TEST_F(LiteralUtilTest, ConvertR4) { TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { // clang-format off - auto s8 = Literal::CreateR4WithLayout({{ + auto s8 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s32 = Literal::CreateR4WithLayout({{ + auto s32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u32 = Literal::CreateR4WithLayout({{ + auto u32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s64 = Literal::CreateR4WithLayout({{ + auto s64 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u64 = Literal::CreateR4WithLayout({{ + auto u64 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto pred = Literal::CreateR4WithLayout({{ + auto pred = LiteralUtil::CreateR4WithLayout({{ {{true, false, true, false}, {false, true, false, true}}, {{false, true, false, true}, {true, false, true, false}}, {{true, false, true, false}, {false, true, false, true}}, }}, layout_r4_dim0major_); - auto int32_pred = Literal::CreateR4WithLayout({{ + auto int32_pred = LiteralUtil::CreateR4WithLayout({{ {{1, 0, 1, 0}, {0, 1, 0, 1}}, {{0, 1, 0, 1}, {1, 0, 1, 0}}, {{1, 0, 1, 0}, {0, 1, 0, 1}}, }}, layout_r4_dim0major_); - auto f16 = Literal::CreateR4WithLayout({{ + auto f16 = LiteralUtil::CreateR4WithLayout({{ {{half(10.0), half(0.0), half(12.0), half(0.0)}, {half(0.0), half(15.0), half(0.0), half(17.0)}}, {{half(0.0), half(19.0), half(0.0), half(21.0)}, @@ -1212,7 +1219,7 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{half(26.0), half(0.0), half(28.0), half(0.0)}, {half(0.0), half(31.0), half(0.0), half(33.0)}}, }}, layout_r4_dim0major_); - auto bf16 = Literal::CreateR4WithLayout({{ + auto bf16 = LiteralUtil::CreateR4WithLayout({{ {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)}, {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}}, {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)}, @@ -1220,17 +1227,17 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)}, {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}}, }}, layout_r4_dim0major_); - auto f32 = Literal::CreateR4WithLayout({{ + auto f32 = LiteralUtil::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); - auto f64 = Literal::CreateR4WithLayout({{ + auto f64 = LiteralUtil::CreateR4WithLayout({{ {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, }}, layout_r4_dim0major_); - auto c64 = Literal::CreateR4WithLayout({{ + auto c64 = LiteralUtil::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, @@ -1302,18 +1309,18 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { } TEST_F(LiteralUtilTest, BitcastConvert) { - auto original = - Literal::CreateR1({tensorflow::bit_cast(2.5f), - tensorflow::bit_cast(-42.25f), - tensorflow::bit_cast(100.f), 0xbeef}); - auto expected = Literal::CreateR1( + auto original = LiteralUtil::CreateR1( + {tensorflow::bit_cast(2.5f), + tensorflow::bit_cast(-42.25f), + tensorflow::bit_cast(100.f), 0xbeef}); + auto expected = LiteralUtil::CreateR1( {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, original->BitcastConvert(F32)); } TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { - auto literal = Literal::CreateR0(1234); + auto literal = LiteralUtil::CreateR0(1234); Status status = literal->BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(), @@ -1348,7 +1355,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) { half h1(1.0f); half h2(2.0f); - auto m = Literal::CreateR2({{h1, h2}, {h2, h1}}); + auto m = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); Literal* l = m.get(); EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); EXPECT_EQ(4, l->data().size()); @@ -1391,10 +1398,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { } TEST_F(LiteralUtilTest, LiteralSliceTest) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); Literal nil(ShapeUtil::MakeNil()); EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); @@ -1413,10 +1420,10 @@ TEST_F(LiteralUtilTest, LiteralSliceTest) { } TEST_F(LiteralUtilTest, MutatingLiteralSlice) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. const auto nested_tuple_view = LiteralSlice(*nested_tuple); @@ -1436,15 +1443,16 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) { } TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); const auto nested_tuple_view = LiteralSlice(*nested_tuple); const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); - EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + EXPECT_EQ(matrix_view, + *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { @@ -1488,7 +1496,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { TEST_F(LiteralUtilTest, LiteralMove) { std::unique_ptr matrix = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal literal(std::move(*matrix)); EXPECT_TRUE( @@ -1501,11 +1509,11 @@ TEST_F(LiteralUtilTest, LiteralMove) { TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = Literal::MakeTuple( - {Literal::CreateR2({{1, 2}, {3, 4}}).get(), - Literal::MakeTuple({Literal::CreateR0(42).get(), - Literal::CreateR1({23.0, 44.0}).get(), - &nil_literal}) + auto nested_tuple = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) .get(), &nil_literal}); @@ -1542,13 +1550,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { TEST_F(LiteralUtilTest, MoveIntoTuple) { std::vector elements; - elements.push_back(std::move(*Literal::CreateR0(1.0))); - elements.push_back(std::move(*Literal::CreateR1({4, 8}))); - elements.push_back(std::move( - *Literal::MakeTuple({Literal::CreateR0(42).get(), - Literal::CreateR1({23.0, 44.0}).get()}) + elements.push_back(std::move(*LiteralUtil::CreateR0(1.0))); + elements.push_back(std::move(*LiteralUtil::CreateR1({4, 8}))); + elements.push_back(std::move(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR1({23.0, 44.0}).get()}) - )); + )); Literal literal = Literal::MoveIntoTuple(&elements); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); @@ -1577,7 +1585,7 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); std::unique_ptr matrix = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); literal = std::move(*matrix); EXPECT_TRUE( @@ -1590,7 +1598,7 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { TEST_F(LiteralUtilTest, LiteralSliceCopy) { std::unique_ptr matrix = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); const auto matrix_view = LiteralSlice(*matrix); LiteralSlice matrix_view_copy(matrix_view); @@ -1601,9 +1609,9 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) { } TEST_F(LiteralUtilTest, GetSetTuple) { - auto tuple = Literal::MakeTuple( - {Literal::CreateR0(42.0).get(), - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); + auto tuple = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42.0).get(), + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); @@ -1644,20 +1652,20 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { TEST_F(LiteralUtilTest, ProtoRoundTrip) { // Test serializing then deserializing a Literal through a proto. - auto one_f32 = Literal::CreateR0(1.0); - auto two_f32 = Literal::CreateR0(2.0); - auto vector_int8 = Literal::CreateR1({-128, 0, 2, 4, 7, 56, 127}); - auto vector_c64 = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); - auto vector_bfloat16 = Literal::CreateR1( + auto one_f32 = LiteralUtil::CreateR0(1.0); + auto two_f32 = LiteralUtil::CreateR0(2.0); + auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); + auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = - Literal::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); + LiteralUtil::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); auto matrix_pred = - Literal::CreateR2({{true, false, true}, {false, false, true}}); - auto tuple = Literal::MakeTuple( + LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); + auto tuple = LiteralUtil::MakeTuple( {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = Literal::MakeTuple( + auto nested_tuple = LiteralUtil::MakeTuple( {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); auto to_from_proto = [](const Literal& literal) -> Literal { @@ -1790,8 +1798,8 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { } TEST_F(LiteralUtilTest, SortSparseElements) { - auto literal = - Literal::CreateSparse({10, 10, 10}, SparseIndexArray(10, 3), {}); + auto literal = LiteralUtil::CreateSparse({10, 10, 10}, + SparseIndexArray(10, 3), {}); literal->AppendSparseElement({2, 3, 4}, 2.0); literal->AppendSparseElement({3, 4, 5}, 3.0); literal->AppendSparseElement({1, 2, 3}, 1.0); @@ -1805,21 +1813,22 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); ASSERT_EQ( - Literal::CreateSparse(dimensions, indices, {true, false, true}) + LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) ->GetSparseElementAsString(1), "false"); - ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {1, 2, 3}) + ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) ->GetSparseElementAsString(1), tensorflow::strings::StrCat(int64{2})); - ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) - ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(double{2.0})); - ASSERT_EQ(Literal::CreateSparse(dimensions, indices, - {half{1.0}, half{2.0}, half{3.0}}) + ASSERT_EQ( + LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(double{2.0})); + ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, + {half{1.0}, half{2.0}, half{3.0}}) ->GetSparseElementAsString(1), tensorflow::strings::StrCat(static_cast(half{2.0}))); ASSERT_EQ( - Literal::CreateSparse( + LiteralUtil::CreateSparse( dimensions, indices, std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) ->GetSparseElementAsString(1), @@ -1827,33 +1836,36 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { - std::unique_ptr literal = Literal::CreateR1({1, 2}); + std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr broadcasted_literal, literal->Broadcast( /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), /*dimensions=*/{0})); - EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 1}, {2, 2}})); + EXPECT_EQ(*broadcasted_literal, + *LiteralUtil::CreateR2({{1, 1}, {2, 2}})); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { - std::unique_ptr literal = Literal::CreateR1({1, 2}); + std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr broadcasted_literal, literal->Broadcast( /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), /*dimensions=*/{1})); - EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 2}, {1, 2}})); + EXPECT_EQ(*broadcasted_literal, + *LiteralUtil::CreateR2({{1, 2}, {1, 2}})); } TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { - std::unique_ptr literal = Literal::CreateR0(9); + std::unique_ptr literal = LiteralUtil::CreateR0(9); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr broadcasted_literal, literal->Broadcast( /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), /*dimensions=*/{})); - EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{9, 9}, {9, 9}})); + EXPECT_EQ(*broadcasted_literal, + *LiteralUtil::CreateR2({{9, 9}, {9, 9}})); } } // namespace diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index eeabf835ac3..548fbe8a83a 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -43,25 +43,6 @@ namespace xla { namespace { -constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; - -// Converts between little and big endian. -// -// Precondition: size % 2 == 0 (elements in the array are 16 bits long) -void ConvertEndianShort(string* bytes) { - CHECK_EQ(bytes->size() / 2, 0); - for (int64 i = 0; i < bytes->size(); i += 2) { - std::swap((*bytes)[i], (*bytes)[i + 1]); - } -} - -void ConvertEndianShort(char* bytes, int64 size) { - CHECK_EQ(size / 2, 0); - for (int64 i = 0; i < size; i += 2) { - std::swap(bytes[i], bytes[i + 1]); - } -} - // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template @@ -103,505 +84,54 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } // namespace -LiteralBase::~LiteralBase() {} - -std::ostream& operator<<(std::ostream& out, const Literal& literal) { - out << literal.ToString(); - return out; -} - -Literal::StrideConfig::StrideConfig( - const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions) - : dimensions(dimensions), - base(dimensions.size(), 0), - step(dimensions.size(), 1) { - if (!dimensions.empty()) { - // Selects the shape with the largest minor dimension as the one upon - // which to run the tight stride loop. - if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >= - dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) { - minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0); - dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); - } else { - minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0); - source_stride = - IndexUtil::GetDimensionStride(source_shape, minor_dimension); - } - minor_loop_size = dimensions[minor_dimension]; - step[minor_dimension] = minor_loop_size; - } -} - -Literal::Literal(const Shape& shape) - : Literal(shape, /*allocate_arrays=*/true) {} - -void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { - if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& subshape = shape.tuple_shapes(i); - - auto child_piece = Piece(); - child_piece.set_subshape(&subshape); - - SetPiece(subshape, &child_piece, allocate_arrays); - - piece->emplace_back(std::move(child_piece)); - } - } else if (ShapeUtil::IsArray(shape)) { - if (allocate_arrays) { - if (LayoutUtil::IsSparseArray(shape)) { - // For sparse arrays, the buffer must be of the size of the maximum - // number of sparse elements possible. - const int64 max_sparse_elements = - LayoutUtil::MaxSparseElements(shape.layout()); - piece->set_buffer( - new char[max_sparse_elements * - ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); - piece->set_sparse_indices( - new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); - } else { - piece->set_buffer(new char[piece->size_bytes()]); - } - } - } else { - // If the shape is neither an array nor tuple, then it must be - // zero-sized. Otherwise, some memory needs to be allocated for it. - CHECK_EQ(piece->size_bytes(), 0); - } -} - -Literal::Literal(const Shape& shape, bool allocate_arrays) - : LiteralBase(), shape_(MakeUnique(shape)) { - CHECK(LayoutUtil::HasLayout(*shape_)); - root_piece_ = new Piece(); - root_piece_->set_subshape(shape_.get()); - CHECK(&root_piece_->subshape() == shape_.get()); - - SetPiece(*shape_, root_piece_, allocate_arrays); -} - -Literal::~Literal() { - if (root_piece_ != nullptr) { - DeallocateBuffers(); - delete root_piece_; - } -} - -void Literal::DeallocateBuffers() { - root_piece_->ForEachMutableSubpiece( - [&](const ShapeIndex& index, Piece* piece) { - if (piece->buffer() != nullptr) { - delete[] piece->buffer(); - delete piece->sparse_indices(); - } - }); -} - -Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } - -Literal& Literal::operator=(Literal&& other) { - DCHECK(&other.root_piece_->subshape() == other.shape_.get()); - using std::swap; - swap(shape_, other.shape_); - swap(root_piece_, other.root_piece_); - DCHECK(&root_piece_->subshape() == shape_.get()); - - return *this; -} - -std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = MakeUnique(shape); - literal->root_piece_->ForEachMutableSubpiece( - [&](const ShapeIndex& index, Piece* piece) { - if (ShapeUtil::IsArray(piece->subshape())) { - memset(piece->untyped_data(), 0, piece->size_bytes()); - } - }); - return literal; -} - -const SparseIndexArray* LiteralBase::sparse_indices( - const ShapeIndex& shape_index) const { - return piece(shape_index).sparse_indices(); -} - -SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { - return piece(shape_index).sparse_indices(); -} - -/* static */ std::unique_ptr Literal::CreateFromDimensions( +/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions) { - return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); + return Literal::CreateFromShape( + ShapeUtil::MakeShape(primitive_type, dimensions)); } -/* static */ std::unique_ptr Literal::ConvertBF16ToF32( +/* static */ std::unique_ptr LiteralUtil::ConvertBF16ToF32( const LiteralSlice& bf16_literal) { return ConvertType(bf16_literal); } -/* static */ std::unique_ptr Literal::ConvertF32ToBF16( +/* static */ std::unique_ptr LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); } -template -Status Literal::CopySliceFromInternal( - const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); - TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); - - auto linear_index = [](const Shape& shape, - tensorflow::gtl::ArraySlice multi_index) { - return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); - }; - - if (ShapeUtil::Rank(src_literal.shape()) == 0 || - ShapeUtil::Rank(shape()) == 0) { - // If any of the two shapes are scalars, we can just call the StridedCopy() - // directly, and we know we will be copying only one value. - TF_RET_CHECK(copy_size.empty()); - StridedCopy(data(), linear_index(shape(), dest_base), 0, - src_literal.data(), - linear_index(src_literal.shape(), src_base), 0, 1); - } else if (!ShapeUtil::IsZeroElementArray(shape()) && - !ShapeUtil::IsZeroElementArray(src_literal.shape())) { - // Perform copy if neither src nor dest has dimensions with zero element, - // otherwise it's a no-op. - TF_RET_CHECK(src_base.size() == dest_base.size()); - TF_RET_CHECK(src_base.size() == copy_size.size()); - - // Scan the source from minor, stepping in copy size blocks, then within - // the index enumaration functor, do a strided copy advancing source index - // by one (walking through the minor dimension), and destination index by - // proper stride size at the matching dimension. - DimensionVector src_indexes(src_base.size(), 0); - DimensionVector dest_indexes(dest_base.size(), 0); - Literal::StrideConfig stride_config(src_literal.shape(), shape(), - copy_size); - - auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { - // Map from multi-dimensional index, to source index. - std::transform(indexes.begin(), indexes.end(), src_base.begin(), - src_indexes.begin(), std::plus()); - // Map from multi-dimensional index, to destination index. - std::transform(indexes.begin(), indexes.end(), dest_base.begin(), - dest_indexes.begin(), std::plus()); - - int64 src_index = linear_index(src_literal.shape(), src_indexes); - int64 dest_index = linear_index(shape(), dest_indexes); - - // `this->` is needed to workaround MSVC bug: #16882 - StridedCopy(this->data(), dest_index, stride_config.dest_stride, - src_literal.data(), src_index, - stride_config.source_stride, stride_config.minor_loop_size); - return true; - }; - - ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base, - stride_config.dimensions, stride_config.step, - copy_proc); - } - return Status::OK(); -} - -Status Literal::CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index) { - DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); - const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( - src_literal.shape(), src_index); - const int64 dest_linear_index = - IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index); - const int64 primitive_size = - ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); - - char* dest_address = - static_cast(untyped_data()) + dest_linear_index * primitive_size; - const char* source_address = - static_cast(src_literal.untyped_data()) + - src_linear_index * primitive_size; - if (dest_address != source_address) { - memcpy(dest_address, source_address, primitive_size); - } - return Status::OK(); -} - -/* static */ std::unique_ptr Literal::CreateToken() { +/* static */ std::unique_ptr LiteralUtil::CreateToken() { return MakeUnique(ShapeUtil::MakeTokenShape()); } -std::vector Literal::DecomposeTuple() { - CHECK(ShapeUtil::IsTuple(shape())); - std::vector elements; - for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), - /*allocate_arrays=*/false)); - Literal& element = elements.back(); - element.root_piece_->ForEachMutableSubpiece( - [&](const ShapeIndex& index, Piece* dest_piece) { - ShapeIndex src_index = {i}; - for (int64 j : index) { - src_index.push_back(j); - } - Piece& src_piece = piece(src_index); - - // Move the respective buffer and sparse indices over to the element - // Literal. - dest_piece->set_buffer(src_piece.buffer()); - src_piece.set_buffer(nullptr); - dest_piece->set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); - }); - } - // Set this literal to be nil-shaped. - *this = Literal(); - return elements; -} - -/* static */ Literal Literal::MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements) { - std::vector element_shapes; - for (const Literal& element : elements) { - element_shapes.push_back(element.shape()); - } - Literal literal(ShapeUtil::MakeTupleShape(element_shapes), - /*allocate_arrays=*/false); - for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK( - literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); - } - return literal; -} - -namespace { - -// Copies the elements in 'src' to 'dest'. The shape and layout of the data in -// the array slices are indicated by dest_shape and src_shape respectively. -template -void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, - tensorflow::gtl::ArraySlice src, - const Shape& dest_shape, const Shape& src_shape) { - CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); - if (ShapeUtil::IsZeroElementArray(dest_shape)) { - return; - } - std::vector index(ShapeUtil::Rank(dest_shape)); - do { - dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = - src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; - } while (IndexUtil::BumpIndices(dest_shape, &index)); -} - -} // namespace - -Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { - CHECK(subshape_ != nullptr); - CHECK(src.subshape_ != nullptr); - if (ShapeUtil::Equal(subshape(), src.subshape())) { - // If the layouts are equal it's faster just to memcpy. - memcpy(buffer(), src.buffer(), src.size_bytes()); - } else { - TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); - std::vector origin(ShapeUtil::Rank(subshape()), 0); - switch (subshape().element_type()) { -#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ - case (XLA_T): \ - CopyElementsBetween(data(), src.data(), \ - subshape(), src.subshape()); \ - break; - COPY_ELEMENTS(U8, uint8); - COPY_ELEMENTS(U16, uint16); - COPY_ELEMENTS(U32, uint32); - COPY_ELEMENTS(U64, uint64); - COPY_ELEMENTS(S8, int8); - COPY_ELEMENTS(S16, int16); - COPY_ELEMENTS(S32, int32); - COPY_ELEMENTS(S64, int64); - COPY_ELEMENTS(F16, half); - COPY_ELEMENTS(BF16, bfloat16); - COPY_ELEMENTS(F32, float); - COPY_ELEMENTS(F64, double); - COPY_ELEMENTS(C64, complex64); - COPY_ELEMENTS(PRED, bool); -#undef COPY_ELEMENTS - default: - return Unimplemented( - "Copying a Literal object with element type %s is not implemented.", - PrimitiveType_Name(subshape().element_type()).c_str()); - } - } - return Status::OK(); -} - -Status Literal::CopyFrom(const LiteralSlice& src_literal, - const ShapeIndex& dest_shape_index, - const ShapeIndex& src_shape_index) { - const Shape& dest_subshape = - ShapeUtil::GetSubshape(shape(), dest_shape_index); - const Shape& src_subshape = - ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index); - if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { - return InvalidArgument( - "Destination subshape incompatible with source subshape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_subshape).c_str()); - } - return root_piece_->ForEachMutableSubpieceWithStatus( - [&](const ShapeIndex& index, Piece* piece) { - if (!ShapeUtil::IsArray(piece->subshape())) { - return Status::OK(); - } - - // Determine if this index is in the part of this literal that we want - // to copy over from src_literal. - bool in_subtree_to_copy = true; - for (int i = 0; i < dest_shape_index.size(); ++i) { - if (index[i] != dest_shape_index[i]) { - in_subtree_to_copy = false; - break; - } - } - if (!in_subtree_to_copy) { - return Status::OK(); - } - // Construct the index of the corresponding piece in the source literal. - ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { - src_piece_index.push_back(index[i]); - } - TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); - return Status::OK(); - }); -} - -Status Literal::MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index) { - const Shape& dest_subshape = - ShapeUtil::GetSubshape(shape(), dest_shape_index); - if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { - return InvalidArgument( - "Destination subshape not equal to source shape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_literal.shape()).c_str()); - } - - src_literal.root_piece_->ForEachSubpiece( - [&](const ShapeIndex& src_index, const Piece& src_piece) { - if (!ShapeUtil::IsArray(src_piece.subshape())) { - return; - } - - ShapeIndex dest_index = dest_shape_index; - for (int64 i : src_index) { - dest_index.push_back(i); - } - Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); - dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - }); - - src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); - delete src_literal.root_piece_; - src_literal.root_piece_ = new LiteralBase::Piece(); - src_literal.root_piece_->set_subshape(src_literal.shape_.get()); - - return Status::OK(); -} - -Status Literal::CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); - TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) - << ShapeUtil::HumanString(src_literal.shape()); - TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); - - switch (shape().element_type()) { - case U8: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case U16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case U32: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case U64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S8: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S32: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case F16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case BF16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case F32: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case F64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case C64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case PRED: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - default: - break; - } - return Unimplemented( - "Copying a slice from a Literal object with element type %d is not " - "implemented.", - shape().element_type()); -} - -/* static */ Literal Literal::Zero(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case U32: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case U64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case S8: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case S32: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case S64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case F16: - return std::move(*Literal::CreateR0(static_cast(0.0f))); + return std::move(*LiteralUtil::CreateR0(static_cast(0.0f))); case BF16: return std::move( - *Literal::CreateR0(static_cast(0.0f))); + *LiteralUtil::CreateR0(static_cast(0.0f))); case F32: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case F64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case C64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case PRED: - return std::move(*Literal::CreateR0(false)); + return std::move(*LiteralUtil::CreateR0(false)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -614,33 +144,33 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ Literal Literal::One(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case U32: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case U64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case S8: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case S32: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case S64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case F16: - return std::move(*Literal::CreateR0(static_cast(1.0f))); + return std::move(*LiteralUtil::CreateR0(static_cast(1.0f))); case BF16: return std::move( - *Literal::CreateR0(static_cast(1.0f))); + *LiteralUtil::CreateR0(static_cast(1.0f))); case F32: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case F64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case C64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case PRED: - return std::move(*Literal::CreateR0(true)); + return std::move(*LiteralUtil::CreateR0(true)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -653,44 +183,44 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ Literal Literal::MinValue(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case U32: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case U64: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case S8: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case S32: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case S64: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case F32: - return std::move( - *Literal::CreateR0(-std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + -std::numeric_limits::infinity())); case F64: - return std::move( - *Literal::CreateR0(-std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + -std::numeric_limits::infinity())); case C64: LOG(FATAL) << "C64 element type has no minimum value"; case PRED: - return std::move(*Literal::CreateR0(false)); + return std::move(*LiteralUtil::CreateR0(false)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(-std::numeric_limits::infinity()))); case BF16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(-std::numeric_limits::infinity()))); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; @@ -701,42 +231,42 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case U32: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case U64: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case S8: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case S32: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case S64: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case F32: - return std::move( - *Literal::CreateR0(std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + std::numeric_limits::infinity())); case F64: - return std::move( - *Literal::CreateR0(std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + std::numeric_limits::infinity())); case PRED: - return std::move(*Literal::CreateR0(true)); + return std::move(*LiteralUtil::CreateR0(true)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(std::numeric_limits::infinity()))); case BF16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(std::numeric_limits::infinity()))); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; @@ -747,7 +277,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ std::unique_ptr Literal::CreateR1( +/* static */ std::unique_ptr LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { auto literal = MakeUnique( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); @@ -755,17 +285,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, return literal; } -void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); - CHECK_EQ(element_count(), values.bits()); - CHECK_EQ(shape().element_type(), PRED); - for (int64 i = 0; i < static_cast(values.bits()); ++i) { - Set({i}, values.get(i)); - } -} - -/* static */ std::unique_ptr Literal::CreateR1U8( +/* static */ std::unique_ptr LiteralUtil::CreateR1U8( tensorflow::StringPiece value) { auto literal = MakeUnique( ShapeUtil::MakeShape(U8, {static_cast(value.size())})); @@ -775,116 +295,13 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { return literal; } -/* static */ std::unique_ptr Literal::CreateR2F32Linspace(float from, - float to, - int64 rows, - int64 cols) { +/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( + float from, float to, int64 rows, int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -std::unique_ptr LiteralBase::Relayout( - const Layout& new_layout, const ShapeIndex& shape_index) const { - // Create new shape with 'new_layout' set at the given shape index. - Shape new_shape = shape(); - Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); - TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); - *subshape->mutable_layout() = new_layout; - auto result = MakeUnique(new_shape); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - -std::unique_ptr LiteralBase::Relayout( - const Shape& shape_with_layout) const { - CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) - << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) - << " not compatible with literal shape " - << ShapeUtil::HumanString(shape()); - std::unique_ptr result = CreateFromShape(shape_with_layout); - ShapeUtil::ForEachSubshape( - result->shape(), - [this, &result](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { - TF_CHECK_OK(result->CopyFrom(*this, - /*dest_shape_index=*/index, - /*src_shape_index=*/index)); - } - }); - return result; -} - -StatusOr> LiteralBase::Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const { - if (!ShapeUtil::IsArray(shape())) { - return InvalidArgument("Broadcast only supports arrays."); - } - - for (int64 i = 0; i < dimensions.size(); i++) { - TF_RET_CHECK(shape().dimensions(i) == - result_shape.dimensions(dimensions[i])); - } - - std::unique_ptr result = MakeUnique(result_shape); - - // scratch_source_index is temporary storage space for the computed index into - // the input literal. We put it here to avoid allocating an std::vector in - // every iteration of ShapeUtil::ForEachIndex. - std::vector scratch_source_index(shape().dimensions_size()); - - char* dest_data = static_cast(result->untyped_data()); - const char* source_data = static_cast(untyped_data()); - const int64 primitive_size = - ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); - - ShapeUtil::ForEachIndex( - result_shape, [&](tensorflow::gtl::ArraySlice output_index) { - for (int64 i = 0; i < dimensions.size(); ++i) { - scratch_source_index[i] = output_index[dimensions[i]]; - } - int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( - result_shape, output_index); - int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( - shape(), scratch_source_index); - memcpy(dest_data + primitive_size * dest_index, - source_data + primitive_size * source_index, primitive_size); - return true; - }); - - return std::move(result); -} - -StatusOr> LiteralBase::Reshape( - tensorflow::gtl::ArraySlice dimensions) const { - if (!ShapeUtil::IsArray(shape())) { - return InvalidArgument("Reshape does not support tuples."); - } - std::unique_ptr output; - if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { - output = - Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); - } else { - output = CloneToUnique(); - } - // Because the layout is monotonic, we can simply reuse the same sequence of - // values without changing their order. - *output->mutable_shape_do_not_use() = - ShapeUtil::MakeShape(shape().element_type(), dimensions); - - int64 elements_before = ShapeUtil::ElementsIn(shape()); - int64 elements_after = ShapeUtil::ElementsIn(output->shape()); - if (elements_before != elements_after) { - return InvalidArgument( - "Shapes before and after Literal::Reshape have different numbers " - "of elements: %s vs %s.", - ShapeUtil::HumanString(shape()).c_str(), - ShapeUtil::HumanString(output->shape()).c_str()); - } - return std::move(output); -} - -/* static */ std::unique_ptr Literal::ReshapeSlice( +/* static */ std::unique_ptr LiteralUtil::ReshapeSlice( tensorflow::gtl::ArraySlice new_dimensions, tensorflow::gtl::ArraySlice minor_to_major, const LiteralSlice& literal) { @@ -956,575 +373,64 @@ StatusOr> LiteralBase::Reshape( return new_literal; } -std::unique_ptr LiteralBase::Transpose( - tensorflow::gtl::ArraySlice permutation) const { - CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; - CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) - << "Given permutation is not a permutation of dimension numbers"; - // To transpose the array, we just permute the dimensions and layout, and - // do a straight memory copy of the raw data set. - // This is considerably faster than iterating over every array element using - // the EachCell<>() and Set<>() APIs. - std::vector inverse_permutation = InversePermutation(permutation); - Shape permuted_shape = - ShapeUtil::PermuteDimensions(inverse_permutation, shape()); - // Replace the layout with one affine to this shape, such that a - // transpose operation can be performed by leaving the flat values - // representation intact. - // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. - // The shape with affine layout resulting from that operation will be - // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the - // most minor. - // - // Essentially, given MinMaj(Di) the position of the Di dimension within the - // minor to major vector, and given T(Di) the index that the original Di - // dimension has within the transposed array, a layout is affine if - // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major - // vector of the affine layout. - CHECK(LayoutUtil::IsDenseArray(permuted_shape)); - Layout* layout = permuted_shape.mutable_layout(); - layout->clear_minor_to_major(); - for (auto index : LayoutUtil::MinorToMajor(shape())) { - layout->add_minor_to_major(inverse_permutation[index]); - } - auto new_literal = MakeUnique(permuted_shape); - DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), - ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); - return new_literal; -} - -template -std::unique_ptr LiteralBase::SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const { - auto result_literal = MakeUnique(result_shape); - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - NativeT value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; -} - -std::unique_ptr LiteralBase::Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const { - CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; - - DimensionVector result_dimensions; - for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { - CHECK_GE(start_indices[dnum], 0); - CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) - << "dnum = " << dnum; - int64 dimension = limit_indices[dnum] - start_indices[dnum]; - CHECK_GE(dimension, 0) << "dnum = " << dnum; - result_dimensions.push_back(dimension); - } - const auto result_shape = - ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, - LayoutUtil::MinorToMajor(shape())); - switch (result_shape.element_type()) { - case F32: - return SliceInternal(result_shape, start_indices); - case BF16: - return SliceInternal(result_shape, start_indices); - case C64: - return SliceInternal(result_shape, start_indices); - case S32: - return SliceInternal(result_shape, start_indices); - case U32: - return SliceInternal(result_shape, start_indices); - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(result_shape.element_type()); - } -} - -Literal LiteralBase::Clone() const { - Literal result(shape()); - TF_CHECK_OK(result.CopyFrom(*this)); - return result; -} - -std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = MakeUnique(shape()); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - -string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { - const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); - CHECK(LayoutUtil::IsDenseArray(subshape)); - switch (subshape.element_type()) { +/* static */ Literal LiteralUtil::GetFirstScalarLiteral( + const LiteralSlice& literal) { + CHECK(ShapeUtil::IsArray(literal.shape())); + CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); + switch (literal.shape().element_type()) { case PRED: - return Get(multi_index, shape_index) ? "true" : "false"; - case S8: - return StrCat(Get(multi_index, shape_index)); - case S16: - return StrCat(Get(multi_index, shape_index)); - case S32: - return StrCat(Get(multi_index, shape_index)); - case S64: - return StrCat(Get(multi_index, shape_index)); - case U8: - return StrCat(Get(multi_index, shape_index)); - case U16: - return StrCat(Get(multi_index, shape_index)); - case U32: - return StrCat(Get(multi_index, shape_index)); - case U64: - return StrCat(Get(multi_index, shape_index)); - case F16: - return StrCat(static_cast(Get(multi_index, shape_index))); - case F32: - return StrCat(Get(multi_index, shape_index)); - case BF16: - return StrCat( - static_cast(Get(multi_index, shape_index))); - case F64: - return StrCat(Get(multi_index, shape_index)); - case C64: { - complex64 c = Get(multi_index, shape_index); - return StrCat("(", c.real(), ", ", c.imag(), ")"); - } - default: - LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); - } -} - -string LiteralBase::GetSparseElementAsString( - int64 sparse_element_number, const ShapeIndex& shape_index) const { - const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); - CHECK(LayoutUtil::IsSparseArray(subshape)); - switch (subshape.element_type()) { - case PRED: - return GetSparseElement(sparse_element_number, shape_index) - ? "true" - : "false"; - case S8: - return StrCat(GetSparseElement(sparse_element_number, shape_index)); - case S16: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case S32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case S64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U8: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U16: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case F16: - return StrCat(static_cast( - GetSparseElement(sparse_element_number, shape_index))); - case F32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case BF16: - return StrCat(static_cast( - GetSparseElement(sparse_element_number, shape_index))); - case F64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case C64: { - complex64 c = - GetSparseElement(sparse_element_number, shape_index); - return StrCat("(", c.real(), ", ", c.imag(), ")"); - } - default: - LOG(FATAL) << "Invalid element type for sparse arrays: " - << PrimitiveType_Name(subshape.element_type()); - } -} - -StatusOr LiteralBase::GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(LayoutUtil::IsDenseArray(shape())); - switch (shape().element_type()) { - case PRED: - return Get(multi_index); - case U8: - return Get(multi_index); - case S32: - return Get(multi_index); - case S64: - return Get(multi_index); - case U32: - return Get(multi_index); - case U64: - return Get(multi_index); - default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); - } -} - -size_t LiteralBase::Hash() const { - using tensorflow::Hash64; - using tensorflow::Hash64Combine; - - size_t hash_value = ShapeUtil::Hash(shape()); - - ShapeUtil::ForEachSubshape( - shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsArray(subshape)) { - return; - } - - CHECK(LayoutUtil::IsDense(subshape.layout())); - hash_value = Hash64Combine( - hash_value, Hash64(static_cast(untyped_data(index)), - size_bytes(index))); - }); - - return hash_value; -} - -Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value) { - CHECK(LayoutUtil::IsDenseArray(shape())); - switch (shape().element_type()) { - case PRED: - Set(multi_index, value); - break; - case U8: - Set(multi_index, value); - break; - case S32: - Set(multi_index, value); - break; - case S64: - Set(multi_index, value); - break; - case U32: - Set(multi_index, value); - break; - case U64: - Set(multi_index, value); - break; - default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); - } - return Status::OK(); -} - -tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index) const { - const Piece& p = piece(shape_index); - CHECK_GE(sparse_element_number, 0); - CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); - return p.sparse_indices()->At(sparse_element_number); -} - -void Literal::SortSparseElements(const ShapeIndex& shape_index) { - piece(shape_index).SortSparseElements(); -} - -Literal LiteralBase::GetFirstScalarLiteral() const { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_GT(ShapeUtil::ElementsIn(shape()), 0); - switch (shape().element_type()) { - case PRED: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); // 8 bit types. case S8: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U8: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); // 16 bit types. case BF16: - return std::move( - *Literal::CreateR0(GetFirstElement())); + return std::move(*LiteralUtil::CreateR0( + literal.GetFirstElement())); case F16: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case S16: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U16: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); // 32 bit types. case F32: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case S32: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U32: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); // 64 bit types. case C64: + return std::move(*LiteralUtil::CreateR0( + literal.GetFirstElement())); + case F64: return std::move( - *Literal::CreateR0(GetFirstElement())); - case F64: - return std::move(*Literal::CreateR0(GetFirstElement())); + *LiteralUtil::CreateR0(literal.GetFirstElement())); case S64: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U64: - return std::move(*Literal::CreateR0(GetFirstElement())); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); default: - LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + LOG(FATAL) << "Unhandled primitive type " + << literal.shape().element_type(); } } -void LiteralBase::Piece::SortSparseElements() { - switch (subshape().element_type()) { - case PRED: - SortSparseElementsInternal(); - break; - case S8: - SortSparseElementsInternal(); - break; - case U8: - SortSparseElementsInternal(); - break; - case S16: - SortSparseElementsInternal(); - break; - case U16: - SortSparseElementsInternal(); - break; - case S32: - SortSparseElementsInternal(); - break; - case U32: - SortSparseElementsInternal(); - break; - case S64: - SortSparseElementsInternal(); - break; - case U64: - SortSparseElementsInternal(); - break; - case F32: - SortSparseElementsInternal(); - break; - case F64: - SortSparseElementsInternal(); - break; - case C64: - SortSparseElementsInternal(); - break; - case F16: - SortSparseElementsInternal(); - break; - case BF16: - SortSparseElementsInternal(); - break; - default: - LOG(FATAL) << "Element type not valid for sparse array: " - << PrimitiveType_Name(subshape().element_type()); - } -} - -template -void LiteralBase::Piece::SortSparseElementsInternal() { - CHECK(LayoutUtil::IsSparseArray(subshape())); - int64 num_elements = sparse_indices()->index_count(); - auto values = data(); - CHECK_LE(num_elements, values.size()); - sparse_indices()->SortWithValues( - tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); -} - -namespace { - -void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { - const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - CHECK(LayoutUtil::HasLayout(literal.shape())); - CHECK(LayoutUtil::HasLayout(subshape)); - - auto shape_to_string = [print_layout](const Shape& shape) { - if (print_layout) { - return ShapeUtil::HumanStringWithLayout(shape); - } else { - return ShapeUtil::HumanString(shape); - } - }; - - // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" (\n"); - std::vector tuple_pieces; - for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { - ShapeIndex element_index = shape_index; - element_index.push_back(i); - std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); - } - pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); - pieces->push_back("\n)"); - return; - } - - if (ShapeUtil::IsToken(subshape)) { - pieces->push_back("token"); - return; - } - - if (LayoutUtil::IsSparseArray(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); - int64 num_elements = literal.sparse_element_count(); - for (int64 i = 0; i < num_elements; ++i) { - if (i > 0) { - pieces->push_back(", "); - } - if (rank == 1) { - pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); - pieces->push_back(": "); - } else { - pieces->push_back("["); - pieces->push_back( - tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); - pieces->push_back("]: "); - } - pieces->push_back(literal.GetSparseElementAsString(i)); - } - pieces->push_back("}"); - return; - } - - CHECK(LayoutUtil::IsDenseArray(subshape)); - - auto element_to_string = - [&](tensorflow::gtl::ArraySlice indices) -> string { - PrimitiveType element_type = subshape.element_type(); - if (element_type == PRED) { - // We display predicates in a densely packed form. - return literal.Get(indices, shape_index) ? "1" : "0"; - } - return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - literal.GetAsString(indices, shape_index); - }; - - if (ShapeUtil::Rank(subshape) == 0) { - pieces->push_back(literal.GetAsString({}, shape_index)); - } else if (ShapeUtil::Rank(subshape) == 1) { - pieces->push_back("{"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(element_to_string({i0})); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 2) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(" { "); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(element_to_string({i0, i1})); - } - pieces->push_back(" "); - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 3) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(element_to_string({i0, i1, i2})); - } - pieces->push_back(" }"); - } - pieces->push_back(" }"); - } - pieces->push_back("\n}"); - } else if (ShapeUtil::Rank(subshape) == 4) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(" {"); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(element_to_string({i0, i1, i2, i3})); - } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 5) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(" {"); - for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { - pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); - } - pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" - : "},\n"); - } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); - } else { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {"); - literal.EachCellAsString( - [&](tensorflow::gtl::ArraySlice indices, const string& value) { - pieces->push_back(" "); - pieces->push_back(value); - }); - pieces->push_back("}"); - } -} - -} // namespace - -int64 LiteralBase::sparse_element_count() const { - CHECK(LayoutUtil::IsSparseArray(shape())); - return sparse_indices()->index_count(); -} - -string LiteralBase::ToString(bool print_layout) const { - std::vector pieces; - CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, print_layout, &pieces); - return tensorflow::str_util::Join(pieces, ""); -} - -/* static */ std::unique_ptr Literal::MakeTuple( +/* static */ std::unique_ptr LiteralUtil::MakeTuple( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; for (const auto* element : elements) { @@ -1537,7 +443,7 @@ string LiteralBase::ToString(bool print_layout) const { return literal; } -/* static */ std::unique_ptr Literal::MakeTupleFromSlices( +/* static */ std::unique_ptr LiteralUtil::MakeTupleFromSlices( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; for (const auto& element : elements) { @@ -1550,7 +456,7 @@ string LiteralBase::ToString(bool print_layout) const { return literal; } -/* static */ std::unique_ptr Literal::MakeTupleOwned( +/* static */ std::unique_ptr LiteralUtil::MakeTupleOwned( std::vector> elements) { std::vector element_shapes; element_shapes.reserve(elements.size()); @@ -1565,822 +471,9 @@ string LiteralBase::ToString(bool print_layout) const { return literal; } -void LiteralBase::EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const { - if (ShapeUtil::IsZeroElementArray(shape())) { - return; - } - std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( - shape(), /*linear_index=*/0); - do { - per_cell(indices, GetAsString(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); -} - -namespace { -template -std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const LiteralBase& src_literal, const ConverterType& converter) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( - src_literal.shape(), - primitive_util::NativeToPrimitiveType())); - auto src_data = src_literal.data(); - auto dest_data = result_literal->template data(); - int64 num_elements = src_literal.element_count(); - - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = converter(src_data[i]); - } - return result_literal; -} - -template -std::unique_ptr ConvertBetweenNativeTypes( - const LiteralBase& src_literal) { - auto converter = [](NativeSrcT src) { return static_cast(src); }; - return ConvertBetweenNativeTypesWithConverter( - src_literal, converter); -} - -template -typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), - std::unique_ptr>::type -BitcastBetweenNativeTypes(const LiteralBase& src_literal) { - auto converter = [](NativeSrcT src) { - return tensorflow::bit_cast(src); - }; - return ConvertBetweenNativeTypesWithConverter( - src_literal, converter); -} - -// This template specialization is here to make the compiler happy. bit_cast has -// a static check that the types are the same size. This specialization should -// never be used because the source and destination types are checked for -// identical sizes higher up. -template -typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), - std::unique_ptr>::type -BitcastBetweenNativeTypes(const LiteralBase& src_literal) { - LOG(FATAL) << "Invalid bitcast between types of different sizes."; -} - -template -std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique( - ShapeUtil::ChangeElementType(src_literal.shape(), C64)); - using NativeSrcT = - typename primitive_util::PrimitiveTypeToNative::type; - tensorflow::gtl::ArraySlice src_data = - src_literal.data(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->data(); - int64 num_elements = src_literal.element_count(); - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = complex64(static_cast(src_data[i]), 0); - } - return result_literal; -} - -template -std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, - bool bitcast) { - CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - if (bitcast) { - return BitcastBetweenNativeTypes< - typename primitive_util::PrimitiveTypeToNative< - primitive_src_type>::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); - } else { - return ConvertBetweenNativeTypes< - typename primitive_util::PrimitiveTypeToNative< - primitive_src_type>::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); - } -} - -template -StatusOr> ConvertIfDestTypeMatches( - const LiteralBase& src_literal, PrimitiveType primitive_dest_type, - bool bitcast) { - switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal, \ - bitcast); - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F16) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) - CONVERT_IF_TYPES_MATCH(BF16) -#undef CONVERT_IF_TYPES_MATCH - case C64: - if (!bitcast) { - return ConvertToC64(src_literal); - } - break; - // Other types are not yet supported. - default: - break; - } - return Unimplemented( - "Converting from type %s to type %s is not implemented.", - PrimitiveType_Name(src_literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); -} - -StatusOr> ConvertSwitch( - const LiteralBase& literal, PrimitiveType primitive_dest_type, - bool bitcast) { - TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); - if (literal.shape().element_type() == primitive_dest_type) { - return literal.CloneToUnique(); - } - switch (literal.shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \ - bitcast); - CONVERT_IF_DEST_TYPE_MATCHES(PRED) - CONVERT_IF_DEST_TYPE_MATCHES(S8) - CONVERT_IF_DEST_TYPE_MATCHES(S32) - CONVERT_IF_DEST_TYPE_MATCHES(S64) - CONVERT_IF_DEST_TYPE_MATCHES(U8) - CONVERT_IF_DEST_TYPE_MATCHES(U32) - CONVERT_IF_DEST_TYPE_MATCHES(U64) - CONVERT_IF_DEST_TYPE_MATCHES(F16) - CONVERT_IF_DEST_TYPE_MATCHES(F32) - CONVERT_IF_DEST_TYPE_MATCHES(F64) - CONVERT_IF_DEST_TYPE_MATCHES(BF16) -#undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. - default: - return Unimplemented( - "%s from type %s to type %s is not implemented.", - (bitcast ? "Bitcast converting" : "Converting"), - PrimitiveType_Name(literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); - } -} - -} // namespace - -StatusOr> LiteralBase::Convert( - PrimitiveType primitive_dest_type) const { - return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); -} - -StatusOr> LiteralBase::BitcastConvert( - PrimitiveType primitive_dest_type) const { - if (primitive_util::BitWidth(shape().element_type()) != - primitive_util::BitWidth(primitive_dest_type)) { - return InvalidArgument( - "Cannot bitcast convert from %s to %s, bit widths are different: %d != " - "%d", - PrimitiveType_Name(shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str(), - primitive_util::BitWidth(shape().element_type()), - primitive_util::BitWidth(primitive_dest_type)); - } - return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); -} - -StatusOr> LiteralBase::ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16) const { - if (!ShapeUtil::IsTuple(dest_shape)) { - if (round_f32_to_bf16 && shape().element_type() == F32 && - dest_shape.element_type() == BF16) { - auto converter = [](float src) { - return tensorflow::bfloat16::round_to_bfloat16(src); - }; - return ConvertBetweenNativeTypesWithConverter(*this, - converter); - } - return Convert(dest_shape.element_type()); - } - std::vector elements; - for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - auto element = LiteralSlice(*this, {i}); - TF_ASSIGN_OR_RETURN( - auto new_element, - element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); - elements.push_back(std::move(*new_element)); - } - auto converted = MakeUnique(); - *converted = Literal::MoveIntoTuple(&elements); - return std::move(converted); -} - -template -bool LiteralBase::Piece::EqualElementsInternal( - const LiteralBase::Piece& other, std::vector* multi_index) const { - if (multi_index->size() == ShapeUtil::Rank(subshape())) { - return (Get(*multi_index) == other.Get(*multi_index)); - } - for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { - multi_index->push_back(i); - if (!EqualElementsInternal(other, multi_index)) { - return false; - } - multi_index->pop_back(); - } - return true; -} - -bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { - DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); - - std::vector multi_index; - switch (subshape().element_type()) { - case PRED: - return EqualElementsInternal(other, &multi_index); - case U8: - return EqualElementsInternal(other, &multi_index); - case S32: - return EqualElementsInternal(other, &multi_index); - case S64: - return EqualElementsInternal(other, &multi_index); - case U32: - return EqualElementsInternal(other, &multi_index); - case U64: - return EqualElementsInternal(other, &multi_index); - case F32: - return EqualElementsInternal(other, &multi_index); - case F64: - return EqualElementsInternal(other, &multi_index); - case F16: - return EqualElementsInternal(other, &multi_index); - case BF16: - return EqualElementsInternal(other, &multi_index); - case C64: - return EqualElementsInternal(other, &multi_index); - default: - LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " - << PrimitiveType_Name(subshape().element_type()); - } -} - -bool LiteralBase::operator==(const LiteralBase& other) const { - if (!ShapeUtil::Compatible(shape(), other.shape())) { - return false; - } - - return root_piece().ForEachSubpieceWithBool( - [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - const Piece& other_piece = other.piece(index); - if (!piece.EqualElements(other_piece)) { - return false; - } - return true; - }); -} - -namespace { - -template -static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, - NativeT value) { - for (int64 i = 0; i < data.size(); ++i) { - if (data[i] != value) { - return false; - } - } - return true; -} - -} // namespace - -bool LiteralBase::IsAll(int8 value) const { - return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, - const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case U8: - if (value >= 0) { - return AllElementsEqualValue(piece.data(), value); - } - return false; - case U32: - if (value >= 0) { - return AllElementsEqualValue(piece.data(), value); - } - return false; - case U64: - if (value >= 0) { - return AllElementsEqualValue(piece.data(), value); - } - return false; - case S8: - return AllElementsEqualValue(piece.data(), value); - case S32: - return AllElementsEqualValue(piece.data(), value); - case S64: - return AllElementsEqualValue(piece.data(), value); - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case PRED: - if (value == 0) { - return AllElementsEqualValue(piece.data(), false); - } - if (value == 1) { - return AllElementsEqualValue(piece.data(), true); - } - return false; - default: - return false; - } - return false; - }; - - if (!piece_is_all()) { - return false; - } - return true; - }); -} - -bool LiteralBase::IsAllFloat(float value) const { - return root_piece().ForEachSubpieceWithBool( - [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue( - piece.data(), static_cast(value)); - default: - return false; - } - }; - if (!piece_is_all()) { - return false; - } - return true; - }); -} - -bool LiteralBase::IsAllComplex(complex64 value) const { - switch (shape().element_type()) { - case C64: - return AllElementsEqualValue(root_piece().data(), - value); - default: - return false; - } -} - -bool LiteralBase::IsAllFirst() const { - return root_piece().ForEachSubpieceWithBool( - [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - // Empty shapes are not all the first element since there is no first - // element. - if (ShapeUtil::IsZeroElementArray(piece.subshape())) { - return false; - } - auto piece_is_all = [&]() { - switch (piece.subshape().element_type()) { - case PRED: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 8 bit types - case S8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 16 bit types - case BF16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 32 bit types - case F32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 64 bit types - case C64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - default: - return false; - } - }; - - if (!piece_is_all()) { - return false; - } - return true; - }); -} - -bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { - CHECK(ShapeUtil::IsArray(shape())); - switch (shape().element_type()) { - case U8: - return Get(indices) == 0; - case U32: - return Get(indices) == 0; - case U64: - return Get(indices) == 0; - case S8: - return Get(indices) == 0; - case S32: - return Get(indices) == 0; - case S64: - return Get(indices) == 0; - case F32: - return Get(indices) == 0.0f; - case F64: - return Get(indices) == 0.0; - case C64: - return Get(indices) == complex64(0.0f, 0.0f); - case F16: - return Get(indices) == static_cast(0.0f); - case BF16: - return Get(indices) == static_cast(0.0f); - case PRED: - return Get(indices) == false; - default: - LOG(FATAL) << "Input literal must be an array."; - } -} - -namespace { - -template -void CopyToRepeatedField(RepeatedFieldT* dest, - const tensorflow::gtl::ArraySlice src) { - *dest = RepeatedFieldT(src.begin(), src.end()); -} - -} // namespace - -void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { - *proto->mutable_shape() = subshape(); - switch (subshape().element_type()) { - case PRED: - CopyToRepeatedField(proto->mutable_preds(), data()); - break; - case U8: - proto->set_u8s(static_cast(data().data()), - element_count()); - break; - case U32: - CopyToRepeatedField(proto->mutable_u32s(), data()); - break; - case U64: - CopyToRepeatedField(proto->mutable_u64s(), data()); - break; - case S32: - CopyToRepeatedField(proto->mutable_s32s(), data()); - break; - case S64: - CopyToRepeatedField(proto->mutable_s64s(), data()); - break; - case F16: - *proto->mutable_f16s() = string( - reinterpret_cast(data().data()), size_bytes()); - if (!kLittleEndian) { - ConvertEndianShort(proto->mutable_f16s()); - } - break; - case BF16: - *proto->mutable_bf16s() = string( - reinterpret_cast(data().data()), size_bytes()); - if (!kLittleEndian) { - ConvertEndianShort(proto->mutable_bf16s()); - } - break; - case F32: - CopyToRepeatedField(proto->mutable_f32s(), data()); - break; - case F64: - CopyToRepeatedField(proto->mutable_f64s(), data()); - break; - case C64: - for (complex64 value : data()) { - proto->add_c64s(value.real()); - proto->add_c64s(value.imag()); - } - break; - case TUPLE: - case TOKEN: - // Nothing to do but assign the shape which is done above. - return; - default: - LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); - } -} - -const void* LiteralBase::Piece::untyped_data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - return buffer(); -} - -void* LiteralBase::Piece::untyped_data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - return buffer(); -} - -namespace { - -template -Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, - const RepeatedFieldT& src) { - if (dest.size() != src.size()) { - return InvalidArgument( - "Expected %lu elements in LiteralProto repeated field, has %d", - dest.size(), src.size()); - } - std::copy(src.begin(), src.end(), dest.begin()); - return Status::OK(); -} - -} // namespace - -Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { - // These conditions should have been checked in Literal::CreateFromProto. - TF_RET_CHECK(proto.has_shape()); - TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); - TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); - - switch (subshape().element_type()) { - case PRED: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); - break; - case U8: { - auto u8_data = data(); - TF_RET_CHECK(proto.u8s().size() == u8_data.size()); - std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin()); - } break; - case S32: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s32s())); - break; - case S64: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s64s())); - break; - case U32: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u32s())); - break; - case U64: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); - break; - case F16: { - const string& s(proto.f16s()); - TF_RET_CHECK(data().size() * sizeof(half) == s.size()); - memcpy(untyped_data(), s.data(), s.size()); - if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); - } - } break; - - case BF16: { - const string& s(proto.bf16s()); - TF_RET_CHECK(data().size() * sizeof(bfloat16) == s.size()); - memcpy(untyped_data(), s.data(), s.size()); - if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); - } - } break; - case F32: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f32s())); - break; - case F64: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f64s())); - break; - case C64: { - auto complex_data = data(); - TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2); - for (int64 i = 0; i < complex_data.size(); ++i) { - complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; - } - } break; - case TUPLE: - LOG(FATAL) << "Should not be called on tuple shapes: " - << ShapeUtil::HumanString(subshape()); - break; - default: - LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); - } - return Status::OK(); -} - -LiteralProto LiteralBase::ToProto() const { - LiteralProto proto; - root_piece().ForEachSubpiece( - [&](const ShapeIndex& index, const Piece& piece) { - LiteralProto* proto_piece = &proto; - for (int64 i : index) { - while (proto_piece->tuple_literals_size() <= i) { - proto_piece->add_tuple_literals(); - } - proto_piece = proto_piece->mutable_tuple_literals(i); - } - piece.WriteToProto(proto_piece); - }); - - if (LayoutUtil::IsSparseArray(shape())) { - CopyToRepeatedField(proto.mutable_sparse_indices(), - sparse_indices()->data()); - } - - return proto; -} - -/* static */ -StatusOr> Literal::CreateFromProto( - const LiteralProto& proto) { - if (!proto.has_shape()) { - return InvalidArgument("LiteralProto has no shape"); - } - if (!LayoutUtil::HasLayout(proto.shape())) { - return InvalidArgument("LiteralProto has no layout"); - } - - auto literal = MakeUnique(proto.shape()); - - TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( - [&](const ShapeIndex& index, Piece* piece) { - const LiteralProto* proto_element = &proto; - for (int64 i : index) { - CHECK(i < proto_element->tuple_literals_size()); - proto_element = &proto_element->tuple_literals(i); - } - - if (ShapeUtil::IsTuple(piece->subshape())) { - if (proto_element->tuple_literals_size() != - ShapeUtil::TupleElementCount(piece->subshape())) { - return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", - ShapeUtil::TupleElementCount(piece->subshape()), - proto_element->tuple_literals_size()); - } - return Status::OK(); - } - if (piece->subshape().element_type() == TOKEN) { - return Status::OK(); - } - - CHECK(ShapeUtil::IsArray(piece->subshape())); - TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); - - return Status::OK(); - })); - - return std::move(literal); -} - -/* static */ string Literal::MultiIndexAsString( +/* static */ string LiteralUtil::MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index) { return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); } -const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { - return piece(shape_index).untyped_data(); -} - -void* Literal::untyped_data(const ShapeIndex& shape_index) { - return piece(shape_index).untyped_data(); -} - -int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { - return piece(shape_index).size_bytes(); -} - -string LiteralBase::GetR1U8AsString() const { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); - CHECK_EQ(shape().element_type(), U8); - return string(tensorflow::bit_cast(data().data()), - ShapeUtil::ElementsIn(shape())); -} - -void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { - CHECK(ShapeUtil::IsTuple(shape)); - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& subshape = shape.tuple_shapes(i); - - auto child_piece = Piece(); - child_piece.set_subshape(&subshape); - - if (ShapeUtil::IsTuple(subshape)) { - BuildPieceSubtree(subshape, &child_piece); - } - - piece->emplace_back(std::move(child_piece)); - } -} - -LiteralSlice::LiteralSlice(const LiteralBase& literal) - : LiteralBase(), root_piece_(&literal.root_piece()) {} - -LiteralSlice::LiteralSlice(const LiteralBase& literal, - const ShapeIndex& view_root) - : LiteralBase(), root_piece_(&literal.piece(view_root)) {} - -BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { - CHECK(ShapeUtil::IsArray(*shape_)); - CHECK(LayoutUtil::HasLayout(*shape_)); - - root_piece_ = Piece(); - root_piece_.set_buffer(const_cast(src_buf_ptr)); - root_piece_.set_subshape(shape_.get()); -} - -BorrowingLiteral::BorrowingLiteral( - tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { - CHECK(ShapeUtil::IsTuple(*shape_)); - CHECK(!ShapeUtil::IsNestedTuple(*shape_)); - CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); - root_piece_ = Piece(); - root_piece_.set_subshape(shape_.get()); - BuildPieceSubtree(*shape_, &root_piece_); - - for (int i = 0; i < src_buf_ptrs.size(); ++i) { - const auto& src_shape = shape_->tuple_shapes(i); - CHECK(ShapeUtil::IsArray(src_shape)); - root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); - } -} - } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 37ca8ea9f1d..e3737a9d005 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -51,679 +52,12 @@ limitations under the License. namespace xla { -// Forward declare Literal and LiteralSlice class to be used by the creation -// methods in the base class. -class Literal; -class LiteralSlice; - -// Abstract base class for literals. -class LiteralBase { +class LiteralUtil { public: - virtual ~LiteralBase() = 0; - - // Literals are equal if they have compatible shapes and the same data - // values. Layout is not compared. - bool operator==(const LiteralBase& other) const; - bool operator!=(const LiteralBase& other) const { return !(*this == other); } - - // Returns the shape of the literal. - const Shape& shape() const { return root_piece().subshape(); } - - // Serialize to proto. - LiteralProto ToProto() const; - - // Returns an ArraySlice of the array for this literal for the given NativeT - // (e.g., float). CHECKs if the subshape of the literal at the given - // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type - // to native type. - template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; - - // Returns a const pointer to the sparse index array. Returns nullptr if the - // literal is not a sparse array. - const SparseIndexArray* sparse_indices( - const ShapeIndex& shape_index = {}) const; - - // Returns a const pointer to (or size of) the underlying buffer holding the - // array at the given shape index. CHECKs if the subshape of the literal at - // the given ShapeIndex is not array. - const void* untyped_data(const ShapeIndex& shape_index = {}) const; - int64 size_bytes(const ShapeIndex& shape_index = {}) const; - - // Returns this literal's data as a string. This literal must be a rank-1 U8 - // array. - string GetR1U8AsString() const; - - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; - - // Gets an element in the literal at the given index. The multi_index is - // CHECKed against the dimension sizes. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const; - // Overloads of Get for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - NativeT GetFirstElement() const; - - // As Get(), but determines the correct type and converts the value - // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index = {}) const; - // As GetSparseElement(), but determines the correct type and converts the - // value into text. - string GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - // As Get(), but determines the correct type and converts the value into - // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; - - // Returns the multi-index of the element in a sparse literal at the given - // sparse element number. The sparse element number is the position with in - // the sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - - // Returns the value of the element in a sparse literal at the given sparse - // element number. The sparse element number is the position with in the - // sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - template - NativeT GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - // - // This literal must have a dense layout. - void EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const; - template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; - - // Returns whether every element in this literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in this literal's type, returns false. Values of 1/0 - // are considered equal to true/false; other values are not considered equal - // to true. Also if this literal is not array-shaped false is returned. - bool IsAll(int8 value) const; - - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. Also if this literal is not array-shaped false is returned. - bool IsAllFloat(float value) const; - - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular complex number. - // - // If the literal is not a complex value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for complex values that can be expressed precisely as - // float pairs e.g. (-0.5, 1.0). - // - // This literal must have a dense layout. - bool IsAllComplex(complex64 value) const; - - // Literal consists entirely of the first element of the literal. - bool IsAllFirst() const; - - // Returns whether this literal is zero at the specified index. This literal - // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; - - // Returns the count of the elements in the array at the given shape index in - // this literal. - int64 element_count(const ShapeIndex& index = {}) const { - return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); - } - - // Returns the count of the elements in the sparse array at the given shape - // index in this literal, which will be no larger than - // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). - int64 sparse_element_count() const; - - // Compute a hash for this literal. This literal must not be a sparse tensor - // or a tuple containing a sparse tensor. - size_t Hash() const; - - // Converts this literal to the given shape. Returns an error is the - // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; - - // Converts this literal to another primitive type using a bitcast - // conversion. The to and from primitive types must have the same bit - // width. Returns an error if the conversion is not possible. This literal - // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to another primitive type. Returns an error if the - // conversion is not possible. This literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; + LiteralUtil() = delete; // Returns a literal scalar representing the first element. - Literal GetFirstScalarLiteral() const; - - // Clones the underlying buffers into a new Literal, or new - // std::unique_ptr. - Literal Clone() const; - std::unique_ptr CloneToUnique() const; - - // TODO(b/67651157): The methods below which perform computation on Literals - // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with - // evaluator code which operates on Literals. - // - // Creates a new value that has the equivalent value as this - // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, - // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} - // minor-to-major dimension layout and the value in the cell at any given - // logical index (i0, i1) will be the same. - // - // For tuple shaped literals, shape_index should be used to select the inner - // array that the new layout applies to. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; - - // An overload of Relayout which changes the layout of the entire shape rather - // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; - - // Creates a new literal by reshaping this literal to have the given - // dimensions. The total number of elements must not change; The - // implementation currently only supports monotonic dim0-major layouts. - // This literal must be an array. - StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by broadcasting this literal with `dimensions` to - // yield a literal of shape `result_shape`. - StatusOr> Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by reordering the dimensions of this literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; - - // Creates a sub-array from this literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this - // literal replicated four times. - // This literal must be an array. - template - std::unique_ptr Replicate(int64 times) const; - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - // - // Note: It's an antipattern to use this method then immediately call - // Literal::Populate on the result (since that results in zero initialization, - // then reinitialization. Conside if a call to MakeUnique(shape), - // followed by the call to Literal::Populate can be used instead. - static std::unique_ptr CreateFromShape(const Shape& shape); - - protected: - // A data structure representing a subshape at a particular ShapeIndex within - // the literal. For array-shaped ShapeIndexes, this data structure holds the - // pointer to the memory allocated for the array data. - class Piece { - public: - // Returns the buffer holding the array data for this piece as an array - // slice. This piece must be array-shaped. - template - tensorflow::gtl::ArraySlice data() const; - template - tensorflow::gtl::MutableArraySlice data(); - - // Returns the buffer holding the array data for this piece as a void*. This - // piece must be array-shaped. - void* untyped_data(); - const void* untyped_data() const; - - // Gets or sets an element in the array at the given index. The multi_index - // is CHECKed against the dimension sizes of the array. This piece must be - // array-shaped. - template - NativeT Get(tensorflow::gtl::ArraySlice index) const; - template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); - - // Gets/sets the buffer holding the array data. - char* buffer() const { return buffer_; } - void set_buffer(char* buffer) { buffer_ = buffer; } - - // The array of multi-indices that provide the locations of non-zero - // elements in a sparse array. Only used if - // LayoutUtil::IsSparseArray(shape()) is true. - SparseIndexArray* sparse_indices() const { return sparse_indices_; } - void set_sparse_indices(SparseIndexArray* sparse_indices) { - sparse_indices_ = sparse_indices; - } - - // Gets or sets the subshape of this piece. This reference points to a - // subshape within the shape in the containing Literal (Literal::shape_). - const Shape& subshape() const { return *subshape_; } - void set_subshape(const Shape* subshape) { subshape_ = subshape; } - - // Returns the size in bytes of the buffer holding the array data. - int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } - - // Returns the number of elements in this piece's array. - int64 element_count() const { - // If this is a sparse array, use the number of elements represented by - // the indices in the associated SparseIndexArray. - return LayoutUtil::IsSparseArray(subshape()) - ? sparse_indices()->index_count() - : ShapeUtil::ElementsIn(subshape()); - } - - // Returns the child piece at 'index' of this piece. - Piece& child(int64 index) { return children_[index]; } - - // Adds a child piece to this piece's children. - void emplace_back(Piece child_piece) { - children_.emplace_back(std::move(child_piece)); - } - - // Returns the size of children pieces of this piece. - int64 children_size() { return children_.size(); } - - // Visitor functions that recursively traverses the piece and calls the - // given function at each child piece. The function has the type: - // void (const ShapeIndex& index, const Piece& piece) - template - void ForEachSubpiece(const Fn& func) const { - ShapeIndex index; - return ForEachHelper( - [&func](const ShapeIndex& index, const Piece& piece) { - func(index, piece); - return Status::OK(); - }, - *this, &index) - .IgnoreError(); - } - // Same as above, but the function has the type: - // Status (const ShapeIndex& index, const Piece& piece) - // The first non-OK return value is returned by the function. - template - Status ForEachSubpieceWithStatus(const Fn& func) const { - ShapeIndex index; - return ForEachHelper(func, *this, &index); - } - // Same as above, but the function has the type: - // Bool (const ShapeIndex& index, const Piece& piece) - // The first non-true return value is returned by the function. - template - bool ForEachSubpieceWithBool(const Fn& func) const { - ShapeIndex index; - return ForEachHelperBool(func, *this, &index); - } - // Same as above, but the function has the type: - // Void (const ShapeIndex& index, Piece& piece) - template - void ForEachMutableSubpiece(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper( - [&func](const ShapeIndex& index, Piece* piece) { - func(index, piece); - return Status::OK(); - }, - const_cast(this), &index) - .IgnoreError(); - } - // Same as above, but the function has the type: - // Status (const ShapeIndex& index, Piece& piece) - // The first non-OK return value is returned by the function. - template - Status ForEachMutableSubpieceWithStatus(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper( - func, const_cast(this), &index); - } - - // Returns true if this piece and 'other' contain the same data. This piece - // and 'other' must be array-shaped and compatible. - bool EqualElements(const Piece& other) const; - - // Writes the shape and data (if array-shaped) into the given proto. - void WriteToProto(LiteralProto* proto) const; - - // Copy the data from 'src' into this piece's buffer. Shapes of this piece - // and src must be compatible. - Status CopyFrom(const Piece& src); - - // Copies the data from the given proto into this piece. The shape of this - // piece must be equal (not just compatible) to the shape of the proto. - Status CopyFromProto(const LiteralProto& proto); - - // Sorts the elements in a sparse array. - void SortSparseElements(); - - private: - // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. - // The first non-OK (or non-true) value is returned by the function. - // The callable 'func' has the same signature as described above in - // ForEachSubpiece*. - template - Status ForEachHelper(const Fn& func, const Piece& piece, - ShapeIndex* index) const { - TF_RETURN_IF_ERROR(func(*index, piece)); - for (int64 i = 0; i < piece.children_.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); - index->pop_back(); - } - return Status::OK(); - } - template - bool ForEachHelperBool(const Fn& func, const Piece& piece, - ShapeIndex* index) const { - if (!func(*index, piece)) { - return false; - } - for (int64 i = 0; i < piece.children_.size(); ++i) { - index->push_back(i); - if (!ForEachHelperBool(func, piece.children_[i], index)) { - return false; - } - index->pop_back(); - } - return true; - } - template - Status ForEachMutableHelper(const Fn& func, Piece* piece, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, piece)); - for (int64 i = 0; i < piece->children_.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR( - ForEachMutableHelper(func, &piece->children_[i], index)); - index->pop_back(); - } - return Status::OK(); - } - - // Recursive helper for EqualElements. - template - bool EqualElementsInternal(const Piece& other, - std::vector* multi_index) const; - - // Helper for SortSparseElements that has the element type as a template - // parameter. - template - void SortSparseElementsInternal(); - - // For array-shaped pieces, this is the buffer holding the literal data. - char* buffer_ = nullptr; - - // For sparse arrays, this is the array of indices. - SparseIndexArray* sparse_indices_ = nullptr; - - // The shape of piece. This points into the shape of the containing Literal - // (Literal::shape_). - const Shape* subshape_ = nullptr; - - // Children pieces for tuple shaped pieces. - std::vector children_ = {}; - }; // class Piece - - const Piece& piece(const ShapeIndex& shape_index) const { - Piece* piece = &const_cast(root_piece()); - for (const auto i : shape_index) { - DCHECK_GE(i, 0); - DCHECK_LT(i, piece->children_size()); - piece = &piece->child(i); - } - return *piece; - } - - // Returns the piece at the root of the shape. - virtual const Piece& root_piece() const = 0; - - // LiteralSlice and Literal must access Pieces of other Literals. - friend class Literal; - friend class LiteralSlice; - friend class BorrowingLiteral; - - private: - template - std::unique_ptr SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const; -}; - -// Class representing literal values in XLA. -// -// The underlying buffer and shape is always owned by this class. -class Literal : public LiteralBase { - public: - Literal() : Literal(ShapeUtil::MakeNil()) {} - - // Create a literal of the given shape. The literal is allocated sufficient - // memory to hold the shape. Memory is uninitialized. - explicit Literal(const Shape& shape); - virtual ~Literal(); - - // Literals are moveable, but not copyable. To copy a literal use - // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies - // of literals which can be expensive. - Literal(const Literal& other) = delete; - Literal& operator=(const Literal& other) = delete; - Literal(Literal&& other); - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); - Literal& operator=(Literal&& other); - - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return shape_.get(); } - - // Returns a MutableArraySlice view of the array for this literal for the - // given NativeT (e.g., float). CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. See primitive_util.h for the mapping from - // XLA type to native type. - template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); - // Unhide const method from parent class. - using LiteralBase::data; - - // Returns a pointer to the sparse index array. Returns nullptr if the literal - // is not a sparse array. - SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); - - // Returns a pointer to the underlying buffer holding the array at the given - // shape index. CHECKs if the subshape of the literal at the given ShapeIndex - // is not array. - void* untyped_data(const ShapeIndex& shape_index = {}); - // Unhide const method from parent class. - using LiteralBase::untyped_data; - - // Populates a literal with a sparse layout with the given indices and values. - // Each index in the indices array is CHECKed against the dimensions in the - // literal's shape. If sort is true, then the indices and values will be - // sorted. If sort is false, then the indices and values are assumed to - // already be in sorted order. See CreateSparse for an example of how data - // are populated. - template - void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); - - // Copy values from 'src_literal' rooted at 'src_shape_index' into this - // literal rooted at 'dest_shape_index'. The subshape of this literal rooted - // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' - // rooted at 'src_shape_index', but need not be arrays. - Status CopyFrom(const LiteralSlice& src_literal, - const ShapeIndex& dest_shape_index = {}, - const ShapeIndex& src_shape_index = {}); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - - // Copies the values from src_literal, starting at src_base shape indexes, - // to this literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // The src_literal and this literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - // Note: if either src_literal or this literal contains dimensions with zero - // element, then copy_size must be 0 in these dimensions while the - // corresponding base indices being 0. - // This literal and 'src_literal' must be arrays. - Status CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Copies one element from src_literal[src_index] to (*this)[dest_index]. - Status CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); - - // Sets an element in the literal at the given index. The multi_index is - // CHECKed against the dimension sizes. - template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); - // Overloads of Set for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - - // Appends the given element to the literal. If the elements are not appended - // in sorted order, then SortSparseElements should be called before calling - // other methods. This literal must have a sparse layout. - template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); - - // Sorts the elements in a sparse array. - void SortSparseElements(const ShapeIndex& shape_index = {}); - - // As Set(), but truncates `value` to the literal element type before storing. - // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); - - // Populate this literal with the given values. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // literal.PopulateR2FromArray2D(values); - // - // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); - // - // The shape and element type of this literal must match given values. For - // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 - // array of S32. - template - void PopulateR1(tensorflow::gtl::ArraySlice values); - void PopulateR1(const tensorflow::core::Bitmap& values); - template - void PopulateR2(std::initializer_list> values); - template - void PopulateFromArray(const Array& values); - template - void PopulateR2FromArray2D(const Array2D& values); - template - void PopulateR3FromArray3D(const Array3D& values); - template - void PopulateR4FromArray4D(const Array4D& values); - - // Populates literal values by calling the generator function for every cell - // in this literal object. - // - // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. - // - // This literal must have a dense layout. - template - Status Populate(const FnType& generator); - - // A parallel version of Populate(). This can be used if the generator is - // thread-safe and the values for the shape's different elements are - // independent. - template - Status PopulateParallel(const FnType& generator); - - // Fills this literal with the given value. - template - void PopulateWithValue(NativeT value); - - // Factory methods below. - // - - // Serialize from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); + static Literal GetFirstScalarLiteral(const LiteralSlice& literal); // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the @@ -889,7 +223,7 @@ class Literal : public LiteralBase { // As above, but intended to be invoked with move semantics; i.e. // // std::vector> elements = ...; - // auto result = Literal::MakeTupleOwned(std::move(elements)); + // auto result = LiteralUtil::MakeTupleOwned(std::move(elements)); // // This would have been declared as an overload, but there is ambiguity // in invocation between the above signature and this one. @@ -899,7 +233,7 @@ class Literal : public LiteralBase { // This overload lets you pass a braced list of unique_ptrs to // MakeTupleOwned: // - // Literal::MakeTupleOwned(Literal::CreateR1(...), ...). + // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). // // Simply relying on the MakeTupleOwned(std::vector>) // overload doesn't work because std::initializer_list's elements are always @@ -920,19 +254,6 @@ class Literal : public LiteralBase { // Create a constant token literal. Token types have no value. static std::unique_ptr CreateToken(); - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // This operation is the inverse of DecomposeTuple. The given elements are - // moved into the tuple elements of a new tuple-shaped Literal which is - // returned. Upon return, each of the Literals in 'elements' is set to a nil - // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); - // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. // The content of the literal values is the default value of the primitive @@ -1000,194 +321,12 @@ class Literal : public LiteralBase { // dimension 1 equal to 8. static string MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index); - - private: - // Recursively sets the subshapes and buffers of all subpieces rooted at - // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in - // the shape. - void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - - // Returns the piece at the given ShapeIndex. - Piece& piece(const ShapeIndex& shape_index) { - return const_cast(LiteralBase::piece(shape_index)); - } - - Piece& root_piece() const override { return *root_piece_; }; - - // Internal template helper for the Literal::CopySliceFrom(), matching its - // arguments one by one. - template - Status CopySliceFromInternal(const LiteralBase& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Utility structure which is used to create the optimal configuration for - // a ShapeUtil::ForEachIndex() scan across two literals. - struct StrideConfig { - StrideConfig(const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions); - - // The dimensions of the stride operation. Essentially every dimension - // will be iterated from base[i] to base[i]+dimensions[i], in step[i] - // steps. - tensorflow::gtl::ArraySlice dimensions; - DimensionVector base; - DimensionVector step; - int64 minor_dimension = 0; - // The size of the strides for source and destination. One of the two - // (the one looping through its most minor dimension) will be 1, while - // the other will be the stride size at the dimension matching the other - // shape most minor dimension being scanned. - int64 dest_stride = 1; - int64 source_stride = 1; - // The size of the inner loop on the most minor dimension. - int64 minor_loop_size = 1; - }; - - // Literal class always owns the shape. The parent class borrows this shape. - std::unique_ptr shape_; - - Piece* root_piece_ = nullptr; - - // Implementation details shared between Populate() and PopulateParallel() - template - Status PopulateInternal(const FnType& generator, bool parallel); - - // Deallocate the buffers held by this literal. - void DeallocateBuffers(); - - friend class LiteralBase; }; + std::ostream& operator<<(std::ostream& out, const Literal& literal); -// A read-only view of a Literal. A LiteralSlice contains pointers to shape and -// literal buffers always owned by others. -class LiteralSlice : public LiteralBase { - public: - LiteralSlice() : LiteralBase() {} - - // Implicit conversion constructors. - LiteralSlice(const LiteralBase& literal); - LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); - - private: - const Piece& root_piece() const override { return *root_piece_; }; - - const Piece* root_piece_; // Not owned. -}; - -// A read-only Literal where the underlying buffers are never owned by this -// class. -class BorrowingLiteral : public LiteralBase { - public: - BorrowingLiteral() : LiteralBase() {} - - // 'src_buf_ptr' is not owned by this class and must outlive the - // lifetime of this class. It points to an appropirately sized buffer with - // data interpretered as indicated by 'shape'. - // This constructor is only used for array shapes. - BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); - // Similar as above, except to be used for constructing non-nested tuples. - BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, - const Shape& shape); - // TODO(b/79707221): adding constructors for nested tuples as well. - - private: - // Recursively builds the subtree for the given piece and sets the subshapes - // of the given piece with the given shape. - void BuildPieceSubtree(const Shape& shape, Piece* piece); - - // Accessor for the root piece of this literal. - const Piece& root_piece() const override { return root_piece_; }; - Piece root_piece_; - - // Shape of this literal. Stored as unique_ptr so such that the (default) - // move construction of this class would be trivially correct: the pointer to - // Shape root_piece_ stores will still point to the correct address. - std::unique_ptr shape_; -}; - template -tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) - << "Attempting to access " - << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) - << " type, but literal element type is " - << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(buffer()), element_count()); -} - -template -tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) - << "Attempting to access " - << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) - << " type, but literal element type is " - << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(buffer()), element_count()); -} - -template -NativeT LiteralBase::Piece::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(LayoutUtil::IsDenseArray(subshape())); - return data()[IndexUtil::MultidimensionalIndexToLinearIndex( - subshape(), multi_index)]; -} - -template -void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - CHECK(LayoutUtil::IsDenseArray(subshape())); - data()[IndexUtil::MultidimensionalIndexToLinearIndex( - subshape(), multi_index)] = value; -} - -template -tensorflow::gtl::ArraySlice LiteralBase::data( - const ShapeIndex& shape_index) const { - return piece(shape_index).data(); -} - -template -tensorflow::gtl::MutableArraySlice Literal::data( - const ShapeIndex& shape_index) { - return piece(shape_index).data(); -} - -template -inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { - return piece(shape_index).Get(multi_index); -} - -template -inline NativeT LiteralBase::Get( - tensorflow::gtl::ArraySlice multi_index) const { - return root_piece().Get(multi_index); -} - -template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value) { - return piece(shape_index).Set(multi_index, value); -} - -template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - return root_piece().Set(multi_index, value); -} - -template -/* static */ std::unique_ptr Literal::CreateR0(NativeT value) { +/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { auto literal = MakeUnique(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); literal->Set({}, value); @@ -1195,7 +334,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR1( +/* static */ std::unique_ptr LiteralUtil::CreateR1( tensorflow::gtl::ArraySlice values) { auto literal = MakeUnique( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), @@ -1205,7 +344,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR2WithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( @@ -1218,13 +357,13 @@ template } template -/* static */ std::unique_ptr Literal::CreateR2( +/* static */ std::unique_ptr LiteralUtil::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr Literal::CreateR3WithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -1249,14 +388,14 @@ template } template -/* static */ std::unique_ptr Literal::CreateR3( +/* static */ std::unique_ptr LiteralUtil::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr Literal::CreateR4WithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( std::initializer_list>>> values, @@ -1287,7 +426,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateSparse( +/* static */ std::unique_ptr LiteralUtil::CreateSparse( tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, tensorflow::gtl::ArraySlice values, bool sort) { int64 num_elements = values.size(); @@ -1302,7 +441,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR4( +/* static */ std::unique_ptr LiteralUtil::CreateR4( std::initializer_list>>> values) { @@ -1310,7 +449,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateFromArrayWithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), @@ -1320,38 +459,40 @@ template } template -/* static */ std::unique_ptr Literal::CreateFromArray( +/* static */ std::unique_ptr LiteralUtil::CreateFromArray( const Array& values) { return CreateFromArrayWithLayout( values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); } template -/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +/* static */ std::unique_ptr +LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr Literal::CreateR2FromArray2D( +/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( const Array2D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { +/* static */ std::unique_ptr +LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr Literal::CreateR3FromArray3D( +/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( const Array3D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr Literal::CreateR3Projected( +/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -1376,7 +517,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR4Projected( +/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -1404,49 +545,21 @@ template } template -/* static */ std::unique_ptr Literal::CreateR4FromArray4D( +/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( const Array4D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { +/* static */ std::unique_ptr +LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } -template -NativeT LiteralBase::GetFirstElement() const { - return data().at(0); -} - -template -NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index) const { - CHECK( - LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); - return data(shape_index)[sparse_element_number]; -} - -template -void Literal::AppendSparseElement( - tensorflow::gtl::ArraySlice multi_index, NativeT value, - const ShapeIndex& shape_index) { - Piece& p = piece(shape_index); - const Shape& subshape = p.subshape(); - CHECK(LayoutUtil::IsSparseArray(subshape)); - int64 rank = ShapeUtil::Rank(subshape); - CHECK_EQ(multi_index.size(), rank); - int64 last_element = p.sparse_indices()->index_count(); - CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); - p.sparse_indices()->Append(multi_index); - CHECK_LT(last_element, p.data().size()); - p.data()[last_element] = value; -} - // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr Literal::MakeIdentityR2(int64 size) { +/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -1455,174 +568,8 @@ template } template -void LiteralBase::EachCell( - std::function indices, - NativeT value)> - per_cell) const { - if (ShapeUtil::IsZeroElementArray(shape())) { - return; - } - std::vector indices(ShapeUtil::Rank(shape()), 0); - do { - per_cell(indices, Get(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); -} - -template -inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); - CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - for (int64 i = 0; i < values.size(); ++i) { - Set({i}, values[i]); - } -} - -template -void Literal::PopulateR2( - std::initializer_list> values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 2); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - - const int64 dim0_size = values.size(); - const int64 dim1_size = values.begin()->size(); - CHECK_EQ(dim0_size, shape().dimensions(0)); - CHECK_EQ(dim1_size, shape().dimensions(1)); - - int64 dim0 = 0; - for (auto inner_list : values) { - int64 dim1 = 0; - for (auto value : inner_list) { - Set({dim0, dim1}, value); - ++dim1; - } - CHECK_EQ(dim1_size, dim1); - ++dim0; - } -} - -template -void Literal::PopulateFromArray(const Array& values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); - for (int dim = 0; dim < values.num_dimensions(); ++dim) { - CHECK_EQ(values.dim(dim), shape().dimensions(dim)); - } - values.Each([this](tensorflow::gtl::ArraySlice indices, - NativeT value) { this->Set(indices, value); }); -} - -template -void Literal::PopulateR2FromArray2D(const Array2D& values) { - PopulateFromArray(values); -} - -template -void Literal::PopulateR3FromArray3D(const Array3D& values) { - PopulateFromArray(values); -} - -template -void Literal::PopulateR4FromArray4D(const Array4D& values) { - PopulateFromArray(values); -} - -template -void Literal::PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort) { - CHECK(LayoutUtil::IsSparseArray(shape())); - int rank = ShapeUtil::Rank(shape()); - CHECK_EQ(indices.rank(), rank); - int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); - CHECK_LE(indices.max_indices(), max_elements); - int64 num_elements = values.size(); - CHECK_LE(num_elements, max_elements); - CHECK_EQ(num_elements, indices.index_count()); - auto root_data = root_piece().data(); - // Piece::data() returns an ArraySlice of size equal to the number of indices - // in the SparseIndexArray. So there is no need to adjust the size of the data - // here. It is enough to just copy the incoming values into the data buffer. - std::copy(values.begin(), values.end(), root_data.begin()); - *this->root_piece().sparse_indices() = std::move(indices); - if (sort) { - auto root_data = this->root_piece().data(); - this->root_piece().sparse_indices()->SortWithValues(root_data); - } - DCHECK(this->root_piece().sparse_indices()->Validate(shape())); -} - -template -Status Literal::PopulateInternal(const FnType& generator, bool parallel) { - const Shape& this_shape = shape(); - const int64 rank = ShapeUtil::Rank(this_shape); - TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); - TF_RET_CHECK(this_shape.element_type() == - primitive_util::NativeToPrimitiveType()); - tensorflow::gtl::MutableArraySlice literal_data = data(); - if (rank > 0) { - StrideConfig stride_config(this_shape, this_shape, - AsInt64Slice(this_shape.dimensions())); - int64 minor_dimension_size = - ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - - auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { - DimensionVector minor_scan_indexes(rank, 0); - const int64 index = - IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); - std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); - for (int64 i = 0; i < minor_dimension_size; ++i) { - minor_scan_indexes[stride_config.minor_dimension] = i; - literal_data.at(index + i) = generator(minor_scan_indexes); - } - }; - if (parallel) { - ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base, - stride_config.dimensions, - stride_config.step, init_function); - } else { - ShapeUtil::ForEachIndex( - this_shape, stride_config.base, stride_config.dimensions, - stride_config.step, - [&init_function](tensorflow::gtl::ArraySlice indexes) { - init_function(indexes); - return true; - }); - } - } else { - // For scalars. - literal_data.at(0) = generator({}); - } - return Status::OK(); -} -template -Status Literal::Populate(const FnType& generator) { - return PopulateInternal(generator, /*parallel=*/false); -} - -template -Status Literal::PopulateParallel(const FnType& generator) { - return PopulateInternal(generator, /*parallel=*/true); -} - -template -void Literal::PopulateWithValue(NativeT value) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - for (NativeT& element : data()) { - element = value; - } -} - -template -/* static */ std::unique_ptr Literal::CreateFullWithDescendingLayout( +/* static */ std::unique_ptr +LiteralUtil::CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { auto literal = MakeUnique(ShapeUtil::MakeShapeWithDescendingLayout( primitive_util::NativeToPrimitiveType(), dimensions)); @@ -1630,44 +577,9 @@ template return literal; } -template -std::unique_ptr LiteralBase::Replicate(int64 times) const { - DimensionVector bounds = {times}; - bounds.reserve(shape().dimensions_size() + 1); - for (int64 bound : shape().dimensions()) { - bounds.push_back(bound); - } - auto literal = - MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); - int64 elements = ShapeUtil::ElementsIn(literal->shape()); - if (elements == 0) { - return literal; - } - - DimensionVector output_indices(bounds.size(), 0); - tensorflow::gtl::ArraySlice input_indices = output_indices; - input_indices.remove_prefix(1); - - bool done = false; - while (!done) { - const auto element = Get(input_indices); - literal->Set(output_indices, element); - - done = true; - for (int n = 0; n < output_indices.size(); ++n) { - ++output_indices[n]; - if (output_indices[n] < bounds[n]) { - done = false; - break; - } - output_indices[n] = 0; - } - } - return literal; -} - template -/* static */ StatusOr> Literal::CreateRandomLiteral( +/* static */ StatusOr> +LiteralUtil::CreateRandomLiteral( const Shape& shape, const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; @@ -1681,8 +593,9 @@ template } template -/* static */ StatusOr> Literal::CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev) { +/* static */ StatusOr> +LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, + T stddev) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; std::normal_distribution generator(mean, stddev); return CreateRandomLiteral( @@ -1692,8 +605,8 @@ template } template -/* static */ StatusOr> Literal::CreateRandomLiteral( - const Shape& shape, T mean, T stddev) { +/* static */ StatusOr> +LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { std::minstd_rand0 engine; return CreateRandomLiteral(shape, &engine, mean, stddev); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 857aae0a798..6b7fd10d63f 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 45a9fe01278..98dccaa9a24 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 22cc4e2436e..fe346f9956a 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -33,6 +33,7 @@ cc_library( srcs = ["numpy_bridge.cc"], hdrs = ["numpy_bridge.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -70,7 +71,7 @@ tf_py_wrap_cc( deps = [ ":local_computation_builder", ":numpy_bridge", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index c44e69e6153..afdea88cb7d 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -109,7 +109,7 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 68648a3a176..71351abd593 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" @@ -374,7 +375,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element)); elements.push_back(std::move(literal)); } - return Literal::MakeTupleOwned(std::move(elements)); + return LiteralUtil::MakeTupleOwned(std::move(elements)); } else if (PyArray_Check(o)) { PyArrayObject* py_array = reinterpret_cast(o); int rank = PyArray_NDIM(py_array); @@ -383,7 +384,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { dimensions[i] = PyArray_DIM(py_array, i); } int np_type = PyArray_TYPE(py_array); - auto literal = Literal::CreateFromDimensions( + auto literal = LiteralUtil::CreateFromDimensions( NumpyTypeToPrimitiveType(np_type), dimensions); TF_RETURN_IF_ERROR( CopyNumpyArrayToLiteral(np_type, py_array, literal.get())); diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 64f0aae0f97..a67c93a4fb7 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -25,7 +25,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/python/lib/core/numpy.h" diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index c289c84cff7..6397f1f4791 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -510,8 +511,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( std::pair lhs_dilation, std::pair rhs_dilation, ConvolutionDimensionNumbers dnums) { HloComputation::Builder b("ConvArray4DGeneralDimensionDilated"); - auto lhs_literal = Literal::CreateR4FromArray4D(lhs); - auto rhs_literal = Literal::CreateR4FromArray4D(rhs); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs); std::array ordered_kernel_strides; std::array ordered_input_dimensions; diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 9da9bc60a20..8091bed4996 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -53,7 +53,7 @@ class ReferenceUtilTest : public ::testing::Test { TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -65,7 +65,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -73,7 +73,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); - auto actual_literal = Literal::CreateR1(*result); + auto actual_literal = LiteralUtil::CreateR1(*result); LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, ErrorSpec(0.0001)); } @@ -81,13 +81,13 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) { TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); - auto actual_literal = Literal::CreateR1(*result); + auto actual_literal = LiteralUtil::CreateR1(*result); LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { - auto result = Literal::CreateR1(ReferenceUtil::Reduce4DTo1D( + auto result = LiteralUtil::CreateR1(ReferenceUtil::Reduce4DTo1D( Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, [](float a, float b) { return a + b; })); LiteralTestUtil::ExpectR1Equal({0}, *result); @@ -96,7 +96,7 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, ErrorSpec(0.0001)); } @@ -106,7 +106,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { return value + row + col; }; auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -117,7 +117,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) { input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); @@ -134,7 +134,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); }; auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); @@ -144,7 +144,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { TEST_F(ReferenceUtilTest, SliceArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}}); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, *actual_literal, ErrorSpec(0.0001)); @@ -152,7 +152,7 @@ TEST_F(ReferenceUtilTest, SliceArray2D) { TEST_F(ReferenceUtilTest, SliceStridedArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}}); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, *actual_literal, ErrorSpec(0.0001)); @@ -164,7 +164,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) { auto result = ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 2, 2}}, {{1, 1, 1}}); - auto actual_literal = Literal::CreateR3FromArray3D(*result); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal, @@ -177,7 +177,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) { auto result = ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 3, 4}}, {{1, 2, 2}}); - auto actual_literal = Literal::CreateR3FromArray3D(*result); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, @@ -190,7 +190,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) { auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 2, 2, 2}}, {{1, 1, 1, 1}}); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); LiteralTestUtil::ExpectR4Near( {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}}, @@ -203,7 +203,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) { auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 3, 4, 5}}, {{1, 2, 2, 2}}); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); LiteralTestUtil::ExpectR4Near( {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}}, @@ -218,7 +218,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kSame); Array3D expected = {{{17, 28, 39, 20}}}; - auto actual_literal = Literal::CreateR3FromArray3D(*actual); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -231,7 +231,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) { ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kValid); Array3D expected = {{{17, 28, 39}}}; - auto actual_literal = Literal::CreateR3FromArray3D(*actual); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -266,7 +266,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { })); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -300,7 +300,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { })); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -356,7 +356,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { }}); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -409,7 +409,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { Array4D expected({{{{2514, 2685}}}}); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -422,7 +422,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { auto actual = ReferenceUtil::ApplyElementwise2D( [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); - auto actual_literal = Literal::CreateR2FromArray2D(*actual); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, *actual_literal, ErrorSpec(0.0001)); } diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index f8414468bd9..90efee50b4f 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -97,7 +97,7 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; std::unique_ptr expected_literal = - Literal::CreateR1(expected); + LiteralUtil::CreateR1(expected); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index fe99f700d23..eacc764f8f6 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -136,7 +136,7 @@ cc_library( ":hlo_dce", ":hlo_pass", ":tuple_simplifier", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -227,6 +227,7 @@ cc_library( ":hlo", ":hlo_query", ":shape_inference", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -244,7 +245,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_evaluator", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -294,6 +295,7 @@ cc_library( ":hlo_reachability", ":name_uniquer", "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_tree", @@ -396,6 +398,7 @@ tf_cc_test( deps = [ ":hlo_matchers", ":hlo_parser", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -407,7 +410,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_parser", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -424,7 +427,7 @@ tf_cc_test( srcs = ["hlo_sharding_test.cc"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -453,7 +456,7 @@ tf_cc_test( srcs = ["call_graph_test.cc"], deps = [ ":call_graph", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -502,7 +505,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -521,7 +524,7 @@ tf_cc_test( deps = [ ":call_graph", ":flatten_call_graph", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -797,7 +800,7 @@ cc_library( hdrs = ["transfer_manager.h"], deps = [ ":shaped_buffer", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -960,7 +963,7 @@ tf_cc_test( ":hlo", ":hlo_ordering", ":hlo_scheduling", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1038,7 +1041,7 @@ tf_cc_test( ":hlo_ordering", ":hlo_value", ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1121,7 +1124,7 @@ cc_library( hdrs = ["hlo_query.h"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", ], ) @@ -1170,6 +1173,7 @@ cc_library( deps = [ ":hlo", ":shape_inference", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1200,6 +1204,7 @@ cc_library( deps = [ ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1219,6 +1224,7 @@ cc_library( ":hlo_creation_utils", ":hlo_pass", ":while_util", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", ], @@ -1233,7 +1239,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1255,6 +1261,7 @@ cc_library( ":hlo_pass", ":hlo_query", ":pattern_matcher", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1274,7 +1281,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1310,7 +1317,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1345,7 +1352,7 @@ cc_library( ":call_inliner", ":hlo", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1361,6 +1368,7 @@ tf_cc_test( ":conditional_simplifier", ":hlo", ":hlo_matchers", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1420,7 +1428,7 @@ tf_cc_test( deps = [ ":defuser", ":hlo_matchers", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", ], @@ -1448,7 +1456,7 @@ tf_cc_test( deps = [ ":hlo_matchers", ":implicit_broadcast_remover", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", ], @@ -1490,7 +1498,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":tuple_simplifier", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1505,7 +1513,7 @@ cc_library( hdrs = ["reshape_mover.h"], deps = [ ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -1520,7 +1528,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":reshape_mover", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1555,7 +1563,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":inliner", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", @@ -1572,7 +1580,7 @@ cc_library( hdrs = ["computation_placer.h"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -1604,7 +1612,7 @@ cc_library( hdrs = ["generic_transfer_manager.h"], deps = [ ":transfer_manager", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1695,7 +1703,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_matchers", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1710,6 +1718,7 @@ tf_cc_binary( deps = [ ":hlo", ":hlo_graph_dumper", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1724,7 +1733,7 @@ tf_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", @@ -1822,7 +1831,7 @@ tf_cc_test( ":hlo_matchers", ":hlo_ordering", ":instruction_fusion", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -1859,7 +1868,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_liveness_analysis", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -1920,7 +1929,7 @@ tf_cc_test( ":hlo_matchers", ":hlo_ordering", ":instruction_fusion", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1955,6 +1964,7 @@ cc_library( ":hlo_dataflow_analysis", ":logical_buffer", ":logical_buffer_analysis", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1973,6 +1983,7 @@ tf_cc_test( ":hlo_matchers", ":instruction_fusion", ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2044,7 +2055,7 @@ tf_cc_test( ":hlo_graph_dumper", ":hlo_matchers", ":hlo_runner", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -2169,6 +2180,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_dce", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -2189,7 +2201,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_module_dce", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -2213,7 +2225,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":layout_assignment", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2272,7 +2284,7 @@ cc_library( ":hlo", ":hlo_domain_map", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -2288,7 +2300,7 @@ tf_cc_test( ":hlo", ":hlo_cse", ":hlo_matchers", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -2310,7 +2322,7 @@ cc_library( ":hlo_evaluator", ":hlo_pass", ":hlo_query", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", @@ -2325,7 +2337,7 @@ tf_cc_test( ":hlo_constant_folding", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -2417,7 +2429,7 @@ cc_library( ":hlo_evaluator", ":hlo_pass", ":hlo_query", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", @@ -2552,7 +2564,7 @@ cc_library( hdrs = ["hlo_tfgraph_builder.h"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", @@ -2583,7 +2595,7 @@ cc_library( ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:window_util", @@ -2601,6 +2613,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_graph_dumper", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/tests:test_utils", @@ -2632,7 +2645,7 @@ tf_cc_test( ":hlo_matchers", ":shape_inference", ":transpose_folding", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -2653,7 +2666,7 @@ cc_library( deps = [ ":hlo", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -2668,7 +2681,7 @@ tf_cc_test( ":hlo", ":shape_inference", ":zero_sized_hlo_elimination", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -2828,6 +2841,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":tuple_util", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/core:lib", ], ) @@ -2963,6 +2977,7 @@ cc_library( ":hlo", ":hlo_lexer", ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 1ddeb27e404..af7728da549 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -195,7 +196,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::Zero(hlo->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -537,8 +538,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = - MakeUnique(constant->literal().GetFirstScalarLiteral()); + std::unique_ptr unique_scalar = MakeUnique( + LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); return ReplaceWithNewInstruction( @@ -1093,7 +1094,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { ShapeUtil::IsZeroElementArray(lhs->shape()) || ShapeUtil::IsZeroElementArray(rhs->shape())) { auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } @@ -1519,7 +1520,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( - Literal::One(power->shape().element_type()).CloneToUnique()); + LiteralUtil::One(power->shape().element_type()).CloneToUnique()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -1554,7 +1555,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(rhs->shape().element_type()).CloneToUnique())); + LiteralUtil::One(rhs->shape().element_type()).CloneToUnique())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -2098,7 +2099,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction::CreateBroadcast( convolution->shape(), computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::Zero(convolution->shape().element_type()) + LiteralUtil::Zero(convolution->shape().element_type()) .CloneToUnique())), {})); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index b733f6f59eb..92bbcbd740f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -60,7 +60,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); @@ -79,7 +79,7 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { HloComputation::Builder builder(TestName()); // Create add computation. HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); HloComputation* add_computation = nullptr; { HloComputation::Builder builder(TestName() + ".add"); @@ -119,7 +119,7 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); @@ -140,9 +140,9 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.14159f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f))); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1)); @@ -165,7 +165,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); HloInstruction* bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r2f32, zero, {0, 1})); builder.AddInstruction( @@ -200,7 +200,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateMap( r2f32, {param0, builder.AddInstruction( @@ -223,7 +223,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0, 0}))); HloInstruction* bcast = builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1})); builder.AddInstruction( @@ -242,7 +242,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({3.14f, 3.14f, 3.14f}))); + LiteralUtil::CreateR1({3.14f, 3.14f, 3.14f}))); auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -258,7 +258,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({3.14, 3.14, 4}))); + LiteralUtil::CreateR1({3.14, 3.14, 4}))); auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -277,7 +277,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); @@ -298,7 +298,7 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kSubtract, param0, constant)); @@ -493,7 +493,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 1.f, 2.f}))); + LiteralUtil::CreateR1({0.f, 1.f, 2.f}))); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); @@ -559,7 +559,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); @@ -580,7 +580,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); @@ -860,7 +860,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); @@ -884,7 +884,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); @@ -912,7 +912,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); @@ -934,7 +934,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); @@ -956,7 +956,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* negative_one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(-1))); builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); @@ -1047,7 +1047,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { builder.AddInstruction(HloInstruction::CreateReduceWindow( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), window, add_computation)); module().AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, @@ -1074,7 +1074,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), padding)); module().AddEntryComputation(builder.Build()); EXPECT_THAT(module().entry_computation()->root_instruction(), @@ -1116,7 +1116,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -1208,7 +1208,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1})); @@ -1238,7 +1238,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1})); @@ -1420,7 +1420,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")), builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{0, 0}, {0, 0}}))))); + LiteralUtil::CreateR2({{0, 0}, {0, 0}}))))); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); @@ -1443,7 +1443,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")), builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{0, 0}, {0, 0}}))))); + LiteralUtil::CreateR2({{0, 0}, {0, 0}}))))); builder.AddInstruction( HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, @@ -1726,7 +1726,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 2}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); PaddingConfig no_padding; for (int i = 0; i < 2; ++i) { auto dimension = no_padding.add_dimensions(); @@ -1757,7 +1757,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {10, 10}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); PaddingConfig padding; int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {2, -3}; @@ -2109,7 +2109,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloComputation::Builder builder(TestName()); HloInstruction* forty_two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); HloInstruction* broadcast = builder.AddInstruction( @@ -2156,7 +2156,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { padding.mutable_dimensions(3)->set_edge_padding_high(2); HloInstruction* pad_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding)); @@ -2187,7 +2187,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { const Shape reduce_window_shape = ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); HloInstruction* reduce_init_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* reduce_window = builder.AddInstruction(HloInstruction::CreateReduceWindow( reduce_window_shape, pad, reduce_init_value, window, @@ -2238,7 +2238,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { padding.mutable_dimensions(3)->set_edge_padding_high(2); HloInstruction* pad_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding)); @@ -2273,7 +2273,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { const Shape reduce_window_shape = ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); HloInstruction* reduce_init_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* reduce_window = builder.AddInstruction(HloInstruction::CreateReduceWindow( reduce_window_shape, convert, reduce_init_value, window, @@ -2344,9 +2344,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloComputation::Builder call_builder(TestName() + ".Call"); HloInstruction* zero = call_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0.0f}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0.0f}))); HloInstruction* one = call_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1.0f}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0f}))); call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); @@ -2362,9 +2362,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; - std::unique_ptr value = - Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), - Literal::CreateR1(constant_vector).get()}); + std::unique_ptr value = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar).get(), + LiteralUtil::CreateR1(constant_vector).get()}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto computation = module().AddEntryComputation(builder.Build()); @@ -2387,8 +2387,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { shape, builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "slice_from")), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))), + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({0, 0, 0}))), /*slice_sizes=*/{10, 100, 1000})); auto computation = module().AddEntryComputation(builder.Build()); @@ -2421,8 +2421,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction( HloInstruction::CreateParameter(2, slice_shape, "to_update")), slice, - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))))); + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({0, 0, 0}))))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -2437,7 +2437,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* input_array = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({3, 4}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({3, 4}))); HloInstruction* inner_bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r2f32, input_array, {1})); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); @@ -2546,7 +2546,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( pad_shape, input, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), padding)); HloComputation* add_computation = nullptr; @@ -2565,7 +2565,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { Window window = window_util::MakeWindow( decorate_spatials(param.reduce_window_spatials, 1, 1)); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ShapeInference::InferReduceWindowShape( pad->shape(), zero->shape(), window, @@ -2704,7 +2704,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); auto* lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k))); Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n}); @@ -2783,7 +2783,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); auto* rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n))); DotDimensionNumbers dot_dnums; @@ -2830,7 +2830,7 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { HloInstruction* const update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); HloInstruction* const start_indices = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0}))); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( dslice_shape, operand, update, start_indices)); const HloComputation* const computation = @@ -2879,7 +2879,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase); Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); auto* lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, /*cols=*/lhs_cols))); @@ -2887,7 +2887,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int32 start_col = (spec.lcd == 0) ? spec.s : 0; const auto start_indices = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR1({start_row, start_col}))); int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); @@ -2898,7 +2898,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); auto* rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, /*cols=*/rhs_cols))); @@ -2946,7 +2946,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k; Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); auto* lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, /*cols=*/lhs_cols))); @@ -2957,7 +2957,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase); Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); auto* rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, /*cols=*/rhs_cols))); @@ -2965,7 +2965,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int32 start_col = (spec.rcd == 0) ? spec.s : 0; const auto start_indices = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR1({start_row, start_col}))); int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index ec13fadbc75..aed5832eee3 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -97,7 +98,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { add_instruction(HloInstruction::CreateConvert( ShapeUtil::MakeShape(operand->shape().element_type(), {}), add_instruction(HloInstruction::CreateConstant( - Literal::CreateR0(-0.5f))))), + LiteralUtil::CreateR0(-0.5f))))), {})); return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, operand, exponent); @@ -113,7 +114,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { add_instruction(HloInstruction::CreateConvert( ShapeUtil::MakeShape(operand->shape().element_type(), {}), add_instruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0 / element_count))))), + LiteralUtil::CreateR0(1.0 / element_count))))), {})); return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, operand, elem_count_recip); @@ -200,11 +201,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction* offset = batch_norm->mutable_operand(2); const Shape feature_shape = scale->shape(); - auto zero_literal = Literal::CreateR0(0.0f); + auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( operand_shape, @@ -320,7 +321,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( HloInstruction* var = batch_norm->mutable_operand(4); const Shape feature_shape = scale->shape(); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( operand_shape, @@ -447,11 +448,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 feature_count = activation_shape.dimensions(feature_index); const int64 elements_per_feature_int64 = size_in_elements / feature_count; - auto zero_literal = Literal::CreateR0(0.0f); + auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); @@ -542,7 +543,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); auto elements_per_feature_literal = - Literal::CreateR0(elements_per_feature_int64); + LiteralUtil::CreateR0(elements_per_feature_int64); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, elements_per_feature_literal->Convert(ptype)); auto elements_per_feature = add( diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index aa36e64b070..90967922373 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index ff6d5027efb..b6f3c84c7e6 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 2124b302ccc..23aa83ea882 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -133,9 +133,9 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { array_b.FillUnique(10.0f); HloInstruction* a = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateFromArray(array_a))); + HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a))); HloInstruction* b = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateFromArray(array_b))); + HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); @@ -150,10 +150,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)), + *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)), dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)), + *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)), dot->operand(1)->literal())); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 6958ee722a8..eb19babf770 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" @@ -125,7 +125,7 @@ class BufferAssignmentTest : public HloTestBase { auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); auto value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value)); return builder.Build(); @@ -142,7 +142,7 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( @@ -167,9 +167,9 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto constv = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto indexc = builder.AddInstruction( @@ -290,7 +290,7 @@ static bool BuffersDistinct(const std::vector& a, TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -304,9 +304,9 @@ TEST_F(BufferAssignmentTest, BufferForConst) { // no buffers assigned, and their consumer has a buffer. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); + LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); auto module = CreateNewModule(); @@ -327,7 +327,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) { auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, f32vec100_, "param0")); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); auto tuple = builder.AddInstruction( @@ -352,7 +352,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { // This computation copies a constant to output. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); auto module = CreateNewModule(); @@ -660,7 +660,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto exp2 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1)); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( /*shape=*/f32vec10_, /*operand=*/exp2, @@ -708,9 +708,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto const3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({const3, const4})); auto while_op = builder.AddInstruction(HloInstruction::CreateWhile( @@ -773,11 +773,11 @@ TEST_F(BufferAssignmentTest, ExampleConditional) { auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.4f))); auto const2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.4f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.4f))); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( r0f32_, pred, const1, true_computation, const2, false_computation)); module->AddEntryComputation(builder.Build()); @@ -1200,8 +1200,9 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output // is properly handled. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple( - {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}))); + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), + LiteralUtil::CreateR0(1).get()}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1584,7 +1585,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { auto b = HloComputation::Builder(TestName() + ".cond"); b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); condition = module->AddEmbeddedComputation(b.Build()); } HloComputation* body; @@ -1647,9 +1648,9 @@ class WhileBufferAssignmentTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto ten = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(10))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); return builder.Build(); @@ -1708,7 +1709,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateParameter(2, data_shape_, "weights1")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( @@ -1851,7 +1852,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto build_cond = [&]() { auto builder = HloComputation::Builder("cond"); auto const4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); builder.AddInstruction(HloInstruction::CreateBinary( @@ -1863,7 +1864,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto build_body = [&]() { auto builder = HloComputation::Builder("body"); auto const9 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(9))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(9))); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); builder.AddInstruction( @@ -1891,7 +1892,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { HloInstruction::CreateWhile(r0s32, cond1, body1, while0)); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero)); auto cond2 = module->AddEmbeddedComputation(build_cond()); @@ -1953,7 +1954,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); @@ -1997,16 +1998,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param")); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1)); sub_computation = module->AddEmbeddedComputation(builder.Build(add)); } auto builder = HloComputation::Builder(TestName()); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto call1 = builder.AddInstruction( HloInstruction::CreateCall(r0f32, {constant2}, sub_computation)); auto call2 = builder.AddInstruction( @@ -2058,9 +2059,9 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto input0 = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape_, "input0")); @@ -2142,7 +2143,7 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 7833ebe73ba..c5e4b72fbcd 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -439,11 +439,13 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = Literal::MakeTuple( - {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}); - auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0(3).get()}); + auto inner_tuple0 = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), + LiteralUtil::CreateR0(1).get()}); + auto inner_tuple1 = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0->shape(), tuple_constant, 0)); @@ -491,7 +493,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -503,7 +505,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element1_shape, tuple_param0, 1)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1)); @@ -555,7 +557,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -627,7 +629,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); HloInstruction* slice = nullptr; if (update_uses_tuple_element1) { // Create a slice instruction as an additional user of 'gte1'. @@ -638,7 +640,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -757,7 +759,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); if (tuple_element1_has_two_uses) { // Add 'gte0' and 'gte1' to create another user of 'gte1'. @@ -766,7 +768,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 1ea7d538cd5..cc80b748431 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -82,7 +82,7 @@ class CallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -247,11 +247,11 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation::Builder builder(TestName()); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloInstruction* const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.4f))); HloInstruction* const2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.6f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.6f))); HloInstruction* conditional = builder.AddInstruction(HloInstruction::CreateConditional( kScalarShape, pred, const1, true_computation, const2, diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 924348c870b..dcec2babcb8 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,9 +48,9 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // the "one" value. HloComputation::Builder inner(TestName() + ".inner"); HloInstruction* zero = inner.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(24.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(24.0f))); HloInstruction* one = inner.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); TF_ASSERT_OK(zero->AddControlDependencyTo(one)); auto module = CreateNewModule(); HloComputation* inner_computation = @@ -87,7 +87,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { // little trickier. HloComputation::Builder just_false(TestName() + ".false"); just_false.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* false_computation = module->AddEmbeddedComputation(just_false.Build()); @@ -99,7 +99,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { HloComputation::Builder outer(TestName() + ".outer"); HloInstruction* init_value = outer.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); outer.AddInstruction( HloInstruction::CreateWhile(pred, call_false, call_false, init_value)); @@ -123,9 +123,9 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { HloComputation::Builder just_false(TestName() + ".false"); auto* true_constant = just_false.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({true}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({true}))); auto* false_constant = just_false.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant)); HloComputation* false_computation = module->AddEmbeddedComputation(just_false.Build()); @@ -147,7 +147,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { HloComputation::Builder outfeeder(TestName() + ".outfeeder"); auto value = outfeeder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto token = outfeeder.AddInstruction(HloInstruction::CreateAfterAll({})); outfeeder.AddInstruction( HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/"")); diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 7c1bacff92b..d26486fcfe0 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index e9ec796121f..b7be3ba605a 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 68f6ffc6b70..834878426f5 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -55,7 +55,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { true_computation_builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {}), "param")); auto one = true_computation_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); true_computation_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one)); @@ -73,7 +73,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "param")); auto forty_two = false_computation_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); false_computation_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two)); @@ -82,11 +82,11 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { } auto false_instrn = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto false_param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {}), "false_param")); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); builder.AddInstruction(HloInstruction::CreateConditional( ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation, @@ -106,7 +106,7 @@ TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { HloComputation* computation = MakeConditional(&module()); auto* true_op = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK( true_op->AddControlDependencyTo(computation->root_instruction())); @@ -123,7 +123,7 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { true_computation->AddInstruction(HloInstruction::CreateAfterAll({})); auto* send = true_computation->AddInstruction(HloInstruction::CreateSend( true_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 105d117cacc..cd735256b83 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -108,7 +108,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); @@ -132,7 +132,7 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}))); + LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}))); auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape()); Layout reversed_layout = LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major); @@ -167,9 +167,9 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -197,11 +197,11 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { // the computation result. Verify that copies are added properly. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -209,7 +209,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction::CreateTuple({constant3, constant2})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); @@ -255,8 +255,9 @@ TEST_F(CopyInsertionTest, BitcastConstant) { // The output of a bitcast is its operand (same buffer), so a bitcast // constant feeding the result must have a copy added. auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1.0, 42.0}))); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.0, 42.0}))); HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); @@ -370,9 +371,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { // copy is added. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -380,7 +381,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction::CreateTuple({constant2, constant1})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); HloInstruction* gte = @@ -413,7 +414,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { const Shape& loop_state_shape) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(10))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); auto induction_variable = @@ -442,7 +443,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(1). @@ -480,7 +481,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -549,7 +550,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); @@ -564,8 +565,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { data = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); } - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); @@ -598,7 +600,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( gte0->shape(), HloOpcode::kAdd, gte0, inc)); @@ -608,8 +610,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // GTE(GTE(loop_state, 1), 0) -> Add auto gte10 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0)); - auto update10 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update10 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add10 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, gte10, update10)); @@ -633,10 +636,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".While"); auto induction_var_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); - auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); if (nested) { auto inner_init = builder.AddInstruction( @@ -659,8 +663,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToConstant() { auto builder = HloComputation::Builder(TestName() + ".While"); - auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, &builder); } @@ -677,11 +682,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto v1 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto v2 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); @@ -689,7 +694,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); @@ -701,7 +706,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto one_vec = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto data_init = @@ -714,11 +719,12 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToInterfering() { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto data_init = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto one_vec = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // Take a reference to 'data_init' to make it interfere with while result. auto add = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data_init, one_vec)); @@ -750,7 +756,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { const bool nested = ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto condition = module_->AddEmbeddedComputation( BuildConditionComputation(loop_state_shape)); auto body = module_->AddEmbeddedComputation( @@ -1252,7 +1258,6 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto loop_init = builder.AddInstruction( HloInstruction::CreateTuple({iter_param, data_param, data_param})); - // Two while loops shares the same loop init tuple. auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape, condition1, body1, loop_init)); @@ -1310,7 +1315,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1318,9 +1323,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -1375,7 +1380,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1383,9 +1388,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -1435,7 +1440,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1443,7 +1448,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); builder.AddInstruction( @@ -1520,7 +1525,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1575,14 +1580,14 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { body_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0))); HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module->AddEmbeddedComputation(cond_builder.Build()); @@ -1644,7 +1649,7 @@ std::unique_ptr MakeTrivialCondition(const Shape& shape) { builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "loop_state")); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kNot, constant)); return builder.Build(); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 3479240610a..a45de4c479b 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -37,6 +37,7 @@ cc_library( srcs = ["cpu_transfer_manager.cc"], hdrs = ["cpu_transfer_manager.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -72,7 +73,7 @@ cc_library( ":ir_emitter", ":parallel_task_assignment", ":simple_orc_jit", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -355,7 +356,7 @@ tf_cc_binary( srcs = ["sample_harness.cc"], deps = [ "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -717,7 +718,7 @@ tf_cc_test( deps = [ ":cpu_layout_assignment", ":target_machine_features_fake", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -809,7 +810,7 @@ tf_cc_test( ":cpu_executable", ":parallel_task_assignment", ":target_machine_features_fake", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -892,7 +893,7 @@ tf_cc_test( srcs = ["cpu_copy_insertion_test.cc"], deps = [ ":cpu_copy_insertion", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 375b017b092..547d4c696da 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -60,11 +60,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in CNHW order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kInputFeatureCount, kBatchSize, kInputSize, kInputSize)))); // The kernel dimensions are in OIHW order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); ConvolutionDimensionNumbers dnums; @@ -122,11 +122,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in NHWC order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kBatchSize, kInputSize, kInputSize, kInputFeatureCount)))); // The kernel dimensions are in HWIO order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); ConvolutionDimensionNumbers dnums; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 55962ba70d2..de6178fd528 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -38,7 +38,7 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index a05a2694178..4db7fa446ea 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -74,14 +74,14 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { body_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0))); HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module->AddEmbeddedComputation(cond_builder.Build()); @@ -114,7 +114,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { auto sub_param = sub_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); auto constant = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0))); auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, sub_param, constant)); sub_builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 750310c6332..991b14f17db 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -282,7 +282,7 @@ class OpcodeFusionTest : public InstructionFusionTest { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "arg0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one)); return module->AddEmbeddedComputation(builder.Build()); @@ -595,7 +595,7 @@ TEST_F(OpcodeFusionTest, MessOfFusileNodes) { auto pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(S32, {5}), idx_choice, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), padding_config)); auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 429fc7b7860..3681d12d8da 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index b877b295814..156166bf2b1 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -180,7 +181,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( tensorflow::gtl::ArraySlice dimensions( tensorflow::bit_cast(literal_shape.dimensions().data()), literal_shape.dimensions().size()); - *literal = std::move(*Literal::CreateFromDimensions( + *literal = std::move(*LiteralUtil::CreateFromDimensions( literal_shape.element_type(), dimensions)); TF_ASSIGN_OR_RETURN(Shape received_shape, TransferArrayBufferFromOutfeed( @@ -211,7 +212,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( tensorflow::bit_cast( tuple_element_shape.dimensions().data()), tuple_element_shape.dimensions().size()); - auto empty = Literal::CreateFromDimensions( + auto empty = LiteralUtil::CreateFromDimensions( tuple_element_shape.element_type(), dimensions); int64 size = GetByteSizeRequirement(tuple_element_shape); buffer_data.push_back({empty->untyped_data(), size}); @@ -232,7 +233,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i); } - *literal = std::move(*Literal::MakeTupleOwned(std::move(elements))); + *literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements))); TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 7e792a82b8b..d9e8dcaed98 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -38,12 +38,13 @@ int main(int argc, char** argv) { // Transfer parameters. std::unique_ptr param0_literal = - xla::Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = xla::Literal::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + std::unique_ptr param1_literal = + xla::LiteralUtil::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = client->TransferToServer(*param1_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 66ae5ef0f66..b4c33e2f6ca 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -40,7 +40,7 @@ tf_cc_test( name = "cpu_fusion_test", srcs = ["cpu_fusion_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -82,7 +82,7 @@ tf_cc_test( name = "cpu_noalias_test", srcs = ["cpu_noalias_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -128,7 +128,7 @@ tf_cc_test( name = "cpu_infeed_test", srcs = ["cpu_infeed_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index 1d4bf483aed..00a7aa2ad2f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -40,7 +40,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest { HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(backing_array))); + LiteralUtil::CreateR2FromArray2D(backing_array))); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 783b2820e92..d98856fdbf4 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -43,8 +43,8 @@ class CpuFusionTest : public HloTestBase { TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { auto builder = HloComputation::Builder(TestName()); - auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); - auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); Shape vshape = input_literal1->shape(); auto input1 = builder.AddInstruction( @@ -83,7 +83,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto builder = HloComputation::Builder(TestName()); - auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); Shape vshape = input_literal->shape(); auto input = builder.AddInstruction( @@ -99,7 +99,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( vshape, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))), {})); builder.AddInstruction( HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); @@ -134,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { // middle. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); Shape vshape = input_literal->shape(); auto input = builder.AddInstruction( @@ -166,7 +166,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { ShapeUtil::MakeShape(F32, {6, 1}), concatenate)), /*init_value=*/ builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{1}, add_f32)); auto exp = builder.AddInstruction( @@ -176,7 +176,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( cshape, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))), {})); builder.AddInstruction( HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor)); @@ -231,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // operand vectors. Test for this problem by counting the number of nodes in // each fusion instruction to ensure that negate is not duplicated. auto builder = HloComputation::Builder(TestName()); - auto input_literal = Literal::CreateR1({1.0, 2.0, 3.0}); + auto input_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); Shape vshape = input_literal->shape(); auto constant = builder.AddInstruction( @@ -292,10 +292,10 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { // computation. The duplication is caused by the other use of exp2 in the // tuple. auto builder = HloComputation::Builder(TestName()); - auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); - auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); Shape shape = constant->shape(); auto exp1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index ea7e479d66f..0d45918d099 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*Literal::CreateR0(true)); + TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*Literal::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip( - *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, - r3_dim0minor)); + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0minor)); - TestInfeedRoundTrip( - *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, - r3_dim0major)); + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*Literal::CreateR4( + TestInfeedRoundTrip(*LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(InfeedTest, SingleInfeedTuple) { TestInfeedRoundTrip( - *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), - Literal::CreateR0(false).get()})); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR0(false).get()})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*Literal::MakeTuple({})); + TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); } // Tests Infeed operation used in a while loop, as in the code below. The @@ -156,13 +156,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { }); // Send 5 Infeed data of shape F32[3]. - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({1, 2, 3}))); - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({4, 5, 6}))); - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({7, 8, 9}))); ASSERT_IS_OK( - client_->TransferToInfeed(*Literal::CreateR1({10, 11, 12}))); + client_->TransferToInfeed(*LiteralUtil::CreateR1({1, 2, 3}))); ASSERT_IS_OK( - client_->TransferToInfeed(*Literal::CreateR1({13, 14, 15}))); + client_->TransferToInfeed(*LiteralUtil::CreateR1({4, 5, 6}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1({7, 8, 9}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1({10, 11, 12}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1({13, 14, 15}))); delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); @@ -247,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({3, 4}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({3, 4}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({5, 6}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({5, 6}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({7, 8}).get(), - Literal::CreateR0(false).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8}).get(), + LiteralUtil::CreateR0(false).get()}))); // Asynchronously launch the execution on the device. std::unique_ptr result; @@ -272,14 +275,14 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). sleep(1); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({7, 8, 9}).get(), - Literal::CreateR0(false).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8, 9}).get(), + LiteralUtil::CreateR0(false).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({4, 5, 6}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({4, 5, 6}).get(), + LiteralUtil::CreateR0(true).get()}))); // Wait for the execution to be done, and transfer the result. delete computation_thread; // Joins the thread. diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 3b6b0ed7406..ccb61740f6b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "llvm/IR/Module.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" @@ -42,7 +42,7 @@ TEST_F(CpuNoAliasTest, Concat) { HloComputation::Builder builder(TestName()); std::unique_ptr literal = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* param_x = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "x")); diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index 32b5c5d35fa..e727ba49cb6 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/defuser.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" @@ -124,7 +124,7 @@ TEST_F(DefuserTest, NonTrivialFusionInstruction) { auto div = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); @@ -162,7 +162,7 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { auto div = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 52aa53dcee5..51f16bdc947 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index ecd97a87968..0686ca74afc 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 8980d430335..addb016b048 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -57,8 +57,8 @@ ENTRY main { } )"; - std::unique_ptr lhs = Literal::CreateR3({{{1}, {2}}}); - std::unique_ptr rhs = Literal::CreateR3({{{3}, {4}}}); + std::unique_ptr lhs = LiteralUtil::CreateR3({{{1}, {2}}}); + std::unique_ptr rhs = LiteralUtil::CreateR3({{{3}, {4}}}); RunTest(hlo_text, {lhs.get(), rhs.get()}); } } // namespace diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index d3854b40de3..8f6608241ed 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -80,7 +80,7 @@ class FlattenCallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -157,7 +157,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(PRED, {}), "param0")); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, param0, false_constant)); @@ -168,7 +168,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { { HloComputation::Builder builder(TestName() + ".entry"); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateWhile( ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation, false_constant)); @@ -232,11 +232,11 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { // computation in the true and false branch. HloComputation::Builder builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.0f))); builder.AddInstruction(HloInstruction::CreateConditional( kScalarShape, pred, constant1, sub_computation, constant2, sub_computation)); diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 7cd2c9c136a..e3a42d0d06b 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -113,7 +114,7 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( const Shape& index_shape = index_vector->shape(); HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateFromDimensions(index_shape.element_type(), {1}))); + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); // We extract out individual components from the smaller index and concatenate // them (interspersing zeros as needed) into the larger index. diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 85e28a0dfe3..7490728b448 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d90b0fb57d7..34a5b01ee2c 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -150,7 +150,7 @@ cc_library( ":parallel_loop_emitter", ":partition_assignment", ":while_transformer", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -199,7 +199,7 @@ cc_library( srcs = ["elemental_ir_emitter.cc"], hdrs = ["elemental_ir_emitter.h"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -351,6 +351,7 @@ cc_library( ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", @@ -382,7 +383,7 @@ cc_library( hdrs = ["cudnn_convolution_rewriter.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -517,6 +518,7 @@ cc_library( hdrs = ["pad_insertion.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -533,7 +535,7 @@ cc_library( hdrs = ["gpu_transfer_manager.h"], deps = [ ":gpu_compiler", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -624,6 +626,7 @@ cc_library( hdrs = ["cudnn_batchnorm_rewriter.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", @@ -716,7 +719,7 @@ cc_library( srcs = ["while_transformer.cc"], hdrs = ["while_transformer.h"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc index c77e3c81c9d..60289506524 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -66,11 +67,12 @@ Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); @@ -101,11 +103,12 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); @@ -128,8 +131,8 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev, computation_->AddInstruction(HloInstruction::CreateBroadcast( inverse_stddev->shape(), - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-2))), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(-2))), {})))); HloInstruction* variance = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -169,11 +172,12 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); // The cudnn libcall expects its input to be rsqrt(variance + epsilon), but // the batchnorm HLO takes plain variance as input. Fix it up. @@ -189,7 +193,7 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { computation_->AddInstruction(HloInstruction::CreateBroadcast( var_plus_epsilon->shape(), computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(-.5))), + LiteralUtil::CreateR0(-.5))), {})))); std::vector operands(batch_norm->operands().begin(), diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 3dc98c4c93e..5a63e65208a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -80,8 +81,7 @@ bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape, const ConvolutionDimensionNumbers& dnums, se::StreamExecutor* stream_exec) { // Skip this check for cudnn7 and newer. - auto version = - stream_exec->AsDnn()->GetVersion(); + auto version = stream_exec->AsDnn()->GetVersion(); if (version.ok() && version.ValueOrDie().major_version() >= 7) { return true; } @@ -338,8 +338,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( computation->AddInstruction(HloInstruction::CreateTuple( {computation->AddInstruction(HloInstruction::CreateGetTupleElement( new_call_shape.tuple_shapes(0), new_call, 0)), - computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({})))})); + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({})))})); TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple)); return true; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index f9dccd287d9..905b5ee8767 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 27d2c3e491b..e594cec2f8d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -29,7 +29,7 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index e48165c1426..95f78ae2932 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -132,10 +132,10 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { HloInstruction::CreateParameter(4, aux_shape, "variance")); auto* epsilon = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto* feature_index = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(kFeatureIndex))); + LiteralUtil::CreateR0(kFeatureIndex))); auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( shape, @@ -201,10 +201,10 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { HloInstruction::CreateParameter(2, offset_scale_shape, "offset")); auto* epsilon = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto* feature_index = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(kFeatureIndex))); + LiteralUtil::CreateR0(kFeatureIndex))); auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, @@ -278,10 +278,10 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { HloInstruction::CreateParameter(4, shape, "var")); auto* epsilon = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto* feature_index = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(kFeatureIndex))); + LiteralUtil::CreateR0(kFeatureIndex))); auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 5343497c03c..83d5083b95f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "llvm/IR/DataLayout.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 1963d9eef72..98ba162cd97 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -33,7 +33,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndOperandElementReusingConsumerNotFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* broadcast2 = @@ -53,7 +53,7 @@ TEST_F(InstructionFusionTest, NonCostlyProducerAndOperandElementReusingConsumerFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0)); HloInstruction* broadcast2 = @@ -73,7 +73,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* reshape2 = builder.AddInstruction( @@ -92,7 +92,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* transpose2 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 80208e1c985..1d0b6597ebb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -28,7 +28,7 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index c8f0d4185c6..b22040eee16 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -68,7 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput( PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(element_type)))); + MakeUnique(LiteralUtil::Zero(element_type)))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -125,7 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, PrimitiveType element_type = kernel->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(element_type)))); + MakeUnique(LiteralUtil::Zero(element_type)))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -234,9 +235,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(input->shape().element_type())))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(MakeUnique( + LiteralUtil::Zero(input->shape().element_type())))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index 7749201cbce..c5321df6c46 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 2f290f61bd5..dbc8442ed27 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -42,7 +42,7 @@ class WhileTransformerTest : public HloTestBase { const int64 tuple_index, const int64 limit) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(limit))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(limit))); auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( 0, GetLoopStateShape(tuple_index), "loop_state")); auto induction_variable = @@ -65,8 +65,8 @@ class WhileTransformerTest : public HloTestBase { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, ind_var_tuple_index)); - auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(increment))); + auto inc = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(increment))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(data_tuple_index). @@ -89,10 +89,12 @@ class WhileTransformerTest : public HloTestBase { const int64 ind_var_tuple_index, const int64 ind_var_init) { auto builder = HloComputation::Builder(TestName() + ".While"); - auto induction_var_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(ind_var_init))); - auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto induction_var_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(ind_var_init))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); auto loop_state_init = ind_var_tuple_index == 0 ? builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index acf66114869..aa89567ee86 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -47,7 +48,7 @@ HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.5))); builder.AddInstruction(HloInstruction::CreateBinary( half->shape(), HloOpcode::kAdd, x_value, half)); return module->AddEmbeddedComputation(builder.Build()); @@ -122,7 +123,7 @@ std::unique_ptr MakeBigGraph() { auto rng = builder.AddInstruction( HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m})); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_computation = ScalarSumComputation(module.get()); builder.AddInstruction( HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation)); diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 3849b565e31..b41dc66fe9f 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -239,7 +239,7 @@ class HeapSimulatorTest : public HloTestBase { TEST_F(HeapSimulatorTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); // Constants aren't assigned. See b/32248867 HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0}); @@ -674,7 +674,7 @@ class HeapAlgorithmTestBase : public ::testing::Test { const BufferValue* DummyBufferValue() { const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); return buffers_.back().get(); } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index a59bf1750c0..403d4df6b50 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -116,9 +116,9 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { // Test the analysis on a single binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); module_->AddEntryComputation(builder.Build()); @@ -228,9 +228,9 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); @@ -267,9 +267,9 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( @@ -346,15 +346,15 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -439,15 +439,15 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -498,7 +498,7 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); return cond_builder.Build(); }; // Build separate condition computations so the call graph is flat. The @@ -543,9 +543,9 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( @@ -608,17 +608,17 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2, constant3})); auto xla_while = builder.AddInstruction( @@ -657,15 +657,15 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { // Test a kTupleSelect. Non-top-level element flow through the instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -753,16 +753,16 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -805,7 +805,7 @@ TEST_F(HloAliasAnalysisTest, Bitcast) { // Bitcasting a value should not produce a new buffer. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); @@ -824,7 +824,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { // interference. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast})); @@ -843,13 +843,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { // the other use of the init. auto builder = HloComputation::Builder(TestName()); auto init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, init->shape(), "param")); auto cond_root = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index a8f3f0e9c2d..af4628cf58e 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -118,7 +118,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { // Test GetInstructionPostOrder for a computation with one instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); @@ -129,7 +129,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { // instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( @@ -144,7 +144,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { // Test GetInstructionPostOrder for a computation with a trace instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto trace = @@ -163,13 +163,13 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -181,11 +181,11 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -205,11 +205,11 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { // computation has multiple roots (dead code). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); // Add three disconnected add expressions. builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -256,7 +256,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { // Test that DeepCopyInstruction properly copies an array. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -268,9 +268,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) { // Test that DeepCopyInstruction properly copies a tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -289,7 +289,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { // copy are specified. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto computation = builder.Build(); { @@ -314,9 +314,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { // specified by the given indices. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto computation = builder.Build(); @@ -390,7 +390,7 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { auto builder = HloComputation::Builder(TestName()); auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); auto module = CreateNewModule(); @@ -407,7 +407,7 @@ TEST_F(HloComputationTest, CycleDetection) { // Test whether the visitor can detect cycles in the graph. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( @@ -433,7 +433,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { // twice. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto dead_negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( @@ -456,9 +456,9 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { TEST_F(HloComputationTest, CloneWithControlDependency) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -502,9 +502,9 @@ TEST_F(HloComputationTest, Reachability) { // There is a control dependency from 'add' to 'exp'. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto negate = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 436d103f230..7229031c0c7 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 5d05ccfc0b2..64a42c1efc0 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -41,7 +41,7 @@ using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); @@ -62,7 +62,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -82,8 +82,8 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({42.0f, 19.0f}))); + HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0f, 19.0f}))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); @@ -120,7 +120,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { for (auto csize : test_config.concat_sizes) { dimensions[test_config.concat_dimension] = csize; concat_size += csize; - auto literal = Literal::CreateFromDimensions(F32, dimensions); + auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); HloInstruction* insn = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); operands.push_back(insn); @@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_strides[] = {1, 1, 1, 1, 1}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - Literal::CreateRandomLiteral( + LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - Literal::CreateRandomLiteral( + LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 9fc4c48226f..9fd0363f578 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -338,13 +338,13 @@ TEST_F(FusionCostAnalysis, LoopFusion) { // tuple = Tuple({sub, sub, mul, C1}) HloComputation::Builder builder(TestName()); auto c1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2))); auto c2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2))); auto c3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2)); @@ -391,9 +391,9 @@ TEST_F(FusionCostAnalysis, NoLayout) { HloComputation::Builder builder(TestName()); auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D(2, 3, 4, 5)))); + LiteralUtil::CreateR4FromArray4D(Array4D(2, 3, 4, 5)))); auto c2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(shape_without_layout, c2, {1})); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 0fb65c845a6..90d2be118d9 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -261,9 +262,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, padding_config_dim.set_edge_padding_high(zeros_to_append); *padding_config.add_dimensions() = padding_config_dim; - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(operand->shape().element_type())))); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(MakeUnique( + LiteralUtil::Zero(operand->shape().element_type())))); return MakePadHlo(operand, zero, padding_config); } @@ -272,7 +273,7 @@ StatusOr BroadcastZeros( ArraySlice broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(element_type)))); + MakeUnique(LiteralUtil::Zero(element_type)))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 7e7c4f95fed..60d3e71757d 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -60,8 +60,8 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *Literal::CreateR1({3, 4})); + *module, {LiteralUtil::CreateR1({3, 4})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({3, 4})); } TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { @@ -82,10 +82,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { std::unique_ptr result_literal, evaluator.Evaluate>( *module, - {Literal::CreateR3( + {LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); CHECK_EQ(*result_literal, - *Literal::CreateR2( + *LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } @@ -103,10 +103,11 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { entry_computation->set_root_instruction(with_1_degenerate_dim_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {Literal::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *Literal::CreateR2({{9, 10}})); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9, 10}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { @@ -123,10 +124,11 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {Literal::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *Literal::CreateR3({{{9, 10}}})); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR3({{{9, 10}}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { @@ -145,8 +147,8 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR0(9)})); - CHECK_EQ(*result_literal, *Literal::CreateR2({{9}})); + *module, {LiteralUtil::CreateR0(9)})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9}})); } TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { @@ -166,9 +168,9 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR1({1, 2, 3, 4, 5, 6})})); + *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); CHECK_EQ(*result_literal, - *Literal::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); + *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { @@ -188,8 +190,8 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *Literal::CreateR1({0, 0, 0, 3, 4, 0})); + *module, {LiteralUtil::CreateR1({3, 4})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { @@ -209,8 +211,8 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR0(0)})); - CHECK_EQ(*result_literal, *Literal::CreateR2({{0, 0}, {0, 0}})); + *module, {LiteralUtil::CreateR0(0)})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { @@ -230,9 +232,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR0(0.0f)})); + *module, {LiteralUtil::CreateR0(0.0f)})); CHECK_EQ(*result_literal, - *Literal::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); + *LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index a0ee8896230..3f1deec2df9 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -24,7 +24,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_domain_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 16db374566c..c98a79fc71b 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -53,9 +53,9 @@ TEST_F(HloCseTest, CombineTwoConstants) { // Test that two identical constants are commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -72,7 +72,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(42.0f, constant->literal().Get({})); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR0(84.0); + auto expected = LiteralUtil::CreateR0(84.0); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -81,10 +81,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { // the pass is not layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -104,7 +104,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(first_operand, first_operand)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -113,10 +113,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { // if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -134,7 +134,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -144,20 +144,20 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { auto builder = HloComputation::Builder(TestName()); std::vector constants; constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)))); // Duplicate the float constant to verify something happens. constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)))); const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); for (int64 i = 0; i < constants.size(); ++i) { @@ -188,13 +188,13 @@ TEST_F(HloCseTest, NonscalarConstants) { // Test that identical nonscalar constants are merged. auto builder = HloComputation::Builder(TestName()); auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); // Create a constant which has the same shape but a different value. auto uncommon_constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); + LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); // Tie the constants together with a tuple. This makes it easier to refer to // the constant instructions via their use. @@ -223,7 +223,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test that three identical instructions are commoned. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -253,7 +253,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { // commoned if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -284,7 +284,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { // the pass is layout insensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -362,7 +362,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { // The *1 instructions should be merged with the *2 instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kNegate, constant)); @@ -400,9 +400,9 @@ TEST_F(HloCseTest, DoNotCombineRng) { // Test that two RNG ops are not commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto rng1 = builder.AddInstruction(HloInstruction::CreateRng( ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, {constant1, constant2})); @@ -442,9 +442,9 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName() + "_rng_fun"); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto rng = builder.AddInstruction(HloInstruction::CreateRng( scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2})); auto param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -459,7 +459,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({5.0f}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({5.0f}))); auto rng1 = builder.AddInstruction( HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); auto rng2 = builder.AddInstruction( @@ -521,9 +521,9 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { // in this case) are not collapsed. auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index f176473366a..da3a02f11cf 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -101,9 +101,9 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) { // Test the dataflow for a simple binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); module_->AddEntryComputation(builder.Build()); @@ -198,9 +198,9 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { // Verify the dataflow through a nested tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto nested_tuple = builder.AddInstruction( @@ -259,9 +259,9 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); @@ -308,9 +308,9 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( @@ -362,9 +362,9 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( @@ -426,9 +426,9 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); module_->AddEntryComputation(builder.Build()); @@ -493,15 +493,15 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -594,15 +594,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -653,7 +653,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); @@ -691,9 +691,9 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( @@ -780,15 +780,15 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -840,11 +840,11 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { // Test a kSelect of an array value. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2)); @@ -863,15 +863,15 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { // Test a kTupleSelect. Non-top-level element flow through the instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -939,17 +939,17 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { // Test kTupleSelect of a nested tuple. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); auto constant5 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant3})); auto tuple1 = builder.AddInstruction( @@ -1025,18 +1025,18 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -1088,7 +1088,7 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { // Test the bitcast_defines_value flag to the dataflow analysis. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); @@ -1309,13 +1309,13 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { auto body_param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "body_param")); auto constant = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto exp = body_builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant)); auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, exp, body_param)); auto dead_constant = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kNegate, dead_constant)); HloComputation* body = module_->AddEmbeddedComputation( @@ -1325,7 +1325,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "cond_param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); @@ -1576,11 +1576,11 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.0f))); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( scalar_shape_, pred, constant1, true_computation, constant2, false_computation)); @@ -1667,11 +1667,11 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.0f))); auto tuple_operand = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( @@ -1797,15 +1797,15 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { // Build entry computation. auto builder = HloComputation::Builder(TestName()); auto pred1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto pred2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.2f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.2f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.3f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.3f))); auto tuple_operand = builder.AddInstruction( HloInstruction::CreateTuple({pred2, constant1, constant2})); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( @@ -1943,9 +1943,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -2048,7 +2048,7 @@ TEST_F(CanShareOperandBufferWithUserTest, Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -2076,7 +2076,7 @@ TEST_F(CanShareOperandBufferWithUserTest, auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "param0")); auto index = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0}))); auto ds = builder.AddInstruction( HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2})); @@ -2144,9 +2144,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -2184,9 +2184,9 @@ TEST_F(CanShareOperandBufferWithUserTest, // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape_bf16, convert1, update, starts)); @@ -2237,9 +2237,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -2248,7 +2248,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -2270,7 +2270,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -2278,7 +2278,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { HloInstruction::CreateReverse(data_shape, operand, {0, 1})); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); @@ -2298,13 +2298,13 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( data_shape, HloOpcode::kMultiply, operand, operand)); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two)); @@ -2370,7 +2370,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto sub_param = sub_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "sub_param")); auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto ones = sub_builder.AddInstruction( HloInstruction::CreateBroadcast(shape, one, {1})); auto add = sub_builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index f5524dc6fef..4fa13c975ad 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -53,9 +53,9 @@ TEST_F(HloDceTest, NoDeadCode) { // Verify that no dead code is removed from a computation with no dead code. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -74,7 +74,7 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) { // Verify that side-effect instructions (Send in this test) are not removed. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); builder.AddInstruction( HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); @@ -127,9 +127,9 @@ TEST_F(HloDceTest, ControlDependencies) { // Verify that instructions with control dependencies are not removed. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); // Create two dead instructions: a negate and an add. auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary( @@ -224,7 +224,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { auto param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); auto constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant)); } @@ -346,12 +346,12 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) { builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{0}, reduce_subcomp)); // Add another instruction as the root of the computation. builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); module->AddEntryComputation(builder.Build()); EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); @@ -387,7 +387,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) { builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{0}, reduce_subcomp)); // Add another instruction as the root of the computation that also uses @@ -397,7 +397,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) { builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")), builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{0}, reduce_subcomp)); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 4ed1508d706..c804f4364f6 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 47da46bfad6..f68b4ca353a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -135,7 +136,6 @@ StatusOr> Compare( } // namespace - HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { typed_visitors_[PRED] = MakeUnique>(this); @@ -382,7 +382,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { ShapeUtil::GetDimension(operand_shape, concat_dim); } - auto result_literal = Literal::CreateFromDimensions( + auto result_literal = LiteralUtil::CreateFromDimensions( reference_shape.element_type(), concat_dimensions); DimensionVector source_indices(rank, 0); DimensionVector dest_indices(concat_dimensions.size(), 0); @@ -533,7 +533,7 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { operand_literals.push_back(&GetEvaluatedLiteralFor(operand)); } - evaluated_[tuple] = Literal::MakeTuple(operand_literals); + evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals); return Status::OK(); } @@ -903,7 +903,7 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { } Status HloEvaluator::HandleAfterAll(HloInstruction* token) { - evaluated_[token] = Literal::CreateToken(); + evaluated_[token] = LiteralUtil::CreateToken(); return Status::OK(); } @@ -1119,7 +1119,7 @@ std::unique_ptr EvaluateSortInternal(HloInstruction* sort, auto result_values_literal = MakeUnique(sort->operand(1)->shape()); result_values_literal->PopulateR1( tensorflow::gtl::ArraySlice(result_values)); - auto result_tuple = Literal::MakeTuple( + auto result_tuple = LiteralUtil::MakeTuple( {result_keys_literal.get(), result_values_literal.get()}); VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); return result_tuple; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 42770d848a8..5f575b24a1f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" @@ -112,9 +112,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. TEST_P(HloEvaluatorTest, DoesClamp) { - auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); - auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); - auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); + auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); Shape shape = low->shape(); HloComputation::Builder b(TestName()); @@ -127,15 +127,15 @@ TEST_P(HloEvaluatorTest, DoesClamp) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); + auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { - auto low = Literal::CreateR0(0.f); - auto value = Literal::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); - auto high = Literal::CreateR0(1.f); + auto low = LiteralUtil::CreateR0(0.f); + auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); + auto high = LiteralUtil::CreateR0(1.f); Shape shape = value->shape(); HloComputation::Builder b(TestName()); @@ -148,7 +148,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {1, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -156,9 +156,9 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. TEST_P(HloEvaluatorTest, DoesSelect) { - auto pred = Literal::CreateR2({{true, false}, {false, true}}); - auto on_true = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - auto on_false = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); + auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); Shape shape = on_true->shape(); HloComputation::Builder b(TestName()); @@ -173,7 +173,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { std::unique_ptr result = Evaluate({}); - auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); + auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -181,46 +181,46 @@ TEST_P(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. TEST_P(HloEvaluatorTest, DoesAdd) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise and with 2 operands. TEST_P(HloEvaluatorTest, DoesAnd) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{0, 0}, {4, 4}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {4, 4}}); TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. TEST_P(HloEvaluatorTest, DoesOr) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{3, 4}, {-100, 4}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {-100, 4}}); TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. TEST_P(HloEvaluatorTest, DoesXor) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{3, 4}, {-104, 0}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {-104, 0}}); TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. TEST_P(HloEvaluatorTest, DoesMultiply) { - auto lhs = Literal::CreateR2({{-1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2( + auto lhs = LiteralUtil::CreateR2({{-1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2( {{std::numeric_limits::min(), 4}, {4, 4}}); - auto expected = Literal::CreateR2( + auto expected = LiteralUtil::CreateR2( {{std::numeric_limits::min(), 0}, {-400, 16}}); TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs), std::move(rhs)); @@ -228,17 +228,17 @@ TEST_P(HloEvaluatorTest, DoesMultiply) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. TEST_P(HloEvaluatorTest, DoesDivideInt64) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } TEST_P(HloEvaluatorTest, DoesDivideDouble) { - auto lhs = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); + auto lhs = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = - Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); + LiteralUtil::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } @@ -246,54 +246,54 @@ TEST_P(HloEvaluatorTest, DoesDivideDouble) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. TEST_P(HloEvaluatorTest, DoesAbsR2) { - auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); - auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); + auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesAbsR0) { - auto operand = Literal::CreateR0(-1.0f); - auto expected = Literal::CreateR0(1.0f); + auto operand = LiteralUtil::CreateR0(-1.0f); + auto expected = LiteralUtil::CreateR0(1.0f); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) { - auto operand = Literal::CreateR1({}); - auto expected = Literal::CreateR1({}); + auto operand = LiteralUtil::CreateR1({}); + auto expected = LiteralUtil::CreateR1({}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesNegateR2) { - auto operand = Literal::CreateR2( + auto operand = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {-1, 4}}); - auto expected = - Literal::CreateR2({{0, std::numeric_limits::min()}, {1, -4}}); + auto expected = LiteralUtil::CreateR2( + {{0, std::numeric_limits::min()}, {1, -4}}); TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesCosR2) { - auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); - auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); + auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = LiteralUtil::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesSinR2) { - auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); - auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); + auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesNotR2) { auto operand = - Literal::CreateR2({{0, std::numeric_limits::min()}, - {-1, std::numeric_limits::max()}}); + LiteralUtil::CreateR2({{0, std::numeric_limits::min()}, + {-1, std::numeric_limits::max()}}); auto expected = - Literal::CreateR2({{-1, std::numeric_limits::max()}, - {0, std::numeric_limits::min()}}); + LiteralUtil::CreateR2({{-1, std::numeric_limits::max()}, + {0, std::numeric_limits::min()}}); TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand)); } // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto rhs2 = Literal::CreateR2({{1, -20}, {-100, 4}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -314,7 +314,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { std::unique_ptr result = Evaluate(args); - auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); + auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -324,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - Literal::CreateRandomLiteral( + LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = @@ -349,8 +349,8 @@ TEST_P(HloEvaluatorTest, DoesReshape) { // Verifies Broadcast operation is correctly evaluated. TEST_P(HloEvaluatorTest, DoesBroadcast) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); - auto output_literal = Literal::CreateR3( + auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto output_literal = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}}); HloInstruction* literal_instruction = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -365,8 +365,8 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR0(111); - auto output_literal = Literal::CreateR2( + auto input_literal = LiteralUtil::CreateR0(111); + auto output_literal = LiteralUtil::CreateR2( {{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}}); HloInstruction* literal_instruction = b.AddInstruction( @@ -386,9 +386,9 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-1, -2}, {100, 200}}))); + LiteralUtil::CreateR2({{-1, -2}, {100, 200}}))); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-2, -3}, {-100, -200}}))); + LiteralUtil::CreateR2({{-2, -3}, {-100, -200}}))); std::vector operands = {operand1, operand2}; @@ -399,8 +399,8 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { std::unique_ptr result = Evaluate(); - auto expected = - Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); + auto expected = LiteralUtil::CreateR2( + {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -408,9 +408,9 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({100, 200}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({100, 200}))); HloInstruction* operand2 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); std::vector operands = {operand1, operand2}; @@ -421,16 +421,16 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR1({100, 200}); + auto expected = LiteralUtil::CreateR1({100, 200}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto expected = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), expected->shape())); @@ -447,9 +447,9 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR2WithLayout( + auto input_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); - auto expected = Literal::CreateR2WithLayout( + auto expected = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), expected->shape())); @@ -478,13 +478,13 @@ PaddingConfig CreatePaddingConfig( } TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { - auto operand = Literal::CreateR2({{}, {}}); + auto operand = LiteralUtil::CreateR2({{}, {}}); HloComputation::Builder b(TestName()); auto operand_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); constexpr int32 kPadValue = 10; - auto pad_value = Literal::CreateR0(kPadValue); + auto pad_value = LiteralUtil::CreateR0(kPadValue); auto padding_value_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value))); @@ -496,7 +496,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2( + auto expected = LiteralUtil::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -506,11 +506,11 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { HloComputation::Builder b(TestName()); Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - auto input = Literal::CreateR4FromArray4D(input_array); + auto input = LiteralUtil::CreateR4FromArray4D(input_array); HloInstruction* input_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); constexpr float kPadValue = 1.5; - auto pad_value = Literal::CreateR0(kPadValue); + auto pad_value = LiteralUtil::CreateR0(kPadValue); HloInstruction* pad_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value))); @@ -532,7 +532,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { (*expected_array)(7, 0, 0, 0) = 5.0f; (*expected_array)(7, 2, 0, 0) = 6.0f; - auto expected = Literal::CreateR4FromArray4D(*expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -549,12 +549,12 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { // } auto input_array = MakeUnique>(4, 3); input_array->FillUnique(1.0f); - auto input = Literal::CreateR2FromArray2D(*input_array); + auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); auto pad_value_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.718f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.718f))); auto r2_padding_on_dim0_dim1 = CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}}); @@ -574,7 +574,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 2) = 2.718f; (*expected_array)(0, 3) = 2.718f; (*expected_array)(0, 4) = 2.718f; - auto expected = Literal::CreateR2FromArray2D(*expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); } @@ -590,12 +590,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { // } auto input_array = MakeUnique>(4, 3); input_array->FillUnique(1.0f); - auto input = Literal::CreateR2FromArray2D(*input_array); + auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); auto pad_value_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.718f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.718f))); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -613,7 +613,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { std::unique_ptr result = Evaluate(); auto expected_array = MakeUnique>(0, 9); - auto expected = Literal::CreateR2FromArray2D(*expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -630,13 +630,13 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // } auto lhs_array = MakeUnique>(4, 1); lhs_array->FillUnique(1.0f); - auto lhs_literal = Literal::CreateR2FromArray2D(*lhs_array); + auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); // rhs: // f32[2] { 1, 2 }, - auto rhs_literal = Literal::CreateR2({{1, 2}}); + auto rhs_literal = LiteralUtil::CreateR2({{1, 2}}); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -658,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { {4.f, 8.f}, }); // clang-format on - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -669,7 +669,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // lhs: // f32[3] // { 1, 2, 3 }, - auto lhs_literal = Literal::CreateR1({1, 2, 3}); + auto lhs_literal = LiteralUtil::CreateR1({1, 2, 3}); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -681,7 +681,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // } auto rhs_array = MakeUnique>(3, 2); rhs_array->FillUnique(1.0f); - auto rhs_literal = Literal::CreateR2FromArray2D(*rhs_array); + auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -695,7 +695,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR1({22.f, 28.f}); + auto expected = LiteralUtil::CreateR1({22.f, 28.f}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -712,7 +712,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // } auto lhs_array = MakeUnique>(4, 3); lhs_array->FillUnique(1.0f); - auto lhs_literal = Literal::CreateR2FromArray2D(*lhs_array); + auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -724,7 +724,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // } auto rhs_array = MakeUnique>(3, 2); rhs_array->FillUnique(1.0f); - auto rhs_literal = Literal::CreateR2FromArray2D(*rhs_array); + auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -744,7 +744,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { {94.f, 124.f}, {130.f, 172.f}, }); - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -753,12 +753,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { HloComputation::Builder b(TestName()); Array3D lhs_array = {{{1, 2, 3}}}; - auto lhs_literal = Literal::CreateR3FromArray3D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR3FromArray3D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); Array3D rhs_array = {{{3.f, 4.f}}}; - auto rhs_literal = Literal::CreateR3FromArray3D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR3FromArray3D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -792,7 +792,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { std::unique_ptr result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; - auto expected = Literal::CreateR3FromArray3D(expected_array); + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -809,7 +809,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -820,7 +820,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { {7, 8}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -854,7 +854,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { {149, 160, 171, 80}, })); // clang-format on - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -884,11 +884,11 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { }}); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(input); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(input); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); - auto rhs_literal = Literal::CreateR4FromArray4D(weight); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(weight); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse( @@ -933,7 +933,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { Array4D expected_array({{{{2514, 2685}}}}); Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D( + auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -964,11 +964,11 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { }}); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(input); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(input); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); - auto rhs_literal = Literal::CreateR4FromArray4D(weight); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(weight); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1010,7 +1010,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { Array4D expected_array({{{{2514, 2685}}}}); Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D( + auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -1028,7 +1028,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -1039,7 +1039,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { {7, 8}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1074,7 +1074,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { {91, 112, 98, 120, 105, 128, 112}, {65, 84, 70, 90, 75, 96, 80}, })); - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1091,7 +1091,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -1102,7 +1102,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { {7, 8}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1138,7 +1138,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { {104, 91, 112, 98, 120, 105, 128, 112}, {78, 65, 84, 70, 90, 75, 96, 80}, })); - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1156,7 +1156,7 @@ TEST_P(HloEvaluatorTest, {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -1167,7 +1167,7 @@ TEST_P(HloEvaluatorTest, {8, 9, 10}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1210,7 +1210,7 @@ TEST_P(HloEvaluatorTest, {0, 0, 0}, {91, 98, 105}, })); - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1225,9 +1225,9 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); HloInstruction* arg_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(v))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(v))); HloInstruction* init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1262,9 +1262,9 @@ void BM_ReducePrecisely(int num_iters) { constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); HloInstruction* arg_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(v))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(v))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1299,13 +1299,13 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { // } auto arg_array = MakeUnique>(2, 3); arg_array->FillUnique(1.0f); - auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1326,7 +1326,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR1({6, 18}); + auto expected = LiteralUtil::CreateR1({6, 18}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1341,13 +1341,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { // } auto arg_array = MakeUnique>(2, 3); arg_array->FillUnique(1.0f); - auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder max_computation("max"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1378,7 +1378,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{6, 7}}); + auto expected = LiteralUtil::CreateR2({{6, 7}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1392,13 +1392,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { // } auto arg_array = MakeUnique>(2, 3); arg_array->FillUnique(1.0f); - auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1435,7 +1435,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); + auto expected = LiteralUtil::CreateR2({{1, 3, 5}, {5, 11, 13}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1445,13 +1445,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); std::unique_ptr arg_literal = - Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); + LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1498,7 +1498,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = - Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); + LiteralUtil::CreateFullWithDescendingLayout(output_dims, 8.0f); EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); } @@ -1513,7 +1513,8 @@ TEST_P(HloEvaluatorTest, StridedSlice) { // } auto operand_array = MakeUnique>(3, 5); operand_array->FillUnique(1.0f); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); @@ -1527,7 +1528,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {3}, {19}, }); @@ -1545,13 +1546,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // } auto operand_array = MakeUnique>(2, 4); operand_array->FillUnique(1.0f); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, @@ -1560,7 +1562,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); @@ -1580,13 +1582,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { // } auto operand_array = MakeUnique>(2, 4); operand_array->FillUnique(1.0f); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2, 1}))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, @@ -1595,7 +1598,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); @@ -1613,16 +1616,17 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { // } auto operand_array = MakeUnique>(2, 3); operand_array->FillUnique(1.0); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); auto update = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); + LiteralUtil::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( @@ -1631,7 +1635,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {1, -2, -3}, {5, -6, -7}, }); @@ -1649,12 +1653,13 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { // } auto operand_array = MakeUnique>(2, 3); operand_array->FillUnique(1.0); - auto operand_literal2 = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal2 = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand2 = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal2))); HloInstruction* operand1 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); auto tuple = b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2})); @@ -1666,7 +1671,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {1, 2, 3}, {5, 6, 7}, }); @@ -1686,9 +1691,9 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { operand_array->FillUnique(1.0); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(*operand_array))); + LiteralUtil::CreateR2FromArray2D(*operand_array))); HloInstruction* operand1 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); auto tuple1 = b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2})); @@ -1706,8 +1711,8 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { std::unique_ptr result = Evaluate(); auto result_inner_literal = - Literal::CreateR2FromArray2D(*operand_array); - auto expected = Literal::MakeTuple({ + LiteralUtil::CreateR2FromArray2D(*operand_array); + auto expected = LiteralUtil::MakeTuple({ result_inner_literal.get(), result_inner_literal.get(), }); @@ -1735,7 +1740,7 @@ TEST_P(HloEvaluatorTest, Reverse) { {{23.0f}, {24.0f}}}, }); // clang-format on - auto operand_literal = Literal::CreateR4FromArray4D(input); + auto operand_literal = LiteralUtil::CreateR4FromArray4D(input); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); @@ -1746,7 +1751,7 @@ TEST_P(HloEvaluatorTest, Reverse) { std::unique_ptr result = Evaluate(); // clang-format off - auto expected = Literal::CreateR4FromArray4D({ + auto expected = LiteralUtil::CreateR4FromArray4D({ {{{23.0f}, {24.0f}}, {{21.0f}, {22.0f}}, {{19.0f}, {20.0f}}}, @@ -1782,11 +1787,11 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}. HloEvaluator evaluator; auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, Literal::CreateR1({1, 2, 3, 4}).get()}, - {square, Literal::CreateR1({10, 20, 30, 40}).get()}}); + add, {{param0, LiteralUtil::CreateR1({1, 2, 3, 4}).get()}, + {square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1799,18 +1804,18 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0")); HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kMultiply, param0, param0)); - HloInstruction* constant = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); HloInstruction* add = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square)); // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; auto result = evaluator.EvaluateWithSubstitutions( - add, {{square, Literal::CreateR1({10, 20, 30, 40}).get()}}); + add, {{square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1830,11 +1835,12 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1854,10 +1860,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + *LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1878,11 +1885,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR3( + *LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1904,13 +1911,13 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-1, 1}, {-4, 4}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1932,13 +1939,13 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-2, 2}, {-1, 1}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1959,10 +1966,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({1, 1}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{5}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{5}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1983,11 +1991,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR3({{{8}}, {{5}}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{8}}, {{5}}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -2007,10 +2015,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{}, {}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{}, {}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -2031,11 +2040,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = Literal::CreateR1({0, 1, 2}); + std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0}, {1}}, {{2}, {1}}}); + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{0, 1}, {2, 1}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{0, 1}, {2, 1}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -2043,14 +2052,14 @@ ENTRY main { // element-wise comparison with 2 bfloat16 operands. TEST_P(HloEvaluatorTest, DoesCompareBF16) { // lhs >= rhs - auto lhs = Literal::CreateR2( + auto lhs = LiteralUtil::CreateR2( {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)}, {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}}); - auto rhs = Literal::CreateR2( + auto rhs = LiteralUtil::CreateR2( {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)}, {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}}); auto expected = - Literal::CreateR2({{false, true, true}, {false, true, true}}); + LiteralUtil::CreateR2({{false, true, true}, {false, true, true}}); TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs), std::move(rhs)); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index cdbac74ba4c..2ae5f8bf36d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" @@ -1316,7 +1317,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(operand); auto curr_val = arg_literal.Get(multi_index); - auto curr_val_literal = Literal::CreateR0(curr_val); + auto curr_val_literal = LiteralUtil::CreateR0(curr_val); arg_literals.push_back(std::move(curr_val_literal)); } @@ -1504,8 +1505,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto curr_val = arg_literal.Get(input_index); // Evaluate computation with specified literal operands. - auto curr_val_literal = Literal::CreateR0(curr_val); - auto result_val_literal = Literal::CreateR0(result_val); + auto curr_val_literal = LiteralUtil::CreateR0(curr_val); + auto result_val_literal = + LiteralUtil::CreateR0(result_val); std::unique_ptr computed_result = embedded_evaluator @@ -1583,10 +1585,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid // dynamic memory allocations. - auto curr_val_literal = Literal::CreateR0(ReturnT()); - auto selected_val_literal = Literal::CreateR0(ReturnT()); - auto source_literal_scatter = Literal::CreateR0(ReturnT()); - auto scattered_literal = Literal::CreateR0(ReturnT()); + auto curr_val_literal = LiteralUtil::CreateR0(ReturnT()); + auto selected_val_literal = LiteralUtil::CreateR0(ReturnT()); + auto source_literal_scatter = LiteralUtil::CreateR0(ReturnT()); + auto scattered_literal = LiteralUtil::CreateR0(ReturnT()); do { // For each element in `source`, we place a window in `operand`. For each // window placement, we iterate inside the window twice: @@ -1707,9 +1709,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Evaluate computation with specified literal operands. const auto curr_val_literal = - Literal::CreateR0(curr_val); + LiteralUtil::CreateR0(curr_val); const auto result_val_literal = - Literal::CreateR0(result_val); + LiteralUtil::CreateR0(result_val); std::unique_ptr computed_result = embedded_evaluator .Evaluate( @@ -1754,7 +1756,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get(operand_index); }; - auto result = Literal::CreateFromDimensions( + auto result = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); TF_RETURN_IF_ERROR(result->Populate(func)); parent_->evaluated_[slice] = std::move(result); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 7a1372f9298..57cf34d7dee 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -27,7 +27,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 68f41a1cbb4..1d7a062c556 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -120,7 +121,7 @@ TEST(HloGraphDumperTest, NestedFusion) { TEST(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(-42))); instruction->SetAndSanitizeName("i_am_a_constant_root_instruction"); HloModuleConfig config; HloModule m(TestName(), config); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 6ea302f8b41..b396042f520 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 34e7dcb43d4..17cc6d35cc2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -33,7 +33,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/iterator_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index d8ca99dfd12..e37556ac8d0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -249,7 +249,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r0f32_, "param1")); auto c0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto addleft = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0)); auto addright = builder.AddInstruction( @@ -294,7 +294,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r0f32_, "param1")); auto c0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto neg1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0)); auto addleft = builder.AddInstruction( @@ -334,7 +334,7 @@ TEST_F(HloInstructionTest, TrivialMap) { auto param = embedded_builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "x")); auto value = embedded_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); @@ -383,9 +383,9 @@ TEST_F(HloInstructionTest, TrivialReduce) { auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, f32a100x10, "p")); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto reduce = builder.AddInstruction( HloInstruction::CreateReduce(f32v100, param0, const0, /*dimensions_to_reduce=*/{1}, add_f32)); @@ -626,7 +626,7 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single unary operation. auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); auto module = CreateNewModule(); @@ -642,9 +642,9 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single binary operation. auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto module = CreateNewModule(); @@ -661,7 +661,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) { HloComputation::Builder builder(TestName()); // Create a chain of fused unary ops. auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction( @@ -682,7 +682,7 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { HloComputation::Builder builder(TestName()); // Create a chain of fused unary ops. auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction( @@ -710,7 +710,7 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({ + HloInstruction::CreateConstant(LiteralUtil::CreateR2({ {1, 2}, {3, 4}, }))); @@ -732,7 +732,7 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { HloComputation::Builder builder(TestName()); auto* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({ + HloInstruction::CreateConstant(LiteralUtil::CreateR2({ {1, 2}, {3, 4}, }))); @@ -763,7 +763,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto map_1_x = builder.AddInstruction( HloInstruction::CreateMap(scalar_shape, {constant}, computation_x)); auto map_2_x = builder.AddInstruction( @@ -798,11 +798,11 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // Notable complexities are repeated operands in the same instruction, // different shapes, use of value in different expressions. auto c1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto c2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.1f))); auto c3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(9.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(9.0f))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2)); @@ -873,11 +873,11 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { // Create a set of random constant operands to use below. Make them matrices // so dimensions are interesting. auto operand1 = HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); auto operand2 = HloInstruction::CreateConstant( - Literal::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); - auto vector_operand = - HloInstruction::CreateConstant(Literal::CreateR1({42.0, 123.0})); + LiteralUtil::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); + auto vector_operand = HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0, 123.0})); Shape shape = operand1->shape(); // Convenient short names for the operands. @@ -1234,9 +1234,9 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { // Build a nested fusion computation. Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); DotDimensionNumbers dot_dnums; @@ -1245,7 +1245,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { auto dot = builder.AddInstruction( HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); auto add = builder.AddInstruction(HloInstruction::CreateBinary( @@ -1342,7 +1342,7 @@ TEST_F(HloInstructionTest, Stringification) { "condition=%TransposeDot, body=%TransposeDot"); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloInstruction* conditional = builder.AddInstruction(HloInstruction::CreateConditional( sout, pred, x, computation, x, computation)); @@ -1550,7 +1550,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { HloInstruction::CreateWhile(sout, computation, computation, x)); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloInstruction* conditional = builder.AddInstruction(HloInstruction::CreateConditional( sout, pred, x, computation, x, computation)); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 7052e236cda..ed934c689a5 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -766,7 +767,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( HloTraceInstruction::HloTraceInstruction(const string& tag, HloInstruction* operand) : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()), - literal_(Literal::CreateR1U8(tag)) { + literal_(LiteralUtil::CreateR1U8(tag)) { AppendOperand(operand); operand->set_tracing(this); } diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 0275294a1a8..01b625c29ca 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 9a3010cf1ff..7de59acc1ef 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -75,8 +76,10 @@ TEST(HloMatchersTest, Test) { } TEST(HloMatchersTest, CustomCallMatcher) { - auto c1 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); - auto c2 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); + auto c1 = + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); + auto c2 = + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); auto call = HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target"); diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 7f28a804bfe..236f4500860 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -38,7 +38,7 @@ class HloModuleTest : public HloTestBase { std::unique_ptr CreateConstantComputation() { auto builder = HloComputation::Builder("Constant"); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); return builder.Build(); } @@ -122,7 +122,7 @@ TEST_F(HloModuleTest, CloneHasFusion) { { auto b = HloComputation::Builder("Entry"); auto input = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); b.AddInstruction( HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput, /*operands=*/{input}, fused_computation)); @@ -173,7 +173,7 @@ TEST_F(HloModuleTest, LargeConstantToString) { auto builder = HloComputation::Builder("Constant"); std::vector values(16, 42.0); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(values))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(values))); module->AddEntryComputation(builder.Build()); EXPECT_EQ( diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index cfe5dace05a..126d3a2d9c7 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -57,7 +57,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { auto builder_c = HloComputation::Builder("C"); HloInstruction* c = builder_c.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); HloComputation* computation_c = module->AddEmbeddedComputation(builder_c.Build()); @@ -145,7 +145,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); module->AddEntryComputation(builder.Build()); @@ -208,7 +208,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); auto add = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index f192debc9c7..54fc34b862f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -1609,7 +1610,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, } } } - *literal = Literal::MakeTupleOwned(std::move(elements)); + *literal = LiteralUtil::MakeTupleOwned(std::move(elements)); return ParseToken(TokKind::kRparen, StrCat("expects ')' at the end of the tuple with ", ShapeUtil::TupleElementCount(shape), "elements")); @@ -1637,8 +1638,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, } // Create a literal with the given shape in default layout. - *literal = Literal::CreateFromDimensions(shape.element_type(), - AsInt64Slice(shape.dimensions())); + *literal = LiteralUtil::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); tensorflow::int64 nest_level = 0; tensorflow::int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2418c19f3de..2a07b6fcbc2 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_query.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index 657a9ee83d2..585c95972b0 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -39,15 +39,15 @@ TEST_F(HloReachabilityTest, Reachability) { */ auto builder = HloComputation::Builder(TestName()); auto a = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto b = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto c = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto d = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto e = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.Build(); HloReachabilityMap reachability({a, b, c, d, e}); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 7a46da6efe0..cd131147e61 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -132,7 +132,7 @@ class HloRematerializationTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); return builder.Build(); } @@ -226,7 +226,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -263,7 +263,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -296,7 +296,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 73f22f81f4e..cf9ceed5b2f 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -168,8 +168,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { auto cond_builder = HloComputation::Builder("WhileCond"); HloInstruction* cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + HloInstruction* zero_vector = + cond_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0, 0, 0, 0}}))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -179,16 +180,18 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { auto body_builder = HloComputation::Builder("WhileBody"); HloInstruction* body_param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); body_builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, body_param, one_vector)); auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); // transpose(matrix) + bcast(while) auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); // Creates 16 bytes, ignoring subcomputations HloInstruction* while_loop = builder.AddInstruction(HloInstruction::CreateWhile( @@ -199,7 +202,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); HloInstruction* matrix = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); // Creates 32 bytes HloInstruction* transpose = builder.AddInstruction( @@ -257,7 +260,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { // Wrap lit in abs because constants are considered free by // IgnoreInstruction, and it skews the accounting. auto lit = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1, 1, 1, 1, 1, 1}))); + LiteralUtil::CreateR1({1, 1, 1, 1, 1, 1}))); auto abs_const = builder.AddInstruction( HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); @@ -300,11 +303,11 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { HloComputation::Builder builder(TestName()); auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1, 1, 1, 1, 1}))); + LiteralUtil::CreateR1({1, 1, 1, 1, 1}))); auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1, 2, 3, 4, 5}))); + LiteralUtil::CreateR1({1, 2, 3, 4, 5}))); auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0, 2, 4, 6, 8}))); + LiteralUtil::CreateR1({0, 2, 4, 6, 8}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); @@ -354,8 +357,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { auto cond_builder = HloComputation::Builder("WhileCond"); HloInstruction* cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + HloInstruction* zero_vector = + cond_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0, 0, 0, 0}}))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -365,15 +369,17 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { auto body_builder = HloComputation::Builder("WhileBody"); HloInstruction* body_param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); body_builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, body_param, one_vector)); auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); // Creates 16 bytes, ignoring subcomputations builder.AddInstruction(HloInstruction::CreateWhile( r1f32, cond_computation, body_computation, while_init)); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 34324d2058e..3d14f9c89e0 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -24,7 +24,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 54b7402b866..7baa927d0e2 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 7b601f9a957..45c684d6675 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -75,7 +75,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant}, callee1)); auto y = builder.AddInstruction( @@ -112,9 +112,9 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1)); auto y = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 3dc733940fc..48f676db85a 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/attr_value.pb.h" diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index be156d765dc..1e2b31a1f2b 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -90,7 +90,7 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { auto builder = HloComputation::Builder("Const"); HloInstruction *instruction = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); OpMetadata metadata; metadata.set_op_name("x"); metadata.set_op_type("y"); diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index 8c7b38dd1bf..f85d31d5225 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index d2af261008f..32937b33b37 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -51,10 +51,10 @@ TEST_F(InlinerTest, MapMax) { auto max_f32 = max_builder.Build(); auto builder = HloComputation::Builder("MapMaxFunction"); - auto lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({4, 3, 2, 1}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); @@ -70,7 +70,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = Literal::CreateR1({4, 3, 3, 4}); + auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } @@ -83,12 +83,12 @@ TEST_F(InlinerTest, MapConstant) { HloInstruction::CreateParameter(0, r0f32, "x")); (void)param1; const2_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); auto const2_f32 = const2_builder.Build(); auto builder = HloComputation::Builder("MapConstFunction"); auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); + LiteralUtil::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); @@ -104,7 +104,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); + auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } @@ -123,10 +123,10 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { auto max_f32 = max_builder.Build(); auto builder = HloComputation::Builder("MapSubFunction"); - auto lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({4, 3, 2, 1}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); @@ -142,7 +142,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = Literal::CreateR1({3, 1, -1, -3}); + auto expected = LiteralUtil::CreateR1({3, 1, -1, -3}); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 524d3234eb4..8652599dc6d 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -74,7 +74,7 @@ cc_library( hdrs = ["executable.h"], deps = [ ":executor", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 9816acf6507..8d40c08d555 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index a673901c756..ebd7f696e6a 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -141,9 +141,9 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); - auto constant_literal1 = Literal::CreateR2WithLayout( + auto constant_literal1 = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); - auto constant_literal2 = Literal::CreateR2WithLayout( + auto constant_literal2 = LiteralUtil::CreateR2WithLayout( {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); Shape ashape = constant_literal1->shape(); @@ -192,10 +192,10 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { // match their source). auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -229,10 +229,10 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { // Verify layouts of a select with tuple operands is assigned properly. auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple0 = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -240,7 +240,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateTuple({constant0, constant1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); @@ -274,7 +274,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // tuple and assigning the layouts of the copied arrays as needed. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto inner_tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); auto nested_tuple = builder.AddInstruction( @@ -584,7 +584,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { auto builder = HloComputation::Builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7}); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(input_shape, constant, {})); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -803,7 +803,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kBitcast, constant0)); diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index f1e7fc29532..ce36afc1e64 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -42,7 +42,7 @@ cc_library( srcs = ["llvm_util.cc"], hdrs = ["llvm_util.h"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 97bacc34b59..6c55361b44b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 4a10ec466da..9c51861eaca 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -27,7 +27,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/Support/raw_ostream.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 49ec38eb62c..ca86c5d13e9 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 13e2d3258e3..ad3b662c20a 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -175,8 +175,9 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{true, true, false}, {false, false, true}}))); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{true, true, false}, {false, false, true}}))); auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1")); @@ -255,12 +256,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); @@ -309,7 +310,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0")); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); builder.AddInstruction(HloInstruction::CreateBinary( @@ -348,7 +349,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3}), "param0")); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({9, 8, 7}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({9, 8, 7}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); auto reshape1 = diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index e384359642a..249bdcc1f5d 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index cccb8f2fbb0..7051a4cf517 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -160,11 +160,11 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { auto builder = HloComputation::Builder("entry"); // (1.0 + 2.0) * (2.0 - 3.0) HloInstruction* const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* const2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* const3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( const1->shape(), HloOpcode::kAdd, const1, const2)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 226d0af5d27..d52091487f6 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -124,9 +124,9 @@ class TuplePointsToAnalysisTest : public HloTestBase { TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -177,14 +177,14 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -238,14 +238,14 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -270,7 +270,7 @@ TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { // Create a tuple which contains duplicate elements. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant, constant, constant})); @@ -291,9 +291,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) { // the same. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto copy = builder.AddInstruction( @@ -317,7 +317,7 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { // Send forwards its operand to the output tuple at {0}. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto send = builder.AddInstruction( HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); @@ -365,16 +365,16 @@ TEST_F(TuplePointsToAnalysisTest, TupleSelect) { // set containing the union of both sides. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); @@ -403,7 +403,7 @@ TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, tuple_shape, "param1")); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1)); auto copy = builder.AddInstruction( @@ -443,16 +443,16 @@ TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) { // Select from two identical tuples. The result should not be ambiguous. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); @@ -474,9 +474,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { // the right values. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto inner_tuple2 = builder.AddInstruction( @@ -488,7 +488,7 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); @@ -521,9 +521,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { // have the operand of the bitcast in its points-to set. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( constant2->shape(), HloOpcode::kBitcast, constant2)); auto tuple = @@ -557,9 +557,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), - Literal::CreateR1({2.0, 42}).get()}))); + auto tuple_constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), + LiteralUtil::CreateR1({2.0, 42}).get()}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); @@ -579,9 +580,9 @@ TEST_F(TuplePointsToAnalysisTest, BufferAliases) { // times. Verify buffer alias sets. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple = builder.AddInstruction( @@ -620,7 +621,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { auto tuple_element1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1)); auto ones = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones) auto update = builder.AddInstruction(HloInstruction::CreateBinary( update_shape, HloOpcode::kAdd, tuple_element1, ones)); @@ -868,9 +869,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -962,9 +963,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -1016,9 +1017,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -1027,7 +1028,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -1049,7 +1050,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -1057,7 +1058,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { HloInstruction::CreateReverse(data_shape, operand, {0, 1})); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); @@ -1122,7 +1123,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto sub_param = sub_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "sub_param")); auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto ones = sub_builder.AddInstruction( HloInstruction::CreateBroadcast(shape, one, {1})); auto add = sub_builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index d3635eae81e..39b693872da 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 23519e445ea..a652aafc50d 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -53,7 +53,7 @@ HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); return module->AddEmbeddedComputation(builder.Build()); } @@ -125,7 +125,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { builder.AddInstruction(HloInstruction::CreateUnary( scalar_s32, HloOpcode::kNegate, mul_result)); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); HloInstruction* sub_result = builder.AddInstruction(HloInstruction::CreateBinary( scalar_s32, HloOpcode::kSubtract, negate_result, constant)); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 3c830492166..e8e9ce200bd 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -157,7 +157,7 @@ TEST_F(WhileLoopSimplifierTest, auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* true_op = while_op->while_body()->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK(true_op->AddControlDependencyTo( while_op->while_body()->root_instruction())); ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -178,7 +178,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { auto* token = while_body->AddInstruction(HloInstruction::CreateAfterAll({})); auto* send = while_body->AddInstruction(HloInstruction::CreateSend( while_body->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))), token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 473eab2ea84..1ef17b9d7d2 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" @@ -38,7 +39,7 @@ static StatusOr WidenWhileCondition( // the root instruction later. We later change the root instruction to // something more appropriate. builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); return narrow_condition->parent()->AddEmbeddedComputation(builder.Build()); }(); @@ -154,7 +155,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape, {&loop_state_shape}, scalar_pred, "while_cond")); HloInstruction* trip_count_constant = cond_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(trip_count))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(trip_count))); HloInstruction* param = cond_computation->parameter_instruction(0); TF_ASSIGN_OR_RETURN(HloInstruction * indvar, @@ -175,7 +176,7 @@ static StatusOr> MakeCountedLoopBodyComputation( CreateComputationWithSignature( {&loop_state_shape}, loop_state_shape, "while_body")); HloInstruction* one = body_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); HloInstruction* param = body_computation->parameter_instruction(0); TF_ASSIGN_OR_RETURN(HloInstruction * indvar, MakeGetTupleElementHlo(param, 0)); @@ -203,7 +204,7 @@ static StatusOr MakeInitTupleFromInitValues( std::vector init_values_with_indvar; init_values_with_indvar.reserve(init_values.size() + 1); HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); init_values_with_indvar.push_back(zero); c_copy(init_values, std::back_inserter(init_values_with_indvar)); return computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index 44b0ec5cd4c..0a3a757b277 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index c6bd013a1aa..4f3e95ada53 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 02f6fc3a271..456a33ad85b 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -65,6 +65,7 @@ cc_library( srcs = ["test_utils.cc"], hdrs = ["test_utils.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -88,6 +89,7 @@ cc_library( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", @@ -179,6 +181,7 @@ cc_library( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -302,7 +305,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -345,7 +348,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -406,7 +409,7 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -435,7 +438,7 @@ xla_test( tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", @@ -531,6 +534,7 @@ xla_test( srcs = ["scalar_computations_test.cc"], shard_count = 32, deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -573,7 +577,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -599,7 +603,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -645,7 +649,7 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -764,6 +768,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", @@ -780,7 +785,7 @@ xla_test( CONVOLUTION_TEST_DEPS = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -827,7 +832,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", @@ -874,7 +879,7 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -907,7 +912,7 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -940,7 +945,7 @@ xla_test( ], deps = [ ":test_utils", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1031,6 +1036,7 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1079,6 +1085,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -1149,7 +1156,7 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1176,7 +1183,7 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1228,6 +1235,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", @@ -1246,6 +1254,7 @@ xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -1291,6 +1300,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -1368,7 +1378,7 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1391,7 +1401,7 @@ xla_test( name = "prng_test", srcs = ["prng_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", @@ -1416,6 +1426,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -1530,7 +1541,7 @@ xla_test( name = "cross_replica_sum_test", srcs = ["cross_replica_sum_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1574,7 +1585,7 @@ xla_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", @@ -1614,7 +1625,7 @@ xla_test( name = "compute_constant_test", srcs = ["compute_constant_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1689,7 +1700,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1714,7 +1725,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1731,6 +1742,7 @@ tf_cc_test( srcs = ["llvm_compiler_test.cc"], tags = ["requires-gpu-sm35"], deps = [ + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:cpu_plugin", @@ -1751,7 +1763,7 @@ xla_test( name = "round_trip_packed_literal_test", srcs = ["round_trip_packed_literal_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:packed_literal_reader", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1774,7 +1786,7 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1802,7 +1814,7 @@ xla_test( srcs = ["multioutput_fusion_test.cc"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1842,7 +1854,7 @@ xla_test( name = "local_client_allocation_test", srcs = ["local_client_allocation_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1865,7 +1877,7 @@ xla_test( shard_count = 30, tags = ["optonly"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -1911,7 +1923,7 @@ xla_test( srcs = ["round_trip_transfer_test.cc"], deps = [ "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", @@ -1932,7 +1944,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1980,7 +1992,7 @@ xla_test( ":literal_test_util", ":local_client_test_base", ":xla_internal_test_main", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 3bdf98544af..3ae96fa1bcb 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -225,7 +225,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0x8000000000000000LL, 0x8000000000000000LL, 1}; - std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); + std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); @@ -239,7 +239,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0, 1, 0x8000000000000000LL}; - std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); + std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); @@ -265,7 +265,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 1, 0, -1}; - std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); + std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); @@ -278,7 +278,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; - std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); + std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); @@ -303,13 +303,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = Literal::CreateR1({a_values}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a_constant = ConstantR1(&builder, a_values); auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param"); - std::unique_ptr b_literal = Literal::CreateR1({b_values}); + std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param"); @@ -1426,7 +1426,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = Literal::CreateR1(values); + std::unique_ptr param_literal = LiteralUtil::CreateR1(values); std::unique_ptr param_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); @@ -1454,10 +1454,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); @@ -1479,10 +1479,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); @@ -1504,10 +1504,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); @@ -1529,10 +1529,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); @@ -1555,15 +1555,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); @@ -1587,15 +1587,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); @@ -1620,15 +1620,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); @@ -1654,19 +1654,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - std::unique_ptr literal3 = Literal::CreateR1(values3); + std::unique_ptr literal3 = LiteralUtil::CreateR1(values3); std::unique_ptr data3 = client_->TransferToServer(*literal3).ConsumeValueOrDie(); @@ -2101,12 +2101,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); + LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -2123,12 +2123,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); + LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); + LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -2145,7 +2145,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -2201,7 +2201,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // the input tensor is large enough to exercise the vectorized tanh // implementation on XLA CPU. XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1( + auto input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85, @@ -2243,7 +2243,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. - std::unique_ptr input_literal = Literal::CreateR1( + std::unique_ptr input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, @@ -2277,7 +2277,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // implementation on XLA CPU. XlaBuilder builder(TestName()); - std::unique_ptr input_literal = Literal::CreateR1( + std::unique_ptr input_literal = LiteralUtil::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, @@ -2469,9 +2469,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); Tuple(&builder, {cmp_dim_0, cmp_dim_1}); - auto expected = Literal::MakeTuple( - {Literal::CreateR2({{true, true}, {true, false}}).get(), - Literal::CreateR2({{true, false}, {false, false}}).get()}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), + LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -2825,8 +2825,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( - r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); + std::unique_ptr a_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); auto a = ConstantLiteral(&builder, *a_literal); auto b = ConstantR1(&builder, r1); Add(a, b, {1}); @@ -2887,8 +2888,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { // broadcast. XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { XlaBuilder builder(TestName()); - auto x_literal = Literal::CreateR1({1, 2, 3}); - auto y_literal = Literal::CreateR1({4, 5}); + auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); + auto y_literal = LiteralUtil::CreateR1({4, 5}); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 217673c8cbc..6a024798f9e 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -63,7 +63,7 @@ class BatchNormalizationTest {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_)); + input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_)); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -242,12 +242,12 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) .get(), - Literal::CreateR1({4, 5}).get(), - Literal::CreateR1({5, 5}).get()}); + LiteralUtil::CreateR1({4, 5}).get(), + LiteralUtil::CreateR1({5, 5}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } @@ -267,12 +267,12 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) .get(), - Literal::CreateR1({4, 5}).get(), - Literal::CreateR1({5, 5}).get()}); + LiteralUtil::CreateR1({4, 5}).get(), + LiteralUtil::CreateR1({5, 5}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } @@ -298,11 +298,11 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { BatchNormTraining(h0, h1, h2, /*epsilon=*/1, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) .get(), - Literal::CreateR1(std::vector(260, 1.0f)).get(), - Literal::CreateR1(std::vector(260, 0.0f)).get()}); + LiteralUtil::CreateR1(std::vector(260, 1.0f)).get(), + LiteralUtil::CreateR1(std::vector(260, 0.0f)).get()}); ComputeAndCompareTuple(&builder, *expected, {operand.get(), scale.get(), offset.get()}, @@ -331,11 +331,12 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { BatchNormTraining(h0, h1, h2, /*epsilon=*/-100, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR3FromArray3D( + {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) .get(), - Literal::CreateR1(std::vector(1, 15.0f)).get(), - Literal::CreateR1(std::vector(1, 125.0f)).get()}); + LiteralUtil::CreateR1(std::vector(1, 15.0f)).get(), + LiteralUtil::CreateR1(std::vector(1, 125.0f)).get()}); ComputeAndCompareTuple(&builder, *expected, {operand.get(), scale.get(), offset.get()}, @@ -362,12 +363,12 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) .get(), - Literal::CreateR1({0, 0}).get(), - Literal::CreateR1({16, 20}).get()}); + LiteralUtil::CreateR1({0, 0}).get(), + LiteralUtil::CreateR1({16, 20}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } @@ -513,11 +514,12 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, scale4D, offset4D, epsilon); - auto expected_normalized = Literal::CreateR4FromArray4D(normalized); + auto expected_normalized = + LiteralUtil::CreateR4FromArray4D(normalized); - auto offset_literal = Literal::CreateR1(offset); - auto scale_literal = Literal::CreateR1(scale); - auto input_literal = Literal::CreateR4FromArray4D(input_array); + auto offset_literal = LiteralUtil::CreateR1(offset); + auto scale_literal = LiteralUtil::CreateR1(scale); + auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = Parameter(&builder, 0, input_literal->shape(), "input"); @@ -526,9 +528,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto offset_activations = Parameter(&builder, 2, offset_literal->shape(), "scale"); - auto expected = Literal::MakeTuple({expected_normalized.get(), - Literal::CreateR1(mean).get(), - Literal::CreateR1(var).get()}); + auto expected = LiteralUtil::MakeTuple( + {expected_normalized.get(), LiteralUtil::CreateR1(mean).get(), + LiteralUtil::CreateR1(var).get()}); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -613,11 +615,11 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, scale4D, offset4D, epsilon); - auto offset_literal = Literal::CreateR1(offset); - auto scale_literal = Literal::CreateR1(scale); - auto mean_literal = Literal::CreateR1(mean); - auto var_literal = Literal::CreateR1(var); - auto input_literal = Literal::CreateR4FromArray4D(input_array); + auto offset_literal = LiteralUtil::CreateR1(offset); + auto scale_literal = LiteralUtil::CreateR1(scale); + auto mean_literal = LiteralUtil::CreateR1(mean); + auto var_literal = LiteralUtil::CreateR1(var); + auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = Parameter(&builder, 0, input_literal->shape(), "input"); @@ -800,14 +802,14 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { }); auto expected_grad_activation = - Literal::CreateR4FromArray4D(grad_activation); + LiteralUtil::CreateR4FromArray4D(grad_activation); - auto input_literal = Literal::CreateR4FromArray4D(input_array); - auto scale_literal = Literal::CreateR1(scale); - auto mean_literal = Literal::CreateR1(mean); - auto var_literal = Literal::CreateR1(var); + auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); + auto scale_literal = LiteralUtil::CreateR1(scale); + auto mean_literal = LiteralUtil::CreateR1(mean); + auto var_literal = LiteralUtil::CreateR1(var); auto grad_output_literal = - Literal::CreateR4FromArray4D(grad_output_array); + LiteralUtil::CreateR4FromArray4D(grad_output_array); auto input_parameter = Parameter(&builder, 0, input_literal->shape(), "input"); @@ -833,9 +835,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { grad_output_parameter, epsilon, feature_index); auto expected = - Literal::MakeTuple({expected_grad_activation.get(), - Literal::CreateR1(grad_scale).get(), - Literal::CreateR1(grad_offset).get()}); + LiteralUtil::MakeTuple({expected_grad_activation.get(), + LiteralUtil::CreateR1(grad_scale).get(), + LiteralUtil::CreateR1(grad_offset).get()}); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index f40d03bea79..747c82b502c 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -95,18 +95,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4( + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4( {{{{static_cast(-1.6875f)}, {static_cast(-2.04f)}}, {{static_cast(0.105f)}, {static_cast(0.66f)}}}, {{{static_cast(1.89f)}, {static_cast(3.35f)}}, {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(4), static_cast(5)}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(5), static_cast(5)}) .get()}); @@ -139,17 +139,17 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4( + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4( {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, {{static_cast(-1.f)}, {static_cast(-1.f)}}}, {{{static_cast(1.f)}, {static_cast(1.f)}}, {{static_cast(3.f)}, {static_cast(3.f)}}}}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(0), static_cast(0)}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(16), static_cast(20)}) .get()}); diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 91aba9a8de3..50dd574624b 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -58,7 +59,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array3D* r3_array, float start, float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout( + auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = client_->TransferToServer(*r3_data).ConsumeValueOrDie(); @@ -71,7 +72,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array2D* r2_array, float start, float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout( + auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = client_->TransferToServer(*r2_data).ConsumeValueOrDie(); @@ -290,13 +291,13 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}}), - ConstantLiteral(&b, *Literal::CreateR3( + ConstantLiteral(&b, *LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); auto expected = - Literal::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, - {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); + LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, + {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -365,7 +366,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } } } - auto expected = Literal::CreateR3FromArray3D(expected_array); + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r3_implicit_global_data.get(), r3_global_data.get()}, @@ -390,7 +391,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { Add(r3h, r1h); auto expected = - Literal::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); @@ -398,39 +399,40 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1}, {2}}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}, {2}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}, {3, 4}}})); + auto r1 = + ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); + LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -438,40 +440,40 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}}, {{3, 4}}})); + ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); - auto r1 = - ConstantLiteral(&b, *Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r1 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1}}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -612,7 +614,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { *v = ApplyOpToFloats(spec.op2, tmp, v3); }); - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r2_implicit_global_data1.get(), r2_global_data.get(), @@ -626,22 +628,24 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}})); - auto r2 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}})); + auto r2 = + ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); - auto expected = Literal::CreateR2({{2, 4}, {4, 6}}); + auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR2({{1}, {2}})); - auto r2 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = + ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); - auto expected = Literal::CreateR2({{2, 3}, {5, 6}}); + auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -650,11 +654,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1, {0}); - auto expected = - Literal::CreateR3({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); + auto expected = LiteralUtil::CreateR3( + {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -663,11 +667,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {1}); - auto expected = - Literal::CreateR3({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); + auto expected = LiteralUtil::CreateR3( + {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -676,11 +680,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {2}); - auto expected = - Literal::CreateR3({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); + auto expected = LiteralUtil::CreateR3( + {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -691,7 +695,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = ConstantR1(&b, {100, 200}); auto r1_2 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = Add(r1_0, r3, {0}); r3 = Add(r3, r1_1, {1}); @@ -699,7 +703,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { } r3 = Mul(r3, ConstantR0(&b, -2)); - auto expected = Literal::CreateR3( + auto expected = LiteralUtil::CreateR3( {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); @@ -720,7 +724,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { } r3 = Mul(r3, ConstantR0(&b, -1)); - auto expected = Literal::CreateR3( + auto expected = LiteralUtil::CreateR3( {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); @@ -733,7 +737,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}, {1.0, 5.0}}), - ConstantLiteral(&b, *Literal::CreateR3( + ConstantLiteral(&b, *LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 51b9f0d3e33..c7b94b5bbaa 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -37,7 +37,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { // Test degenerate case of broadcasting a scalar into a scalar. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {}), input, {})); @@ -46,14 +46,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0(42.0), *result, - error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0(42.0), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {})); @@ -63,14 +63,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple // to enable testing of the results. @@ -86,18 +86,18 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), LiteralSlice(*result, {0}), error_spec_)); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), LiteralSlice(*result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); @@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { @@ -116,7 +116,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { // the dimensions, ie transpose. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); @@ -125,15 +125,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); @@ -143,15 +143,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1.0, 2.0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0, 2.0}))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -166,8 +166,9 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -176,7 +177,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { int64 r1_size = input_data.size(); std::iota(input_data.begin(), input_data.end(), 0.0f); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(input_data))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(input_data))); // Broadcast vector in dimension 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -196,8 +197,9 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -207,7 +209,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { std::vector r1_array(64, 42.0); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(r1_array))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(r1_array))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -218,14 +220,14 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array), + EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); @@ -238,15 +240,16 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { auto builder = HloComputation::Builder(TestName()); Array2D to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}}); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(to_broadcast))); + LiteralUtil::CreateR2FromArray2D(to_broadcast))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -260,8 +263,9 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -280,7 +284,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { } } auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR3FromArray3D(input_vals))); + LiteralUtil::CreateR3FromArray3D(input_vals))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -291,8 +295,9 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index bc64a19ce22..2086e38b919 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -76,7 +77,8 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); - auto constant = ConstantLiteral(&builder, *Literal::CreateR0(42.0)); + auto constant = + ConstantLiteral(&builder, *LiteralUtil::CreateR0(42.0)); Call(&builder, callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); @@ -85,8 +87,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); - auto x = ConstantLiteral(&builder, *Literal::CreateR1({})); - auto y = ConstantLiteral(&builder, *Literal::CreateR1({})); + auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); + auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); @@ -95,8 +97,10 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); - auto x = ConstantLiteral(&builder, *Literal::CreateR1({1.0f, 2.0f})); - auto y = ConstantLiteral(&builder, *Literal::CreateR1({2.0f, 3.0f})); + auto x = + ConstantLiteral(&builder, *LiteralUtil::CreateR1({1.0f, 2.0f})); + auto y = + ConstantLiteral(&builder, *LiteralUtil::CreateR1({2.0f, 3.0f})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); @@ -129,15 +133,15 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr start, - client_->TransferToServer(*Literal::CreateR0(1.0f))); + client_->TransferToServer(*LiteralUtil::CreateR0(1.0f))); ComputeAndCompareR0(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f)); } XLA_TEST_F(CallOpTest, CallR0F32Tuple) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32TupleComputation(); - auto elem = Literal::CreateR0(42.0); - auto tuple = Literal::MakeTuple({elem.get()}); + auto elem = LiteralUtil::CreateR0(42.0); + auto tuple = LiteralUtil::MakeTuple({elem.get()}); Call(&builder, callee, {ConstantLiteral(&builder, *elem)}); ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 1ad57c075b2..0bc8facfe2c 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -36,7 +36,7 @@ class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); - auto param_literal = Literal::CreateR1({1.1f, 2.2f}); + auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0"); auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1"); @@ -85,12 +85,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_IS_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto f32_literal = Literal::CreateR0(1.1f); + auto f32_literal = LiteralUtil::CreateR0(1.1f); auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); - auto f32_4_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); + auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); - auto u8_4_literal = Literal::CreateR1U8("hola"); + auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); // Match diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index dafd6ebabbe..ef784da457b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -157,7 +157,7 @@ string ClientLibraryTestBase::ExecuteToString( void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = Literal::CreateR1(expected); + std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -295,7 +295,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = Literal::ConvertF32ToBF16(expected); + converted_expected = LiteralUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -347,7 +347,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = Literal::ConvertF32ToBF16(expected); + converted_expected = LiteralUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -389,7 +389,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = Literal::CreateR1U8(expected); + std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); VLOG(1) << "expected: " << expected_literal->ToString(); VLOG(1) << "actual: " << actual->ToString(); @@ -560,8 +560,9 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { - return ConstantLiteral( - builder, use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); + return ConstantLiteral(builder, use_bfloat16_ + ? *LiteralUtil::ConvertF32ToBF16(literal) + : literal); } std::unique_ptr @@ -582,7 +583,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( const Literal* param_literal = &literal; std::unique_ptr converted_literal; if (use_bfloat16_) { - converted_literal = Literal::ConvertF32ToBF16(literal); + converted_literal = LiteralUtil::ConvertF32ToBF16(literal); param_literal = converted_literal.get(); } std::unique_ptr data = diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 5361ae6783c..fcc9347db51 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -284,7 +285,7 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { - return AddParam(*Literal::CreateFromArray(argument), builder); + return AddParam(*LiteralUtil::CreateFromArray(argument), builder); } // Creates a constant instruction with the given literal. When the @@ -299,13 +300,14 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp CreateConstantFromArray(const Array& array, XlaBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array), + builder); } // Same as CreateConstantFromArray, but for scalars. template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateR0(value), + return CreateConstantFromLiteral(*LiteralUtil::CreateR0(value), builder); } @@ -410,7 +412,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR0(expected); + LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -426,7 +428,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR0(expected); + LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -436,7 +438,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR1(expected); + LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -452,7 +454,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR1(expected); + LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -462,7 +464,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR2FromArray2D(expected); + LiteralUtil::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -478,7 +480,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR2FromArray2D(expected); + LiteralUtil::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -488,7 +490,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR3FromArray3D(expected); + LiteralUtil::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -504,7 +506,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR3FromArray3D(expected); + LiteralUtil::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -514,7 +516,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR4FromArray4D(expected); + LiteralUtil::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -530,7 +532,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR4FromArray4D(expected); + LiteralUtil::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -539,9 +541,9 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR0(value); + std::unique_ptr literal = LiteralUtil::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -553,9 +555,9 @@ template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR1(values); + std::unique_ptr literal = LiteralUtil::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -567,9 +569,9 @@ template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); + std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -581,9 +583,9 @@ template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); + std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 831b863998f..6ce2f844a34 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -56,7 +56,7 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { client_->Execute(computation, {}, &execution_options)); std::unique_ptr expected_literal = - Literal::CreateR2WithLayout( + LiteralUtil::CreateR2WithLayout( {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); TF_ASSERT_OK_AND_ASSIGN( @@ -112,9 +112,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr const_arg, - client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr const_arg, + client_->TransferToServer( + *LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); Add(Parameter(&b, 0, shape, "param_0"), @@ -136,7 +136,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN(auto results, client_->ExecuteParallel(computation_instances)); - auto expected_result = Literal::CreateR2({{6, 8}, {10, 12}}); + auto expected_result = LiteralUtil::CreateR2({{6, 8}, {10, 12}}); TF_ASSERT_OK_AND_ASSIGN( auto result_literal, diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index eb211dd8ff3..ff382462867 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -50,7 +50,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { &execution_profile) .ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR0(expected_result), *result, error_spec_)); + *LiteralUtil::CreateR0(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -67,7 +67,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2(expected_result), *result, error_spec_)); + *LiteralUtil::CreateR2(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -89,13 +89,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*Literal::CreateR0(42.0f)) + client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*Literal::CreateR0(123.0f)) + client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*Literal::CreateR0(456.0f)) + client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); @@ -143,12 +143,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { // layouts. Use these arrays as parameters to a simple computation. If the // layout of the array changes then computation should be recompiled (cache // miss). - auto rowmaj_array = Literal::CreateR2WithLayout( + auto rowmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); - auto colmaj_array = Literal::CreateR2WithLayout( + auto colmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 1a396b090c6..64bf8b3b387 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -207,7 +207,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = - Literal::CreateR1({4, 6}); + LiteralUtil::CreateR1({4, 6}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -221,7 +221,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = Literal::CreateR0(5); + std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -242,8 +242,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) { &b, &layout_proto)); std::unique_ptr expected_literal = - Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, - LayoutUtil::MakeLayout(layout)); + LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( expected_literal->shape(), computed->shape())); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 1161b560b7b..9f288634c0f 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -534,8 +534,8 @@ TEST_P(ConcatR2BinaryTest, DoIt) { // concat XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = Literal::CreateR0(2.f); - auto y_literal = Literal::CreateR0(3.f); + auto x_literal = LiteralUtil::CreateR0(2.f); + auto y_literal = LiteralUtil::CreateR0(3.f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); @@ -556,9 +556,9 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { // produces the correct result in rank 1. XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = Literal::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); - auto y_literal = Literal::CreateR0(1.5f); - auto z_literal = Literal::CreateR0(5.5f); + auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); + auto y_literal = LiteralUtil::CreateR0(1.5f); + auto z_literal = LiteralUtil::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); @@ -584,9 +584,9 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); Array3D x3d(3, 5, 7, 3.14f); - auto x_literal = Literal::CreateR3FromArray3D(x3d); - auto y_literal = Literal::CreateR0(1.5f); - auto z_literal = Literal::CreateR0(5.5f); + auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); + auto y_literal = LiteralUtil::CreateR0(1.5f); + auto z_literal = LiteralUtil::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index ee3c83039bf..35f1400fb2a 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -344,8 +344,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0(12.0f).get(), - Literal::CreateR0(25.0f).get()}), + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12.0f).get(), + LiteralUtil::CreateR0(25.0f).get()}), {}, error_spec_); } @@ -361,8 +361,9 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR1({13.0f, 16.0f}).get(), - Literal::CreateR1({26.0f, 30.0f}).get()}), + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({13.0f, 16.0f}).get(), + LiteralUtil::CreateR1({26.0f, 30.0f}).get()}), {}, error_spec_); } @@ -399,9 +400,10 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0(true).get(), - Literal::CreateR0(12.2f).get(), - Literal::CreateR1({12.8f, 14.6f}).get()}), + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(true).get(), + LiteralUtil::CreateR0(12.2f).get(), + LiteralUtil::CreateR1({12.8f, 14.6f}).get()}), {}, error_spec_); } @@ -443,12 +445,14 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { ComputeAndCompareTuple( &builder, - *Literal::MakeTuple( - {Literal::MakeTuple({Literal::CreateR0(46.6f).get(), - Literal::CreateR1({54.4f, 58.4f}).get()}) + *LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(46.6f).get(), + LiteralUtil::CreateR1({54.4f, 58.4f}).get()}) .get(), - Literal::MakeTuple({Literal::CreateR1({62.1f, 67.4f}).get(), - Literal::CreateR0(9.3f).get()}) + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({62.1f, 67.4f}).get(), + LiteralUtil::CreateR0(9.3f).get()}) .get()}), {}, error_spec_); } @@ -607,8 +611,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0(a).get(), - Literal::CreateR0(b).get()}), + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(a).get(), + LiteralUtil::CreateR0(b).get()}), {}, error_spec_); }; diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index cc5d3b11767..71d72a9828c 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -110,8 +110,8 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - ConstantLiteral( - &builder, *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); + ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D( + Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); } @@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - ConstantLiteral(&builder, *Literal::CreateR3FromArray3D(array3d)); + ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -141,7 +141,7 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { }); input_array.FillWithPZ(pz); std::unique_ptr input_literal = - Literal::CreateR4FromArray4D(input_array); + LiteralUtil::CreateR4FromArray4D(input_array); { XlaBuilder builder(TestName()); @@ -159,22 +159,23 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *Literal::MakeTuple( - {Literal::CreateR2({{1.0}, {2.0}}).get(), - Literal::CreateR1({2.0, 42}).get()})); + ConstantLiteral(&builder, + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), + LiteralUtil::CreateR1({2.0, 42}).get()})); std::unique_ptr result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); - LiteralTestUtil::ExpectR2Near( - {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_); - LiteralTestUtil::ExpectR1Near( - {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); + LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, + LiteralSlice(*result, {0}), error_spec_); + LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(*result, {1}), + error_spec_); } TEST_F(ConstantsTest, Token) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *Literal::CreateToken()); + ConstantLiteral(&builder, *LiteralUtil::CreateToken()); // TODO(b/80000000): tokens cannot be returned from computations. Tuple(&builder, {}); TF_ASSERT_OK(Execute(&builder, {}).status()); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 292942a49e2..dca57fd1c70 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -145,7 +145,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast(0x8000008000000000LL), static_cast(0x8000010000000000LL), }; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -164,7 +164,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -182,7 +182,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -199,7 +199,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -216,7 +216,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -253,7 +253,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { 9223370937343148032.f, -9223371487098961920.f, -9223370937343148032.f}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -391,7 +391,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*Literal::CreateR1(input))); + client_->TransferToServer(*LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( @@ -411,7 +411,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*Literal::CreateR1(input))); + client_->TransferToServer(*LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 7605ebf4c0e..944366410b1 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -93,7 +93,8 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto weight_array = MakeUnique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array)) + client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 0f6d54d042d..a8b8f74ca96 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {7.0f, 8.0f}, })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { Array3D expected({{{570.0f, 670.0f, 770.0f}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { Array3D expected({{{190, 320, 230, 380, 270, 440, 310, 500}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -434,15 +434,15 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota(input_elems.begin(), input_elems.end(), 1.0f); - auto input_r1 = Literal::CreateR1(input_elems); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota(filter_elems.begin(), filter_elems.end(), 1.0f); - auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - auto expected_r1 = Literal::CreateR1( + auto expected_r1 = LiteralUtil::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); @@ -497,15 +497,15 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); - auto input_r1 = Literal::CreateR1(input_elems); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); - auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - auto expected_r1 = Literal::CreateR1( + auto expected_r1 = LiteralUtil::CreateR1( {static_cast(92115), static_cast(93150), static_cast(94185)}); auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); @@ -561,8 +561,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, expected_result.Fill(0); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(param0)), - std::move(*Literal::CreateFromArray(param1))}, + {std::move(*LiteralUtil::CreateFromArray(param0)), + std::move(*LiteralUtil::CreateFromArray(param1))}, error_spec_); } @@ -617,18 +617,18 @@ class Convolve1D1WindowTestBase std::vector input_elems(ShapeUtil::ElementsIn(input_shape), static_cast(1.0f)); - auto input_r1 = Literal::CreateR1(input_elems); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), static_cast(1.0f)); - auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); std::vector expect_elems(batch * output_feature * num_windows, static_cast(window_size * input_feature)); - auto expected_r1 = Literal::CreateR1(expect_elems); + auto expected_r1 = LiteralUtil::CreateR1(expect_elems); auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature}) .ConsumeValueOrDie(); @@ -737,8 +737,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } @@ -761,8 +761,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { filter_data.FillIota(10); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}); + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index c31d033bb0f..8792e7781b1 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -1333,17 +1333,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { XlaBuilder builder(TestName()); - auto gradients_flat = Literal::CreateR1({1}); + auto gradients_flat = LiteralUtil::CreateR1({1}); auto gradients_literal = gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto gradients = ConstantLiteral(&builder, *gradients_literal); - auto weights_flat = Literal::CreateR1({1, 10, 100}); + auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); auto weights_literal = weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto weights = ConstantLiteral(&builder, *weights_literal); - auto expected_flat = Literal::CreateR1({10}); + auto expected_flat = LiteralUtil::CreateR1({10}); auto expected_literal = expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); @@ -1357,17 +1357,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { XlaBuilder builder(TestName()); - auto activations_flat = Literal::CreateR1({1, 2, 3, 4}); + auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); auto activations_literal = activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); auto activations = ConstantLiteral(&builder, *activations_literal); - auto gradients_flat = Literal::CreateR1({100, 10, 1}); + auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); auto gradients_literal = gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto gradients = ConstantLiteral(&builder, *gradients_literal); - auto expected_flat = Literal::CreateR1({13, 24, 130}); + auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); auto expected_literal = expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index fef42885e51..1dc6ff0f4f5 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -58,37 +58,38 @@ class CopyOpTest : public HloTestBase { }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*Literal::CreateR0(true)); + TestCopyOp(*LiteralUtil::CreateR0(true)); } XLA_TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*Literal::CreateR1({})); + TestCopyOp(*LiteralUtil::CreateR1({})); } XLA_TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*Literal::CreateR1({1, 2, 3})); + TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); } XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp(*Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp( + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*Literal::CreateR4( + TestCopyOp(*LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*Literal::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto builder = HloComputation::Builder(TestName()); // Copy literal to device to use as parameter. - auto literal = Literal::CreateR0(42.0); + auto literal = LiteralUtil::CreateR0(42.0); Shape shape = literal->shape(); auto param0 = builder.AddInstruction( @@ -109,7 +110,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto builder = HloComputation::Builder(TestName()); - auto literal = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -131,7 +132,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); std::unique_ptr literal = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal->mutable_shape_do_not_use()->mutable_layout(); @@ -168,7 +169,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = Literal::CreateR3FromArray3D(a); + std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -202,7 +203,7 @@ void CopyOpTest::TestCopyConstantLayoutR4( HloComputation::Builder builder(TestName()); - std::unique_ptr literal = Literal::CreateR4FromArray4D(a); + std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index b151187c4b8..d12a4e7fcd7 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -45,7 +45,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); - auto literal = Literal::CreateR1({1, 2, 3}); + auto literal = LiteralUtil::CreateR1({1, 2, 3}); EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); } @@ -66,10 +66,10 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); - auto literal0 = Literal::CreateR1({1, 2, 3}); - auto literal1 = Literal::CreateR1({10, 20}); + auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); + auto literal1 = LiteralUtil::CreateR1({10, 20}); EXPECT_EQ( - *Literal::MakeTuple({literal0.get(), literal1.get()}), + *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); } @@ -93,9 +93,9 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); - auto literal0 = Literal::CreateR1({1, 2, 3}); - auto literal1 = Literal::CreateR1({10, 20}); - EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}), + auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); + auto literal1 = LiteralUtil::CreateR1({10, 20}); + EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), *ExecuteAndTransfer(std::move(module), {literal0.get()})); } diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index d1516a28b0b..90f3d1b874f 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -74,7 +74,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2")); @@ -95,7 +95,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { array(1, 1) = 4.0f; auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2FromArray2D(array))); + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum")); @@ -111,7 +111,7 @@ XLA_TEST_F(CustomCallTest, auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2FromArray2D( + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( Array2D{{1.0f, 2.0f}, {3.0f, 4.0f}}))); auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues")); diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index acba67491d2..a6a233e71aa 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -171,7 +171,7 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({3.14f, -100.25f}); + LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index cf2e645d472..d86fd7cc2d4 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -67,15 +67,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, - *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), - Literal::CreateR2({{5, 6}, {7, 8}}).get()}), + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), + LiteralUtil::CreateR2({{5, 6}, {7, 8}}).get()}), "arg0", &builder, ¶m); auto lhs = GetTupleElement(param, 0); auto rhs = GetTupleElement(param, 1); Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, - *Literal::CreateR2({{19, 22}, {43, 50}}), + *LiteralUtil::CreateR2({{19, 22}, {43, 50}}), {param_data.get()}); } @@ -194,11 +195,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { auto lhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2D( + ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2D( + ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) .ConsumeValueOrDie(); @@ -217,14 +218,14 @@ class SquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f}, {3.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -286,9 +287,10 @@ void ParametricDotTest::TestImpl() { std::unique_ptr> dot_lhs_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); - std::unique_ptr dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( - *dot_lhs_data, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); + std::unique_ptr dot_lhs_lit = + LiteralUtil::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor( + param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); @@ -297,7 +299,7 @@ void ParametricDotTest::TestImpl() { Layout rhs_layout = LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); std::unique_ptr dot_rhs_lit = - Literal::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); + LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); @@ -307,7 +309,7 @@ void ParametricDotTest::TestImpl() { if (param.has_addend) { addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); - addend_lit = Literal::CreateR2FromArray2DWithLayout( + addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); @@ -476,14 +478,14 @@ class NonsquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -510,12 +512,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout( + ->TransferToServer(*LiteralUtil::CreateR2WithLayout( {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout( + ->TransferToServer(*LiteralUtil::CreateR2WithLayout( {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); @@ -583,7 +585,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ - ->TransferToServer(*Literal::CreateR4FromArray4D( + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, @@ -591,7 +593,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*Literal::CreateR4FromArray4D( + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{11.0f, 22.0f}, {33.0f, 44.0f}}, {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) @@ -629,13 +631,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { auto x_data = this->client_ - ->TransferToServer(*Literal::CreateR3FromArray3D( + ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*Literal::CreateR3FromArray3D( + ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) .ConsumeValueOrDie(); @@ -664,15 +666,17 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { } auto lhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - *lhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + ->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + *lhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - *rhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + ->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + *rhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); XlaBuilder builder(this->TestName()); @@ -733,15 +737,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); this->template ComputeAndCompareR2( @@ -782,15 +786,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); this->template ComputeAndCompareR2( diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index f3c258a4d4c..b063b6bdef1 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -124,11 +124,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { // vector is special so that it cannot be an ArraySlice, which // is what the code below wants. So instead we do this. Literal input_values = - std::move(*Literal::CreateR1(input_values_int) + std::move(*LiteralUtil::CreateR1(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR1(expected_values_int) + std::move(*LiteralUtil::CreateR1(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -150,11 +150,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array2D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR2FromArray2D(input_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR2FromArray2D(expected_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -176,11 +176,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array3D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR3FromArray3D(input_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR3FromArray3D(expected_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -349,15 +349,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { void RunR0(int input_value_int, int update_value_int, const std::vector slice_starts, int expected_value_int) { Literal input_value = - std::move(*Literal::CreateR0(input_value_int) + std::move(*LiteralUtil::CreateR0(input_value_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_value = - std::move(*Literal::CreateR0(update_value_int) + std::move(*LiteralUtil::CreateR0(update_value_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_value = - std::move(*Literal::CreateR0(expected_value_int) + std::move(*LiteralUtil::CreateR0(expected_value_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -380,15 +380,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, tensorflow::gtl::ArraySlice expected_values_int) { Literal input_values = - std::move(*Literal::CreateR1(input_values_int) + std::move(*LiteralUtil::CreateR1(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_values = - std::move(*Literal::CreateR1(update_values_int) + std::move(*LiteralUtil::CreateR1(update_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR1(expected_values_int) + std::move(*LiteralUtil::CreateR1(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -411,15 +411,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array2D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR2FromArray2D(input_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_values = - std::move(*Literal::CreateR2FromArray2D(update_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR2FromArray2D(expected_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -442,15 +442,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array3D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR3FromArray3D(input_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_values = - std::move(*Literal::CreateR3FromArray3D(update_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR3FromArray3D(expected_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -520,7 +520,7 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { std::unique_ptr literal = - Literal::CreateR3FromArray3D(values); + LiteralUtil::CreateR3FromArray3D(values); LOG(INFO) << name << ":" << literal->ToString(); } }; @@ -695,7 +695,7 @@ void BM_DynamicSlice(int num_iters) { XlaBuilder builder("DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] - auto input_literal = Literal::CreateR4( + auto input_literal = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); auto input = ConstantLiteral(&builder, *input_literal); @@ -715,7 +715,7 @@ void BM_DynamicSlice(int num_iters) { start_indices_shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); - auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); + auto start_indices_literal = LiteralUtil::CreateR1({0, 1, 2, 3}); auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index ddc6a7db187..ebba13c5b39 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr input, client_->TransferToServer( - *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256))); + *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); XlaBuilder b(TestName() + ".add"); Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1")); diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 74cf8b213e0..86bfaea4ef4 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -39,7 +39,7 @@ class ExhaustiveF32ElementwiseOpTest XlaBuilder builder(TestName()); std::unique_ptr input_literal = - Literal::CreateFromDimensions(F32, {input_size}); + LiteralUtil::CreateFromDimensions(F32, {input_size}); for (int64 i = begin; i < end; i++) { if (i >= known_incorrect_range.first && i < known_incorrect_range.second) { diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index f7f9a87413e..dc644779357 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -90,7 +90,7 @@ class FusionTest : public HloTestBase { HloInstruction* hlos[4]; for (int i = 0; i < Arity; ++i) { hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(operand_data[i]))); + LiteralUtil::CreateR2FromArray2D(operand_data[i]))); } auto answer_shape = ShapeUtil::MakeShape(prim_type, {test_width, test_height}); @@ -116,7 +116,7 @@ class FusionTest : public HloTestBase { ArraySlice(hlos, 0, Arity + 1), HloInstruction::FusionKind::kLoop); - auto expected = Literal::CreateR2FromArray2D(answer_data); + auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); @@ -187,27 +187,28 @@ XLA_TEST_F(FusionTest, Test) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0}, {2.0}, {3.0}}))); + LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); + LiteralUtil::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1)); auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0})); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.62, 2.72, 3.14}}))); + LiteralUtil::CreateR2({{1.62, 2.72, 3.14}}))); auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0)); auto const6 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); + LiteralUtil::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6)); auto add8 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7)); auto const9 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); - auto const10 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{true, false, true}, {false, true, false}}))); + LiteralUtil::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); + auto const10 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{true, false, true}, {false, true, false}}))); auto select11 = builder.AddInstruction( HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kSelect, const10, add8, const9)); @@ -223,7 +224,7 @@ XLA_TEST_F(FusionTest, Test) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{0.5}, {2.72}}), + *LiteralUtil::CreateR2({{0.5}, {2.72}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } @@ -234,11 +235,11 @@ XLA_TEST_F(FusionTest, Parameter) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0, 3.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0)); auto const2 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-2.0, -2.0, -2.0}}))); + LiteralUtil::CreateR2({{-2.0, -2.0, -2.0}}))); // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1} auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2)); @@ -249,7 +250,7 @@ XLA_TEST_F(FusionTest, Parameter) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{-1.0, 0.0, 1.0}}), + *LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } @@ -270,7 +271,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { auto hlo_module = CreateNewModule(); auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto x = builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {})); auto y = builder.AddInstruction( @@ -293,9 +294,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); + LiteralUtil::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1})); // add2 = broadcast(const_vector) + const_array @@ -309,7 +310,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } @@ -317,14 +318,14 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto single_element_array = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{5}}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {}), single_element_array)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(5), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(5), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -332,14 +333,14 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 2, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -347,14 +348,14 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -362,14 +363,14 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR3({{{7}}}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(7), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -377,14 +378,14 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(7))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR3({{{7}}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{7}}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -392,14 +393,14 @@ XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(7))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(7), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -407,14 +408,14 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -422,14 +423,14 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -437,14 +438,14 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -452,7 +453,7 @@ XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( ShapeUtil::MakeShape(S32, {3}), const0, {0})); hlo_module->AddEntryComputation(builder.Build()) @@ -460,7 +461,7 @@ XLA_TEST_F(FusionTest, Reverse) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({3, 2, 1}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({3, 2, 1}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -468,7 +469,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( ShapeUtil::MakeShape(S32, {3}), const0, {0})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -478,7 +479,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-3, -2, -1}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-3, -2, -1}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -486,7 +487,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {2}), const0, {})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -496,15 +497,15 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-1, -1}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -1}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -514,17 +515,17 @@ XLA_TEST_F(FusionTest, SliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-1, -3}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -3}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1}))); auto dynamic_slice2 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( ShapeUtil::MakeShape(S32, {2}), const0, const1, {2})); @@ -536,15 +537,15 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-2, -3}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-2, -3}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0)); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -553,16 +554,16 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2}, {3, 4}}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}}))); auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -571,9 +572,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -591,10 +592,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -603,7 +604,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(15), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(15), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -611,10 +612,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -625,7 +626,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(-15), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(-15), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -633,9 +634,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); + LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Window window; ASSERT_TRUE( tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n" @@ -675,7 +676,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{462, 2145}, {24871, 62491}}), + *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -687,9 +688,9 @@ XLA_TEST_F(FusionTest, SharedConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -711,7 +712,7 @@ XLA_TEST_F(FusionTest, SharedConstant) { EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({8}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({8}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -784,7 +785,7 @@ ENTRY main { )"; std::unique_ptr operand = - Literal::CreateR2({{0., 0.}, {1., 0.}}); + LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -794,7 +795,7 @@ ENTRY main { test_runner_.Execute(std::move(module), {operand.get()}, /*run_hlo_passes=*/false)); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), + *LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), *result)); } @@ -838,19 +839,19 @@ void BM_ParallelFusion(int num_iters) { // Transfer literals to device. auto param0_literal = - Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); + LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); ScopedShapedBuffer buffer0 = client->LiteralToShapedBuffer(*param0_literal, device_ordinal) .ConsumeValueOrDie(); auto param1_literal = - Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); + LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); ScopedShapedBuffer buffer1 = client->LiteralToShapedBuffer(*param1_literal, device_ordinal) .ConsumeValueOrDie(); auto param2_literal = - Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); + LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); ScopedShapedBuffer buffer2 = client->LiteralToShapedBuffer(*param2_literal, device_ordinal) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index b8404826b16..9178b505958 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -63,8 +63,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -84,8 +85,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -105,9 +107,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -127,9 +129,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -149,9 +151,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -171,11 +173,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -195,11 +197,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -219,8 +221,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({1, 1}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -240,9 +243,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -261,8 +264,9 @@ ENTRY main { window_bounds={1, 0} } )"; - std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -292,11 +296,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR2( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); std::unique_ptr in_bounds_mask = - Literal::CreateR1({0, 1, 1, 0, 0, 1}); + LiteralUtil::CreateR1({0, 1, 1, 0, 0, 1}); RunTest(hlo_text, {operand.get(), gather_indices.get(), in_bounds_mask.get()}); @@ -328,11 +332,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR2( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); std::unique_ptr in_bounds_mask = - Literal::CreateR1({0, 1, 1, 0, 0, 1}); + LiteralUtil::CreateR1({0, 1, 1, 0, 0, 1}); RunTest(hlo_text, {operand.get(), gather_indices.get(), in_bounds_mask.get()}); @@ -353,9 +357,9 @@ ENTRY main { window_bounds={1,3,2} } )"; - std::unique_ptr operand = Literal::CreateR3( + std::unique_ptr operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr gather_indices = Literal::CreateR0(1); + std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -374,8 +378,8 @@ ENTRY main { window_bounds={1} } )"; - std::unique_ptr operand = Literal::CreateR1({1, 2, 3, 4}); - std::unique_ptr gather_indices = Literal::CreateR0(1); + std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -395,8 +399,8 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR1({}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -419,8 +423,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -443,9 +448,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -468,9 +473,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -493,11 +498,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -521,11 +526,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -548,8 +553,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({1, 1}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -572,9 +578,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -609,12 +615,13 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { Gather(operand, indices, dim_numbers, {1, 3}); std::vector expected = {}; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr operand_arg, - client_->TransferToServer(*Literal::CreateR2( - {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr operand_arg, + client_->TransferToServer( + *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr indices_arg, - client_->TransferToServer(*Literal::CreateR1({0, 2}))); + client_->TransferToServer(*LiteralUtil::CreateR1({0, 2}))); TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index fd851188490..73a47eda721 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index d1b8a6cf0b2..31a099c15f1 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -154,20 +155,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR0(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR1(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR2(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR2(expected), actual)); } template @@ -175,46 +176,46 @@ template std::initializer_list>> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR3(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR0(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR1(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR2(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR2(expected), actual, error)); } template @@ -222,7 +223,7 @@ template std::initializer_list>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR3(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR3(expected), actual, error)); } template @@ -231,28 +232,28 @@ template std::initializer_list>>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR4(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index bbac7285aef..f297b2b847f 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -31,8 +31,9 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = Literal::MakeTuple({ - Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), + std::unique_ptr literal = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR0(64).get(), }); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); } @@ -42,11 +43,13 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = Literal::MakeTuple({ - Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), + std::unique_ptr lhs = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR0(64).get(), }); - std::unique_ptr rhs = Literal::MakeTuple({ - Literal::CreateR0(64).get(), Literal::CreateR0(42).get(), + std::unique_ptr rhs = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(64).get(), + LiteralUtil::CreateR0(42).get(), }); CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; }; @@ -55,8 +58,8 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto dummy_lambda = [] { - auto two = Literal::CreateR0(2); - auto four = Literal::CreateR0(4); + auto two = LiteralUtil::CreateR0(2); + auto four = LiteralUtil::CreateR0(4); ErrorSpec error(0.001); CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; }; @@ -98,8 +101,8 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { } TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { - auto expected = Literal::CreateR1({1, 2, 3}); - auto actual = Literal::CreateR1({4, 5, 6}); + auto expected = LiteralUtil::CreateR1({1, 2, 3}); + auto actual = LiteralUtil::CreateR1({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(*expected, *actual); EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); @@ -107,25 +110,26 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { } TEST(LiteralTestUtilTest, NearComparatorR1) { - auto a = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - auto b = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto a = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); } TEST(LiteralTestUtilTest, NearComparatorR1Nan) { - auto a = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); - auto b = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + auto a = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + auto b = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); } TEST(LiteralTestUtil, NearComparatorDifferentLengths) { - auto a = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - auto b = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); + auto a = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = + LiteralUtil::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); } diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 082bc34136e..13df83fffff 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" @@ -64,7 +65,7 @@ class LLVMCompilerTest : public ::testing::Test { // Create HLO module, and run the compiler. auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); @@ -86,7 +87,7 @@ class LLVMCompilerTest : public ::testing::Test { void TestMultiModuleCompilation(LLVMCompiler *compiler) { HloComputation::Builder builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); std::unique_ptr hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 9191be9fd90..0df50150aee 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); int64 allocation_count_before = allocator_->allocation_count(); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index fd74cadea2f..7c003fb81fe 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -68,7 +68,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { auto y = ConstantR0(&builder, 123.0f); Add(x, y); - auto x_value = LiteralToShapedBuffer(*Literal::CreateR0(42.0f)); + auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0(42.0f)); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(result), @@ -81,7 +81,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { auto y = ConstantR1(&builder, {}); Add(x, y); - auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({})); + auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1({})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(result), @@ -95,7 +95,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near( @@ -109,7 +109,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; ScopedShapedBuffer result = ExecuteLocallyOrDie( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), @@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( @@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { Tuple(&builder, {x, y}); auto array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); ExecutableBuildOptions options = DefaultExecutableBuildOptions(); Shape shape_with_layout = ShapeUtil::MakeTupleShape( @@ -298,12 +298,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { Tuple(&builder, {array_sum, vector_diff}); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_literal = Literal::MakeTuple( - {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - Literal::CreateR1({42.0, 75.0, 123.0}).get()}); - auto y_literal = Literal::MakeTuple( - {Literal::CreateR1({2.0, 4.0, 6.0}).get(), - Literal::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); + auto x_literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}); + auto y_literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({2.0, 4.0, 6.0}).get(), + LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); auto x_buffer = LiteralToShapedBuffer(*x_literal); auto y_buffer = LiteralToShapedBuffer(*y_literal); @@ -344,12 +344,12 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { Tuple(&builder, {negate_array, vector_sum}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = Literal::MakeTuple( - {Literal::MakeTuple( - {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - Literal::CreateR1({42.0, 75.0, 123.0}).get()}) + auto arg_literal = LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}) .get(), - Literal::CreateR1({222.0, -2.0, 10.0}).get()}); + LiteralUtil::CreateR1({222.0, -2.0, 10.0}).get()}); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); @@ -377,9 +377,9 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = Literal::MakeTuple( - {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - Literal::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); + auto arg_literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); @@ -429,10 +429,10 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { // -tuple_index}. std::vector> arg_elements; for (int i = 0; i < kElementCount; ++i) { - arg_elements.push_back(Literal::CreateR1({1.0f * i, -1.0f * i})); + arg_elements.push_back(LiteralUtil::CreateR1({1.0f * i, -1.0f * i})); } std::unique_ptr arg_literal = - Literal::MakeTupleOwned(std::move(arg_elements)); + LiteralUtil::MakeTupleOwned(std::move(arg_elements)); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); @@ -480,12 +480,13 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { for (int i = 0; i < kFanout; ++i) { std::vector> inner_tuple_elements; for (int j = 0; j < kFanout; ++j) { - inner_tuple_elements.push_back(Literal::CreateR0(i + j)); + inner_tuple_elements.push_back(LiteralUtil::CreateR0(i + j)); } outer_tuple_elements.push_back( - Literal::MakeTupleOwned(std::move(inner_tuple_elements))); + LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements))); } - auto arg_literal = Literal::MakeTupleOwned(std::move(outer_tuple_elements)); + auto arg_literal = + LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements)); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); @@ -524,11 +525,11 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::unique_ptr arg_literal = Literal::CreateR0(123.0); + std::unique_ptr arg_literal = LiteralUtil::CreateR0(123.0); for (int i = 0; i < kTupleDepth; ++i) { std::vector> arg_vector; arg_vector.push_back(std::move(arg_literal)); - arg_literal = Literal::MakeTupleOwned(std::move(arg_vector)); + arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector)); } auto arg_buffer = LiteralToShapedBuffer(*arg_literal); @@ -551,7 +552,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({1.0f, 2.0f, 3.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -567,7 +568,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { Neg(x); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -584,7 +585,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { Neg(x); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions().set_result_layout( @@ -767,7 +768,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { executable_status.ConsumeValueOrDie(); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -791,29 +792,29 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { }; // Array shapes. - test_to_device_and_back(*Literal::CreateR0(42.0)); - test_to_device_and_back(*Literal::CreateR0(true)); - test_to_device_and_back(*Literal::CreateR1({1.0, 42.0, 744.4})); + test_to_device_and_back(*LiteralUtil::CreateR0(42.0)); + test_to_device_and_back(*LiteralUtil::CreateR0(true)); + test_to_device_and_back(*LiteralUtil::CreateR1({1.0, 42.0, 744.4})); test_to_device_and_back( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*Literal::CreateR2({{2, 1}, {4444, 56}})); + *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); // Null shape (empty tuple). - test_to_device_and_back(*Literal::MakeTuple({})); + test_to_device_and_back(*LiteralUtil::MakeTuple({})); // Non-nested tuples. test_to_device_and_back( - *Literal::MakeTuple({Literal::CreateR0(12223.0).get()})); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12223.0).get()})); test_to_device_and_back( - *Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), - Literal::CreateR0(123456.0).get()})); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), + LiteralUtil::CreateR0(123456.0).get()})); // Nested tuple. - test_to_device_and_back(*Literal::MakeTuple( - {Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), - Literal::CreateR0(123456.0).get()}) + test_to_device_and_back(*LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), + LiteralUtil::CreateR0(123456.0).get()}) .get(), - Literal::CreateR0(false).get()})); + LiteralUtil::CreateR0(false).get()})); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { @@ -831,13 +832,13 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { }; test_to_device_and_back( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*Literal::CreateR2({{2, 1}, {4444, 56}})); + *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); test_to_device_and_back( - *Literal::CreateR2({{20000000000ULL, 1}, {4444, 56}})); - test_to_device_and_back( - *Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), - Literal::CreateR0(123456789000LL).get()})); + *LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); + test_to_device_and_back(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({1.0, -42.0}).get(), + LiteralUtil::CreateR0(123456789000LL).get()})); } XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { @@ -856,7 +857,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *Literal::CreateR1({-5.0, 123.0, 42.0}), + *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); // Join the thread. @@ -881,7 +882,7 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) { [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *Literal::CreateR1({-5.0, 123.0, 42.0}), + *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, @@ -916,7 +917,7 @@ void BM_LocalClientOverhead(int num_iters) { transfer_manager ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); - auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); + auto literal = LiteralUtil::CreateR2({{0, 0, 0}, {0, 0, 0}}); auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 1b3bc9d5040..7ddc6369319 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -169,7 +169,7 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR0(42.0); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -183,7 +183,7 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR1({}); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -198,7 +198,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -212,7 +212,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -225,7 +225,7 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -239,7 +239,7 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); + LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -255,7 +255,7 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR1({}); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -272,7 +272,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -287,7 +287,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR2( + std::unique_ptr param0_literal = LiteralUtil::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -343,11 +343,11 @@ TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -365,12 +365,12 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR2WithLayout( + std::unique_ptr param0_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = Literal::CreateR2WithLayout( + std::unique_ptr param1_literal = LiteralUtil::CreateR2WithLayout( {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -392,12 +392,12 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); + LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); + LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -414,15 +414,15 @@ TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); std::unique_ptr param2_literal = - Literal::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); + LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); @@ -476,11 +476,11 @@ TEST_F(MapTest, MapOperantionWithBuildError) { auto error_add = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -513,8 +513,8 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = Literal::CreateR0(2.0f); - std::unique_ptr param1_literal = Literal::CreateR0(5.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); + std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -540,8 +540,8 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = Literal::CreateR0(2.0f); - std::unique_ptr param1_literal = Literal::CreateR0(5.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); + std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -565,7 +565,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) { Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = Literal::CreateR0(10.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 17b1807f44a..069b8a881f4 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -63,8 +63,8 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { Exp(data); std::unique_ptr expected = - Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 - {0.36788f, 1.64872f}}); // row 1 + LiteralUtil::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 + {0.36788f, 1.64872f}}); // row 1 this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -92,8 +92,8 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { Map(&builder, {data}, add_half, {0, 1}); std::unique_ptr expected = - Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 - {-0.5f, 1.0f}}); // row 1 + LiteralUtil::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 + {-0.5f, 1.0f}}); // row 1 this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -111,8 +111,8 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { Max(lhs, rhs); std::unique_ptr expected = - Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 - {3.0f, -4.0f}}); // row 1 + LiteralUtil::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 + {3.0f, -4.0f}}); // row 1 this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); } @@ -200,12 +200,14 @@ class MatOpsDotAddTest TF_ASSERT_OK_AND_ASSIGN( auto lhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); TF_ASSERT_OK_AND_ASSIGN( auto rhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); XlaBuilder builder(TestName()); auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs"); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 6597748c8d1..eb06b115daa 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -60,7 +60,7 @@ class MultiOutputFusionTest : public HloTestBase { const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(8.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(8.0f))); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape0, "0")); @@ -105,8 +105,9 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect(ShapeUtil::MakeShape(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); - auto actual = ExecuteAndTransfer( - std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); + auto actual = + ExecuteAndTransfer(std::move(hlo_module), + {LiteralUtil::CreateR0(-9.0f).get(), &arg1}); EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } @@ -165,7 +166,8 @@ class MultiOutputFusionTest : public HloTestBase { Literal input1(ShapeUtil::MakeShape(F64, {size})); input1.PopulateWithValue(1.); - Literal expect = std::move(*Literal::CreateR1({size * 1.5f * 3.5f})); + Literal expect = + std::move(*LiteralUtil::CreateR1({size * 1.5f * 3.5f})); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } @@ -198,16 +200,16 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::MakeTupleOwned( - Literal::MakeTupleOwned( - Literal::MakeTupleOwned(Literal::CreateR0(42)), - Literal::CreateR0(1.0)), - Literal::MakeTupleOwned(Literal::CreateR0(3.0), - Literal::CreateR0(4))); + auto param = LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), + LiteralUtil::CreateR0(1.0)), + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), + LiteralUtil::CreateR0(4))); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR0(42)), *result)); + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), *result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -232,7 +234,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); + auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); @@ -265,7 +267,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR1({1.0, 2.0, 3.0}); + auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); @@ -308,12 +310,14 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2({{5, 16}, {36, 64}})), *result)); } @@ -338,12 +342,14 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR2({{6, 8}, {10, 12}}), - Literal::CreateR2({{25, 36}, {49, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{6, 8}, {10, 12}}), + LiteralUtil::CreateR2({{25, 36}, {49, 64}})), *result)); } @@ -369,13 +375,14 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), - Literal::CreateR1({36, 64}), - Literal::CreateR1({66, 138})), + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR1({36, 64}), + LiteralUtil::CreateR1({66, 138})), *result)); } @@ -401,14 +408,15 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), - Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), + LiteralUtil::CreateR2({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2({{5, 16}, {36, 64}})), *result)); } @@ -434,14 +442,16 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR2({{6, 8}, {10, 12}}), - Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - Literal::CreateR2({{25, 36}, {49, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{6, 8}, {10, 12}}), + LiteralUtil::CreateR3( + {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + LiteralUtil::CreateR2({{25, 36}, {49, 64}})), *result)); } @@ -468,14 +478,16 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR1({14, 22}), - Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - Literal::CreateR3( + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR3( + {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + LiteralUtil::CreateR3( {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), *result)); } @@ -502,15 +514,16 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - auto init1 = Literal::CreateR0(5); - auto init2 = Literal::CreateR0(6); + auto param = + LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto init1 = LiteralUtil::CreateR0(5); + auto init2 = LiteralUtil::CreateR0(6); std::unique_ptr result = ExecuteNoHloPasses( std::move(module), {param.get(), init1.get(), init2.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR2({{167, 172}, {176, 180}}), - Literal::CreateR2({{6, 6}, {6, 8}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{167, 172}, {176, 180}}), + LiteralUtil::CreateR2({{6, 6}, {6, 8}})), *result)); } @@ -537,19 +550,20 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3( + auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}}), - Literal::CreateR3({{{Eigen::half(1), Eigen::half(2)}, - {Eigen::half(3), Eigen::half(4)}}, - {{Eigen::half(5), Eigen::half(6)}, - {Eigen::half(7), Eigen::half(8)}}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2({{5, 16}, {36, 64}}), + LiteralUtil::CreateR3( + {{{Eigen::half(1), Eigen::half(2)}, + {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, + {Eigen::half(7), Eigen::half(8)}}})), *result)); } diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index 2e5081bbcb6..e428fa9b5e1 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - Pad(AddParam(*Literal::CreateR1({}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*LiteralUtil::CreateR1({}), &b), + AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } @@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - Pad(AddParam(*Literal::CreateR1({}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*LiteralUtil::CreateR1({}), &b), + AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, DefaultErrorSpec()); } @@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - Pad(AddParam(*Literal::CreateR1({1, 2, 3}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*LiteralUtil::CreateR1({1, 2, 3}), &b), + AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } @@ -132,7 +132,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { XlaBuilder b(TestName()); Pad(AddParam(Array4D(2, 0, 3, 2), &b), - AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); + AddParam(*LiteralUtil::CreateR0(1.5), &b), + r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, DefaultErrorSpec()); } @@ -147,7 +148,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(1.5), &b), + Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); @@ -166,7 +167,8 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - Pad(AddParam(input, &b), AddParam(*Literal::CreateR0(pad_value), &b), + Pad(AddParam(input, &b), + AddParam(*LiteralUtil::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(8, 5, 1, 1); @@ -205,11 +207,11 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); - auto input = Literal::CreateR4FromArray4D(input_array); + auto input = LiteralUtil::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(pad_value), &b), - padding_config); + Pad(AddParam(*input, &b), + AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -251,11 +253,11 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 0, 0, 0) = 1.0f; input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; - auto input = Literal::CreateR4FromArray4D(input_array); + auto input = LiteralUtil::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(pad_value), &b), - padding_config); + Pad(AddParam(*input, &b), + AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -329,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -351,7 +353,8 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(3.14f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -376,7 +379,8 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -403,7 +407,8 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -430,7 +435,8 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -446,12 +452,13 @@ XLA_TEST_P(PadTestFloat, ReducePad) { XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = - Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); + Reduce(input, AddParam(*LiteralUtil::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); + Pad(reduce, AddParam(*LiteralUtil::CreateR0(0.0f), &b), + padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 2620063aa49..8ba1d11b333 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -42,7 +42,8 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR0(3.14159f); + std::unique_ptr param0_literal = + LiteralUtil::CreateR0(3.14159f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -54,7 +55,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR1({}); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -67,7 +68,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({3.14f, -100.25f}); + LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -80,7 +81,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XlaBuilder builder(TestName()); string str("hello world"); - std::unique_ptr param0_literal = Literal::CreateR1U8(str); + std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -94,7 +95,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR2FromArray2D(Array2D(3, 0)); + LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -106,7 +107,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR2( + std::unique_ptr param0_literal = LiteralUtil::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -122,12 +123,12 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = Literal::CreateR1({1, 2}); + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); - std::unique_ptr literal1 = Literal::CreateR1({10, 20}); + std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); @@ -153,7 +154,7 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = Literal::CreateR0(3.14159f); + std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -167,12 +168,12 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = Literal::CreateR1({1, 2}); + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); Parameter(&builder, 0, literal0->shape(), "param0"); - std::unique_ptr literal1 = Literal::CreateR1({10, 20}); + std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); Parameter(&builder, 1, literal1->shape(), "param1"); @@ -187,11 +188,12 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. XlaBuilder builder(TestName()); - std::unique_ptr literal0 = Literal::CreateR1({1, 2}); + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1({10, 20, 30}); + std::unique_ptr literal1 = + LiteralUtil::CreateR1({10, 20, 30}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); @@ -231,7 +233,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = Literal::CreateR1(sum_value); + std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -266,7 +268,7 @@ XLA_TEST_F(ParamsTest, constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = Literal::CreateR0(i); + std::unique_ptr literal = LiteralUtil::CreateR0(i); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -298,7 +300,7 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector params; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = Literal::CreateR1({i, i}); + std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -322,10 +324,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector> elements; std::vector ptrs; for (int i = 0; i < kParamCount; ++i) { - elements.push_back(Literal::CreateR1({target + i, target + i})); + elements.push_back(LiteralUtil::CreateR1({target + i, target + i})); ptrs.push_back(elements.back().get()); } - ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); } // Test large number of parameters flowing into a while-loop. @@ -354,7 +356,7 @@ XLA_TEST_F(ParamsTest, std::vector params; std::vector parameter_shapes; for (int i = 0; i < kParamCount; ++i) { - std::unique_ptr literal = Literal::CreateR1({i, i}); + std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -364,7 +366,7 @@ XLA_TEST_F(ParamsTest, // Add bool parameter for the loop condition. Use a parameter HLO instead of a // constant because DCE may eliminate the while-body otherwise. - std::unique_ptr bool_literal = Literal::CreateR0(false); + std::unique_ptr bool_literal = LiteralUtil::CreateR0(false); param_data_owner.push_back( std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); XlaOp bool_param = @@ -421,10 +423,10 @@ XLA_TEST_F(ParamsTest, std::vector> elements; std::vector ptrs; for (int i = 0; i < kParamCount; ++i) { - elements.push_back(Literal::CreateR1({i, i})); + elements.push_back(LiteralUtil::CreateR1({i, i})); ptrs.push_back(elements.back().get()); } - ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); } #endif @@ -441,9 +443,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { std::unique_ptr data = client_ - ->TransferToServer(*Literal::MakeTuple({ - Literal::CreateR1({1, 2, 3}).get(), - Literal::CreateR1({4, 5, 6}).get(), + ->TransferToServer(*LiteralUtil::MakeTuple({ + LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR1({4, 5, 6}).get(), })) .ConsumeValueOrDie(); @@ -455,7 +457,7 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = Literal::CreateR2WithLayout( + std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); Parameter(&builder, 0, literal->shape(), "input"); @@ -467,7 +469,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = Literal::CreateR2WithLayout( + std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); Parameter(&builder, 0, literal->shape(), "input"); @@ -478,7 +480,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = Literal::CreateR2({ + std::unique_ptr literal = LiteralUtil::CreateR2({ {1, 3}, {2, 4}, }); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 8e163e885d0..5ebf8344d2b 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -193,7 +193,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, client_->TransferToServer(*param0_literal)); diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index 9052b188ed0..a080dd1732b 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -95,21 +95,21 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { *reduce_input_shape->mutable_layout() = LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major); - std::unique_ptr reduce_input = - Literal::CreateR4({{ /*i0=0*/ - {/*i1=0*/ - {-0.246092796, -0.179497838, -0.161181688}, - {-0.151643038, -0.240213156, -0.198156}}, - {/*i1=1*/ - {-0.14222312, -0.162200093, -0.193907976}, - {-0.239411, -0.198166847, -0.172471642}}}, - { /*i0=1*/ - {/*i1=0*/ - {-0.22965157, -0.218723893, -0.129257083}, - {-0.188762426, -0.16123569, -0.181166649}}, - {/*i1=1*/ - {-0.241772294, -0.245131493, -0.160247207}, - {-0.179881215, -0.23383224, -0.121976733}}}}); + std::unique_ptr reduce_input = LiteralUtil::CreateR4( + {{ /*i0=0*/ + {/*i1=0*/ + {-0.246092796, -0.179497838, -0.161181688}, + {-0.151643038, -0.240213156, -0.198156}}, + {/*i1=1*/ + {-0.14222312, -0.162200093, -0.193907976}, + {-0.239411, -0.198166847, -0.172471642}}}, + { /*i0=1*/ + {/*i1=0*/ + {-0.22965157, -0.218723893, -0.129257083}, + {-0.188762426, -0.16123569, -0.181166649}}, + {/*i1=1*/ + {-0.241772294, -0.245131493, -0.160247207}, + {-0.179881215, -0.23383224, -0.121976733}}}}); EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 4c1aa121067..04c7f316463 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -230,7 +230,8 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({input_values}); + std::unique_ptr a_literal = + LiteralUtil::CreateR1({input_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); @@ -253,7 +254,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); @@ -282,7 +283,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); @@ -308,7 +309,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); @@ -332,7 +333,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); @@ -357,7 +358,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index c9f57cbb167..1407fca72fd 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -67,12 +67,12 @@ class ReduceTest : public ClientLibraryTestBase { ReduceTest() { // Implementation note: laid out z >> y >> x by default. // clang-format off - literal_2d_ = Literal::CreateR2({ + literal_2d_ = LiteralUtil::CreateR2({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 }); - literal_3d_ = Literal::CreateR3Projected({ + literal_3d_ = LiteralUtil::CreateR3Projected({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 @@ -101,7 +101,7 @@ class ReduceTest : public ClientLibraryTestBase { } } std::unique_ptr input_literal = - Literal::CreateR1(AsSlice(input_data)); + LiteralUtil::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -133,7 +133,7 @@ class ReduceTest : public ClientLibraryTestBase { Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = Literal::CreateR1(input_data); + std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -175,7 +175,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(0, 1); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -209,7 +209,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -237,7 +237,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -295,7 +295,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillUnique(initial_value); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -450,7 +450,7 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -482,7 +482,7 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -531,7 +531,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR3FromArray3D(input_data); + LiteralUtil::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -594,7 +594,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { auto max = CreateScalarMaxComputation(F32, &builder); Array2D input(300, 250); input.FillRandom(214.0f); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); Reduce(ConstantLiteral(&builder, *input_literal), ConstantR0(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; @@ -609,7 +609,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { auto min = CreateScalarMinComputation(F32, &builder); Array2D input(150, 130); input.FillRandom(214.0f); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); Reduce(ConstantLiteral(&builder, *input_literal), ConstantR0(&builder, FLT_MAX), min, {0, 1}); @@ -623,7 +623,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { XlaBuilder builder(TestName()); Array2D input({{1}, {2}}); auto min = CreateScalarMinComputation(U32, &builder); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); auto initial_value = ConstantR0(&builder, std::numeric_limits::max()); @@ -635,7 +635,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { XlaBuilder builder(TestName()); Array2D input({{1}, {2}}); auto max = CreateScalarMaxComputation(U32, &builder); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); auto initial_value = ConstantR0(&builder, std::numeric_limits::min()); @@ -818,7 +818,7 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { // input_array.FillRandom(3.14f, 0.05); input_array.Fill(1.0f); - auto input_literal = Literal::CreateR3FromArray3D(input_array); + auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = @@ -872,7 +872,8 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { auto a = ConstantR0(&builder, 2.0f); auto a2 = Abs(a); - std::unique_ptr b_literal = Literal::CreateR1({1.0f, 4.0f}); + std::unique_ptr b_literal = + LiteralUtil::CreateR1({1.0f, 4.0f}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); auto b = Parameter(&builder, 0, b_literal->shape(), "b"); @@ -900,7 +901,7 @@ class ReduceInitializerTest : public ReduceTest { auto init = ConstantR0(&builder, initializer); std::vector input_arr(num_elems, std::numeric_limits::lowest()); - auto input_literal = Literal::CreateR1(input_arr); + auto input_literal = LiteralUtil::CreateR1(input_arr); auto input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, @@ -950,10 +951,11 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { float operand[] = {42.0f}; float init = 58.5f; float expected = 42.0f; - std::unique_ptr input_literal = Literal::CreateR1(operand); + std::unique_ptr input_literal = + LiteralUtil::CreateR1(operand); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - std::unique_ptr input_literal2 = Literal::CreateR0(init); + std::unique_ptr input_literal2 = LiteralUtil::CreateR0(init); std::unique_ptr input_global_data2 = client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); ComputeAndCompareR0( diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 741974480c6..c2681f70f7e 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -70,8 +70,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - auto init = - CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_); + auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), + &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); @@ -81,7 +81,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); + auto init = + CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_); ReduceWindow(input, init, CreateScalarMaxComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); @@ -91,7 +92,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); + auto init = + CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_); ReduceWindow(input, init, CreateScalarMinComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); @@ -102,9 +104,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1({1, 1, 1, 1}), &builder_); + *LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); const auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); ReduceWindow(input, init_value, CreateScalarAddComputation(FloatType(), &builder_), @@ -119,32 +121,32 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { // Regression test for b/68964348. TEST_P(ReduceWindowTest, R0ReduceWindow) { const auto input = - CreateConstantFromLiteral(*Literal::CreateR0(42.0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(42.0), &builder_); const auto init = - CreateConstantFromLiteral(*Literal::CreateR0(1.0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(1.0), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{}, /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR0(43.0), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0(43.0), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride2) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({100, 1}), {}, - ErrorSpec(0.00001)); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({100, 1}), + {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, Padding::kSame); ComputeAndCompareLiteral(&builder_, - *Literal::CreateR1({1000, 100, 10, 1, 1}), {}, - ErrorSpec(0.00001)); + *LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), + {}, ErrorSpec(0.00001)); } XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { @@ -156,7 +158,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -171,7 +173,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -185,7 +187,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -202,7 +204,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -224,8 +226,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { @@ -247,8 +249,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } // Tests the super windowing logic w.r.t handling prime number of windows in a @@ -272,8 +274,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { @@ -289,8 +291,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. @@ -308,12 +310,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { auto lhs = Parameter(b.get(), 0, scalar, "lhs"); auto rhs = Parameter(b.get(), 1, scalar, "rhs"); Min(Add(lhs, rhs), - CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); ReduceWindow( input, - CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_), + CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), &builder_), reduce_fn, /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); @@ -327,15 +329,15 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); input_array.FillRandom(2.f, 2.f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -347,7 +349,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -376,7 +378,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) { auto shape = ShapeUtil::MakeShape(F32, input_dims); std::unique_ptr arg_literal = - Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); + LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); @@ -385,7 +387,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector output_dims = {8, 8, 6, 6, 8, 8}; std::unique_ptr expected = - Literal::CreateFullWithDescendingLayout(output_dims, 9.0f); + LiteralUtil::CreateFullWithDescendingLayout(output_dims, 9.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -394,7 +396,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -408,7 +410,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -416,7 +418,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -430,7 +432,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -438,7 +440,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -452,7 +454,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -473,18 +475,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); const auto input = CreateConstantFromLiteral( - *Literal::CreateR1(input_vector), &builder_); + *LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); ComputeAndCompareLiteral( &builder_, - *Literal::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + *LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, DefaultErrorSpec()); } @@ -499,9 +501,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *Literal::CreateR1(input_vector), &builder_); + *LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -516,9 +518,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *Literal::CreateR1(input_vector), &builder_); + *LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {1}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -535,14 +537,15 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, + *LiteralUtil::CreateFromArray(*res), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); XlaOp input = Broadcast( - CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); + CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); @@ -550,8 +553,9 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, + *LiteralUtil::CreateFromArray(*res), {}, + DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, @@ -609,7 +613,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, param.base_bounds[2], param.base_bounds[3]); input.FillIota(1); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", @@ -621,7 +625,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, } auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto computation = param.reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) @@ -647,7 +651,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*stride=*/param.strides, /*padding=*/padding); std::unique_ptr expected_literal = - Literal::CreateFromArray(*expected); + LiteralUtil::CreateFromArray(*expected); const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( input_literal->shape().element_type(), AsInt64Slice(expected_literal->shape().dimensions()), param.layout); @@ -959,14 +963,14 @@ TEST_P(R3ReduceWindowTest, Add) { Array3D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], 1.0f); std::unique_ptr input_literal = - Literal::CreateR3FromArray3DWithLayout( + LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); ReduceWindow(/*operand=*/parameter, /*init_value=*/init_value, /*computation=*/CreateScalarAddComputation(FloatType(), &b), @@ -977,7 +981,7 @@ TEST_P(R3ReduceWindowTest, Add) { /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/param.padding); - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } @@ -1093,7 +1097,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, const float kInitValue = 0.0f; Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); std::unique_ptr input_literal = - Literal::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; @@ -1107,7 +1111,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1123,7 +1127,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } }; @@ -1292,7 +1296,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::vector input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = - Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); + LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -1304,7 +1308,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1323,7 +1327,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *Literal::CreateR1(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1(*expected), {input_arg.get()}, DefaultErrorSpec()); } diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index bebd814fa8b..d544968648d 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -91,10 +91,10 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*Literal::CreateR0(2)) + client_->TransferToServer(*LiteralUtil::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*Literal::CreateR0(3)) + client_->TransferToServer(*LiteralUtil::CreateR0(3)) .ConsumeValueOrDie(); std::unique_ptr literal = client_ diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index 5812fe442b2..7c0389cfa32 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index d3d6c3c7d70..46d91711a55 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -55,39 +55,39 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({1.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1({1.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{}); - auto expected_literal = Literal::CreateR1({1.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1({1.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0}); - auto expected_literal = Literal::CreateR1({1.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -97,7 +97,7 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); @@ -105,7 +105,7 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0(1.0f); + auto expected_literal = LiteralUtil::CreateR0(1.0f); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -113,14 +113,14 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR0(1.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); auto a = Neg(parameter); Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); - auto expected_literal = Literal::CreateR1({-1.0f}); + auto expected_literal = LiteralUtil::CreateR1({-1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -128,12 +128,12 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XLA_TEST_P(ReshapeTest, Trivial0x3) { XlaBuilder builder(TestName()); Array2D input_array(0, 3); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({}); + auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -142,12 +142,12 @@ XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR2FromArray2D(Array2D(0, 3)); + LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({}); + auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -155,12 +155,12 @@ XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XLA_TEST_P(ReshapeTest, Trivial3x0) { XlaBuilder builder(TestName()); Array2D input_array(3, 0); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({}); + auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -168,12 +168,12 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) { // Collapses a 2-dimensional row vector to 1 dimension. XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}}); + auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -181,12 +181,12 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { // Collapses a 2-dimensional column vector to 1 dimension. XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR2({{1.0f}, {2.0f}, {3.0f}}); + auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -194,13 +194,13 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { // Splits an empty vector into an empty matrix. XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1({}); + auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); - auto expected_literal = Literal::CreateR2({{}, {}}); + auto expected_literal = LiteralUtil::CreateR2({{}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -209,14 +209,14 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { XlaBuilder builder(TestName()); auto input_literal = - Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); auto expected_literal = - Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -224,13 +224,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { // Transposes a 2x0 array to a 0x2 array. XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 0}); - auto expected_literal = Literal::CreateR2({{}, {}}); + auto expected_literal = LiteralUtil::CreateR2({{}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -239,7 +239,7 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { XlaBuilder builder(TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); - auto input_literal = Literal::CreateFromArray(*simple); + auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -247,7 +247,7 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -256,7 +256,7 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { XLA_TEST_P(ReshapeTest, TransposeAsReshape) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -264,7 +264,7 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { /*new_sizes=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -272,12 +272,12 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { // Transposes a 0x4 array with XlaBuilder::Transpose. XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); - auto expected_literal = Literal::CreateR2({{}, {}, {}, {}}); + auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -286,14 +286,14 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) { XLA_TEST_P(ReshapeTest, Transpose4x3) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -302,26 +302,27 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { // rearrangement of the originals (split), but no reordering (no shuffle). XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 3, 0, 0}); - auto expected_literal = Literal::CreateFromArray(Array4D(2, 3, 0, 0)); + auto expected_literal = + LiteralUtil::CreateFromArray(Array4D(2, 3, 0, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); + auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{24, 0}); - auto expected_literal = Literal::CreateFromArray(Array2D(24, 0)); + auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -331,7 +332,7 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -339,20 +340,20 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { /*new_sizes=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 0}); - auto expected_literal = Literal::CreateFromArray(Array2D(3, 0)); + auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -362,7 +363,7 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -370,7 +371,7 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); - auto expected_literal = Literal::CreateFromArray(expected); + auto expected_literal = LiteralUtil::CreateFromArray(expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -388,13 +389,13 @@ static Array3D ArrayForDocR3Tests() { XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{24}); - auto expected_literal = Literal::CreateR1( + auto expected_literal = LiteralUtil::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -403,33 +404,33 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{8, 3}); - auto expected_literal = Literal::CreateR2({{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}); + auto expected_literal = LiteralUtil::CreateR2({{10, 11, 12}, + {15, 16, 17}, + {20, 21, 22}, + {25, 26, 27}, + {30, 31, 32}, + {35, 36, 37}, + {40, 41, 42}, + {45, 46, 47}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{24}); - auto expected_literal = Literal::CreateR1( + auto expected_literal = LiteralUtil::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -438,33 +439,33 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{8, 3}); - auto expected_literal = Literal::CreateR2({{10, 20, 30}, - {40, 11, 21}, - {31, 41, 12}, - {22, 32, 42}, - {15, 25, 35}, - {45, 16, 26}, - {36, 46, 17}, - {27, 37, 47}}); + auto expected_literal = LiteralUtil::CreateR2({{10, 20, 30}, + {40, 11, 21}, + {31, 41, 12}, + {22, 32, 42}, + {15, 25, 35}, + {45, 16, 26}, + {36, 46, 17}, + {27, 37, 47}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{2, 6, 2}); - auto expected_literal = Literal::CreateR3( + auto expected_literal = LiteralUtil::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -491,12 +492,12 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); - auto input_literal = Literal::CreateFromArray(t2x2x2x3); + auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); - auto expected_literal = Literal::CreateR2( + auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); @@ -516,7 +517,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 0, 1) = 5; t(1, 0, 1, 0) = 6; t(1, 0, 1, 1) = 7; - auto input_literal = Literal::CreateFromArray(t); + auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -524,7 +525,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { /*new_sizes=*/{2, 4}); auto expected_literal = - Literal::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); + LiteralUtil::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -545,7 +546,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { &b, ¶meter); Reshape(parameter, dimensions, {}); - auto expected_literal = Literal::CreateR0(83.0f); + auto expected_literal = LiteralUtil::CreateR0(83.0f); ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, zero_error_spec_); } @@ -553,7 +554,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { XLA_TEST_P(ReshapeTest, BadDimensions) { XlaBuilder b(TestName()); - auto input_literal = Literal::CreateR1({1.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); @@ -565,7 +566,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); - auto input_literal = Literal::CreateR1({1.0f, 2.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); @@ -577,7 +578,8 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { XlaBuilder builder(TestName()); // clang-format off - auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ + auto input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + Array4D{ { { {0, 1}, @@ -622,16 +624,16 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = - Literal::CreateR2FromArray2D(expected_array); + LiteralUtil::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = Literal::ConvertF32ToBF16(*expected); + expected = LiteralUtil::ConvertF32ToBF16(*expected); } EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = Literal::CreateR2({ + std::unique_ptr input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -642,7 +644,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - auto expected_literal = Literal::CreateR4({ + auto expected_literal = LiteralUtil::CreateR4({ {{{0, 1, 2, 3}}, {{4, 5, 6, 7}}}, {{{100, 101, 102, 103}}, @@ -658,7 +660,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = Literal::CreateR2({ + std::unique_ptr input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -669,7 +671,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - auto expected_literal = Literal::CreateR4({ + auto expected_literal = LiteralUtil::CreateR4({ {{{0, 100, 200, 1}}, {{101, 201, 2, 102}}}, {{{202, 3, 103, 203}}, @@ -691,7 +693,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( @@ -699,7 +701,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = - Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal); + LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -713,7 +715,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( @@ -721,7 +723,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = - Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal); + LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -736,7 +738,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( @@ -749,7 +751,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = *cell; }); - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -763,7 +765,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( @@ -785,7 +787,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = Literal::ConvertF32ToBF16(*input_literal); + auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal); EXPECT_EQ(expected->data(), output_literal->data()); } else { EXPECT_EQ(input_literal->data(), output_literal->data()); @@ -794,7 +796,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { XlaBuilder builder(TestName()); - auto literal_1x2x3x4 = Literal::CreateR4( + auto literal_1x2x3x4 = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -808,7 +810,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { } XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = Literal::CreateR4( + auto literal_1x2x3x4 = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -820,7 +822,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = Literal::CreateR4( + auto expected_2x4x3x1 = LiteralUtil::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -844,7 +846,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; @@ -854,7 +856,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -873,7 +875,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; @@ -883,7 +885,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -902,7 +904,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; @@ -912,7 +914,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -932,7 +934,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; @@ -942,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -961,7 +963,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; @@ -971,7 +973,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 662bc422248..23f0d26d93b 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -82,7 +82,7 @@ TEST_P(FloatReverseTest, Reverses) { std::vector input_vector( ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); std::iota(input_vector.begin(), input_vector.end(), 0.0); - auto r1_literal = Literal::CreateR1(input_vector); + auto r1_literal = LiteralUtil::CreateR1(input_vector); auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 7cfca781acd..a620fe19085 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/packed_literal_reader.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index f334a8c1318..a8193c2eac0 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -46,61 +46,62 @@ class RoundTripTransferTest : public ClientLibraryTestBase { }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*Literal::CreateR0(42)); + RoundTripTest(*LiteralUtil::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*Literal::CreateR0(42.0)); + RoundTripTest(*LiteralUtil::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*Literal::CreateR1({})); + RoundTripTest(*LiteralUtil::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*Literal::CreateR1({42.0, 64.0})); + RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest(*Literal::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest( + *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*Literal::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*Literal::CreateR4({{ + RoundTripTest(*LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -108,33 +109,36 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*Literal::MakeTuple({})); + RoundTripTest(*LiteralUtil::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { - RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), - Literal::CreateR1({3, 4}).get()})); + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), + LiteralUtil::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { - RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({}).get(), - Literal::CreateR1({3, 4}).get()})); + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), + LiteralUtil::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest(*Literal::MakeTuple({Literal::CreateR0(1.0).get(), - Literal::CreateR1({2, 3}).get()})); + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), + LiteralUtil::CreateR1({2, 3}).get()})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*Literal::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*Literal::CreateR4FromArray4D(array4d)); + RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 3afd8c8fc88..3b603c0d315 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -162,7 +163,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { ConvertElementType(a, F32); int64 value = 3LL << 35; - std::unique_ptr a_literal = Literal::CreateR0(value); + std::unique_ptr a_literal = LiteralUtil::CreateR0(value); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, static_cast(value), @@ -226,9 +227,9 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR0(2.1f); - std::unique_ptr b_literal = Literal::CreateR0(5.5f); - std::unique_ptr c_literal = Literal::CreateR0(0.5f); + std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); + std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); + std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); @@ -375,8 +376,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = Literal::CreateR0(dividend); - auto divisor_literal = Literal::CreateR0(divisor); + auto dividend_literal = LiteralUtil::CreateR0(dividend); + auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, @@ -387,7 +388,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0(dividend / divisor); + auto expected_literal = + LiteralUtil::CreateR0(dividend / divisor); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } @@ -416,8 +418,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = Literal::CreateR0(dividend); - auto divisor_literal = Literal::CreateR0(divisor); + auto dividend_literal = LiteralUtil::CreateR0(dividend); + auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, @@ -428,7 +430,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0(dividend % divisor); + auto expected_literal = + LiteralUtil::CreateR0(dividend % divisor); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } @@ -440,7 +443,7 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); Rem(x, ConstantR0(&builder, 80000)); - std::unique_ptr literal = Literal::CreateR0(87919); + std::unique_ptr literal = LiteralUtil::CreateR0(87919); TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 0a173fbbbd5..b1f1e69d3cd 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 3e5c01d6d47..48138e7b076 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -170,7 +170,7 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { values.FillRandom(3.14f); auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}}, /*strides=*/{{1, 1, 2, 1}}); - auto expected_literal = Literal::CreateR4FromArray4DWithLayout( + auto expected_literal = LiteralUtil::CreateR4FromArray4DWithLayout( *expected, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); auto original = ConstantR4FromArray4D(&builder, values); @@ -197,7 +197,7 @@ class SliceR1Test : public ClientLibraryTestBase, // vector. tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); - auto literal = Literal::CreateR1(input); + auto literal = LiteralUtil::CreateR1(input); XlaBuilder builder(TestName()); auto original = Parameter(&builder, 0, literal->shape(), "p0"); @@ -368,7 +368,7 @@ XLA_TEST_P(SliceR2Test, DoIt) { const R2Spec& spec = GetParam(); Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); - auto literal = Literal::CreateR2FromArray2DWithLayout( + auto literal = LiteralUtil::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); @@ -463,7 +463,7 @@ class SliceR4Test : public ClientLibraryTestBase, auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); XlaBuilder builder(TestName()); - auto literal = Literal::CreateR4FromArray4DWithLayout( + auto literal = LiteralUtil::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); auto parameter = Parameter(&builder, 0, literal->shape(), "p0"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 20c7c30878a..26479370132 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" @@ -110,7 +111,7 @@ StatusOr> MakeFakeLiteralInternal( MakeFakeLiteralInternal(element_shape, engine)); elements.push_back(std::move(element)); } - return Literal::MakeTupleOwned(std::move(elements)); + return LiteralUtil::MakeTupleOwned(std::move(elements)); } if (engine == nullptr) { return Literal::CreateFromShape(shape); @@ -220,7 +221,7 @@ std::unique_ptr MakeRandomNonwrappingSliceIndex( start_indices[i] = generator(*engine); } } - return Literal::CreateR1(start_indices); + return LiteralUtil::CreateR1(start_indices); } // Use dataflow analysis on each parameter to see if there are uses that would @@ -318,9 +319,9 @@ StatusOr> CreateLiteralForConstrainedUses( } else if (needs_constant != nullptr) { switch (constant_type) { case ConstantType::kZero: - return Literal::Zero(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); case ConstantType::kOne: - return Literal::One(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::One(param.shape().element_type()).CloneToUnique(); case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index a8689f64981..e59f215a9a3 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index e9008fa48aa..f6f4a17bca2 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -37,7 +37,7 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken())); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, TokenTree) { @@ -53,7 +53,7 @@ XLA_TEST_F(TokenHloTest, TokenTree) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken())); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { @@ -64,7 +64,7 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { builder.AddInstruction( HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); module->AddEntryComputation(builder.Build()); Status status = HloVerifier().Run(module.get()).status(); @@ -98,7 +98,7 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); builder.AddInstruction(HloInstruction::CreateAfterAll({param})); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); module->AddEntryComputation(builder.Build()); Status status = HloVerifier().Run(module.get()).status(); @@ -184,7 +184,7 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); - auto arg = Literal::CreateR0(true); + auto arg = LiteralUtil::CreateR0(true); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, Execute(std::move(module), {arg.get()})); EXPECT_EQ(42, result->Get({})); @@ -195,7 +195,7 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); - auto arg = Literal::CreateR0(false); + auto arg = LiteralUtil::CreateR0(false); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, Execute(std::move(module), {arg.get()})); EXPECT_EQ(7, result->Get({})); diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 86babb58c9d..0f86b7f20f9 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -68,7 +68,7 @@ class TransferManagerTest : public LocalClientTestBase { }; XLA_TEST_F(TransferManagerTest, TransferR0U32) { - std::unique_ptr literal = Literal::CreateR0(42); + std::unique_ptr literal = LiteralUtil::CreateR0(42); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); @@ -84,7 +84,7 @@ XLA_TEST_F(TransferManagerTest, TransferR0U32) { XLA_TEST_F(TransferManagerTest, TransferR1F32) { std::unique_ptr literal = - Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); + LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); @@ -102,7 +102,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) { XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { std::vector test_vector(1024 * 1024); std::iota(test_vector.begin(), test_vector.end(), 0); - std::unique_ptr literal = Literal::CreateR1(test_vector); + std::unique_ptr literal = LiteralUtil::CreateR1(test_vector); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); @@ -118,7 +118,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { XLA_TEST_F(TransferManagerTest, TransferR1U8) { const char* test_string = "0123456789abcdef"; - std::unique_ptr literal = Literal::CreateR1U8(test_string); + std::unique_ptr literal = LiteralUtil::CreateR1U8(test_string); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); @@ -134,7 +134,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) { XLA_TEST_F(TransferManagerTest, TransferR2F32) { std::unique_ptr literal = - Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); @@ -151,7 +151,7 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) { XLA_TEST_F(TransferManagerTest, TransferR2F32AndChangeLayoutTransferringToDevice) { - std::unique_ptr literal = Literal::CreateR2WithLayout( + std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); const Shape ondevice_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -172,10 +172,10 @@ XLA_TEST_F(TransferManagerTest, } XLA_TEST_F(TransferManagerTest, TransferTuple) { - std::unique_ptr literal = Literal::MakeTuple( - {Literal::CreateR0(123.0f).get(), - Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); + std::unique_ptr literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(123.0f).get(), + LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. @@ -189,7 +189,7 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) { } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { - std::unique_ptr literal = Literal::MakeTuple({}); + std::unique_ptr literal = LiteralUtil::MakeTuple({}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. @@ -203,13 +203,13 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { - std::unique_ptr literal = Literal::MakeTuple( - {Literal::CreateR0(123.0f).get(), - Literal::MakeTuple( - {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + std::unique_ptr literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(123.0f).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) .get(), - Literal::CreateR1({-10.0f, 123.0f}).get()}); + LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. @@ -223,7 +223,7 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { - std::unique_ptr literal = Literal::CreateR1( + std::unique_ptr literal = LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); @@ -238,12 +238,12 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) { } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { - std::unique_ptr literal = Literal::MakeTuple( - {Literal::CreateR1( + std::unique_ptr literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) .get(), - Literal::CreateR1({1, 2, 3, 4, 5, 6}).get(), - Literal::CreateR0(complex64(0.3f, -0.4f)).get()}); + LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}).get(), + LiteralUtil::CreateR0(complex64(0.3f, -0.4f)).get()}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. @@ -265,25 +265,25 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*Literal::CreateToken(), *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result)); } XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { const int64 kIterationCount = 5000; - std::unique_ptr literal1 = Literal::MakeTuple( - {Literal::CreateR0(123.0f).get(), - Literal::MakeTuple( - {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + std::unique_ptr literal1 = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(123.0f).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) .get(), - Literal::CreateR1({-10.0f, 123.0f}).get()}); - std::unique_ptr literal2 = Literal::MakeTuple( - {Literal::CreateR0(456.0f).get(), - Literal::MakeTuple( - {Literal::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), - Literal::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) + LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); + std::unique_ptr literal2 = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(456.0f).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) .get(), - Literal::CreateR1({-98.0f, 153.0f}).get()}); + LiteralUtil::CreateR1({-98.0f, 153.0f}).get()}); auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); @@ -325,10 +325,10 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest { std::vector> tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( - Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } std::unique_ptr literal = - Literal::MakeTupleOwned(std::move(tuple_elements)); + LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); auto device_buffer = AllocateDeviceBuffer(literal->shape()); TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, device_buffer)); @@ -357,10 +357,10 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest { std::vector> tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( - Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } std::unique_ptr literal = - Literal::MakeTupleOwned(std::move(tuple_elements)); + LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); auto device_buffer = AllocateDeviceBuffer(literal->shape()); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index ec11508891d..bf86c5dfb6a 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -49,10 +49,10 @@ XLA_TEST_F(TupleTest, TupleConstant) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto value = - Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), - Literal::CreateR1(constant_vector).get(), - Literal::CreateR2(constant_matrix).get()}); + auto value = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar).get(), + LiteralUtil::CreateR1(constant_vector).get(), + LiteralUtil::CreateR2(constant_matrix).get()}); ConstantLiteral(&builder, *value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); @@ -64,9 +64,9 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; - auto value = - Literal::MakeTuple({Literal::CreateR0(constant_scalar1).get(), - Literal::CreateR0(constant_scalar2).get()}); + auto value = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar1).get(), + LiteralUtil::CreateR0(constant_scalar2).get()}); ConstantLiteral(&builder, *value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); @@ -86,10 +86,10 @@ XLA_TEST_F(TupleTest, TupleCreate) { ConstantR1(&builder, constant_vector), ConstantR2(&builder, constant_matrix)}); - auto expected = - Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), - Literal::CreateR1(constant_vector).get(), - Literal::CreateR2(constant_matrix).get()}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar).get(), + LiteralUtil::CreateR1(constant_vector).get(), + LiteralUtil::CreateR2(constant_matrix).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -100,8 +100,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { Tuple(&builder, {ConstantR0(&builder, 7.0), ConstantR1(&builder, {})}); - auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), - Literal::CreateR1({}).get()}); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), + LiteralUtil::CreateR1({}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -109,7 +110,7 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); Tuple(&builder, {}); - auto expected = Literal::MakeTuple({}); + auto expected = LiteralUtil::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -193,9 +194,9 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ConstantR2(&builder, constant_matrix)}); Tuple(&builder, {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); - auto expected = - Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), - Literal::CreateR1(constant_vector).get()}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2(constant_matrix).get(), + LiteralUtil::CreateR1(constant_vector).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -216,8 +217,8 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); auto expected = - Literal::MakeTuple({Literal::CreateR0(direction).get(), - Literal::CreateR0(!direction).get()}); + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(direction).get(), + LiteralUtil::CreateR0(!direction).get()}); ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, error_spec_); @@ -284,8 +285,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), - Literal::CreateR1(vec1).get()}); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), + LiteralUtil::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -328,8 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, true), tuple12, tuple21); - auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), - Literal::CreateR1(vec2).get()}); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), + LiteralUtil::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -403,8 +406,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), - Literal::CreateR1(vec1).get()}); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), + LiteralUtil::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -414,13 +418,13 @@ XLA_TEST_F(TupleTest, NestedTuples) { ConstantR0(&builder, 42.0)}); Tuple(&builder, {inner_tuple, ConstantR1(&builder, {22.0, 44.0})}); - auto expected_v1 = Literal::CreateR1({1.0, 2.0}); - auto expected_s = Literal::CreateR0(42.0); + auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); + auto expected_s = LiteralUtil::CreateR0(42.0); auto expected_inner_tuple = - Literal::MakeTuple({expected_v1.get(), expected_s.get()}); - auto expected_v2 = Literal::CreateR1({22.0, 44.0}); + LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); + auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); auto expected = - Literal::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -440,14 +444,14 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr data = client_ - ->TransferToServer(*Literal::MakeTuple({ - Literal::MakeTuple( + ->TransferToServer(*LiteralUtil::MakeTuple({ + LiteralUtil::MakeTuple( { - Literal::CreateR1({1.0, 2.0, 3.0}).get(), - Literal::CreateR1({4.0, 5.0, 6.0}).get(), + LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), + LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), }) .get(), - Literal::CreateR1({7.0, 8.0, 9.0}).get(), + LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), })) .ConsumeValueOrDie(); @@ -478,11 +482,12 @@ XLA_TEST_F(TupleTest, ComplexTuples) { std::unique_ptr arg0 = client_ - ->TransferToServer(*Literal::MakeTuple( - {Literal::CreateR0({1, 2}).get(), - Literal::MakeTuple( - {Literal::CreateR1({{10, 20}, {30, 40}}).get(), - Literal::CreateR2( + ->TransferToServer(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0({1, 2}).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({{10, 20}, {30, 40}}) + .get(), + LiteralUtil::CreateR2( {{{100, 200}, {300, 400}}, {{1000, 2000}, {3000, 4000}}, {{10000, 20000}, {30000, 40000}}}) @@ -491,11 +496,13 @@ XLA_TEST_F(TupleTest, ComplexTuples) { .ConsumeValueOrDie(); std::unique_ptr arg1 = client_ - ->TransferToServer(*Literal::CreateR1({{1, 2}, {1, -2}})) + ->TransferToServer( + *LiteralUtil::CreateR1({{1, 2}, {1, -2}})) .ConsumeValueOrDie(); - auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, - {{1011, 2022}, {3031, 4042}}, - {{10011, 20022}, {30031, 40042}}}); + auto sum = + LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, + {{1011, 2022}, {3031, 4042}}, + {{10011, 20022}, {30031, 40042}}}); auto prod = MakeUnique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { @@ -505,9 +512,9 @@ XLA_TEST_F(TupleTest, ComplexTuples) { : complex64(1, -2)); }) .ok()); - auto expected = - Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(), - Literal::CreateR0({123, 456}).get()}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(), + LiteralUtil::CreateR0({123, 456}).get()}); ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, error_spec_); } @@ -530,10 +537,11 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); + auto param = + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})), + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), *result)); } diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 929b1ca7fb9..a90a6fb0a5b 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -101,7 +101,7 @@ void UnaryOpTest::AbsTestHelper() { Abs(arg); std::unique_ptr expected = - Literal::CreateR1({2, 25, 0, 0.5, inf(), inf()}); + LiteralUtil::CreateR1({2, 25, 0, 0.5, inf(), inf()}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } @@ -113,7 +113,7 @@ void UnaryOpTest::SignTestHelper() { {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); Sign(arg); - std::unique_ptr expected = Literal::CreateR1( + std::unique_ptr expected = LiteralUtil::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } @@ -128,7 +128,7 @@ void UnaryOpTest::SignAbsTestHelper() { Sub(Mul(sign, ConvertElementType(abs, C64)), arg); std::unique_ptr expected = - Literal::CreateR1({0, 0, 0, 0}); + LiteralUtil::CreateR1({0, 0, 0, 0}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } @@ -173,7 +173,7 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); std::unique_ptr expected = - Literal::CreateR0({-2.6f, 0.8f}); + LiteralUtil::CreateR0({-2.6f, 0.8f}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index bbd67cd8d7c..29befef92e4 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -347,8 +347,8 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // the sum will increase by 1.0. It will first be >15.5 when the elements // have all reached 2.0. auto expected_data = - Literal::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); - auto expected = Literal::MakeTuple({expected_data.get()}); + LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); + auto expected = LiteralUtil::MakeTuple({expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -397,12 +397,13 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(N); - auto expected_w1 = Literal::CreateR1({1.0f, 1.0f, 1.0f}); - auto expected_w2 = Literal::CreateR1({2.0f, 2.0f, 2.0f}); - auto expected_w3 = Literal::CreateR1({3.0f, 3.0f, 3.0f}); - auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(), - expected_w3.get(), expected_w1.get()}); + auto expected_counter = LiteralUtil::CreateR0(N); + auto expected_w1 = LiteralUtil::CreateR1({1.0f, 1.0f, 1.0f}); + auto expected_w2 = LiteralUtil::CreateR1({2.0f, 2.0f, 2.0f}); + auto expected_w3 = LiteralUtil::CreateR1({3.0f, 3.0f, 3.0f}); + auto expected = + LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(), + expected_w3.get(), expected_w1.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -506,11 +507,11 @@ TEST_F(WhileTest, WhileWithTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_data = Literal::CreateR1( + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -554,10 +555,10 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_predicate = Literal::CreateR0(true); - auto expected = - Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_predicate = LiteralUtil::CreateR0(true); + auto expected = LiteralUtil::MakeTuple( + {expected_counter.get(), expected_predicate.get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); } @@ -599,10 +600,10 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_data = Literal::CreateR0(7); + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR0(7); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -882,11 +883,11 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_data = Literal::CreateR1( + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -974,12 +975,12 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); While(cond_computation, body_computation, t); - auto expected_element = Literal::CreateR1({1, 1}); + auto expected_element = LiteralUtil::CreateR1({1, 1}); auto expected = - Literal::MakeTuple({expected_element.get(), expected_element.get()}); + LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR1({42, 42}))); + client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1004,7 +1005,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR1({42, 42}))); + client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareR1(&outer, {1.0f, 1.0f}, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1030,7 +1031,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR0(42))); + client_->TransferToServer(*LiteralUtil::CreateR0(42))); ComputeAndCompareR0(&outer, 43.0f, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1069,11 +1070,11 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR0(1))); + client_->TransferToServer(*LiteralUtil::CreateR0(1))); - auto add1 = Literal::CreateR0(15); - auto add2 = Literal::CreateR0(16); - auto expected = Literal::MakeTuple({add1.get(), add2.get()}); + auto add1 = LiteralUtil::CreateR0(15); + auto add2 = LiteralUtil::CreateR0(16); + auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()}); ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1226,9 +1227,9 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { auto while_instruction = While(condition, body, init); GetTupleElement(while_instruction, 3); - TF_ASSERT_OK_AND_ASSIGN(auto param_value, - client_->TransferToServer(*Literal::CreateR2( - {{1.0, 2.0}, {-1.0, -2.0}}))); + TF_ASSERT_OK_AND_ASSIGN( + auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2( + {{1.0, 2.0}, {-1.0, -2.0}}))); ComputeAndCompareR2( &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}}, diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 56702feab9a..897123d7606 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index e45e5291c9b..708e8c80d8b 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index 23070b66387..92f9b4f9f0e 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 373c0d2d8d8..24e0784741a 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 0a1235b5e04..159ac1b7e1b 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 70cf2fb1b8a..4ea02faffcd 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -30,8 +31,9 @@ namespace xla { namespace { TEST(TextLiteralWriterTest, WritesFloatLiteral) { - auto literal = Literal::CreateR2({ - {3.14, 2.17}, {1.23, 4.56}, + auto literal = LiteralUtil::CreateR2({ + {3.14, 2.17}, + {1.23, 4.56}, }); string path = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever"); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index e4a052c8f1c..55501827f29 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -74,7 +74,7 @@ cc_library( srcs = ["replay_computation.cc"], deps = [ "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -123,7 +123,7 @@ tf_cc_binary( name = "show_literal", srcs = ["show_literal.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", @@ -145,7 +145,7 @@ tf_cc_binary( name = "show_text_literal", srcs = ["show_text_literal.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:text_literal_reader", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 3a7917cf304..854e797ec2e 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -43,7 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index fe8e72ba32b..51909190a3e 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 8525873e913..48c83748118 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/text_literal_reader.h" #include "tensorflow/compiler/xla/types.h"