diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc index 99046c0bd76..3cc68f2a1a4 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -91,7 +91,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, } string message = absl::StrCat( "Function invoked by the following node is not compilable: ", - SummarizeNodeDef(node_def), ".\n"); + SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); absl::StrAppend(&message, "Uncompilable nodes:"); for (const auto& node_info : uncompilable_node_info) { string node_message = diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index e0ec990462b..8c24f182f5c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -201,9 +201,7 @@ void XlaComputationLaunchContext::PopulateInputs( se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. - arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); - arg_buffers_.resize(kernel->xla_input_shapes.size()); - arg_ptrs_ = std::vector(arg_buffers_.size()); + arg_ptrs_ = std::vector(kernel->xla_input_shapes.size()); // Pass remaining parameters. const Tensor* t; @@ -239,11 +237,11 @@ void XlaComputationLaunchContext::PopulateInputs( << " not the same as on-host shape " << xla::ShapeUtil::HumanStringWithLayout(shape); se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); - arg_buffers_[i] = absl::make_unique( + arg_buffers_.emplace_back( /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(), client_->default_device_ordinal()); - arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); - arg_ptrs_[i] = arg_buffers_[i].get(); + arg_buffers_.back().set_buffer(dmem, /*index=*/{}); + arg_ptrs_[i] = &arg_buffers_.back(); } } } diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 511e0f1451a..cf68dcb7dd6 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -165,7 +165,7 @@ class XlaComputationLaunchContext { se::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; bool use_multiple_streams_; - std::vector> arg_buffers_; + std::deque arg_buffers_; std::vector arg_ptrs_; }; diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 6eff7dbd084..1a508bdb190 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -279,22 +279,6 @@ cc_library( ], ) -tf_cc_test( - name = "tftext_utils_test", - size = "small", - srcs = ["utils/lstm_utils_test.cc"], - deps = [ - ":lstm_utils", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - ], -) - cc_library( name = "stateful_ops_utils", srcs = [ diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 923efdbaf9d..edb533c9442 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -2297,26 +2297,17 @@ def TFL_PReluOp : TFL_Op<"prelu", [ NoSideEffect, ResultsBroadcastableShape, TFL_GpuTargetOp, - TFL_OperandHasRankAtMost<0, 4>, - TFL_OperandHasRankAtMost<1, 4>, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, BinaryOpSameElementTypeConstraint, PredOpTrait<"input and output must have the same element type", - TFL_TCresVTEtIsSameAsOp<0, 0>>, - PredOpTrait<"'alpha' should have one less rank than 'input'.", - Or<[TFL_OperandIsUnrankedPred<0>, - TFL_OperandIsUnrankedPred<1>, - CPred<"$_op.getOperand(0).getType().cast().getRank() == " - "$_op.getOperand(1).getType().cast().getRank() " - "+ 1">]>>]> { + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "Parameterized Relu operator"; let description = [{ Parameterized Relu operator x -> x >= 0 ? x : (alpha * x) where alpha is a trainable tensor. - alpha should have one less rank than the input as it doesn't have the batch - dimension, and the other dimensions either should be the same size as input - or size 1, where it is broadcasted in the second case. + input and alpha should be the same size as input or be broadcastable. }]; let arguments = ( diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index a07b7b8dd1d..8a2faebcbe6 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -55,8 +55,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, std::vector node_names; std::vector node_dtypes; std::vector> node_shapes; - std::vector node_mins; - std::vector node_maxs; + std::vector> node_mins; + std::vector> node_maxs; // Populate quantization specs. TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 51fcbb97360..ab80746f8b7 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer( std::vector node_names; std::vector node_dtypes; std::vector> node_shapes; - std::vector node_mins; - std::vector node_maxs; + std::vector> node_mins; + std::vector> node_maxs; // Populate quantization specs. TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index a1401323e89..8f2c8bc362c 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -177,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) { return RegisterCustomBuiltinOps(extra_tf_opdefs); } -Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - mlir::TFL::QuantizationSpecs* quant_specs, - std::vector* node_names, - std::vector* node_dtypes, - std::vector>* node_shapes, - std::vector* node_mins, - std::vector* node_maxs) { +Status PopulateQuantizationSpecs( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + mlir::TFL::QuantizationSpecs* quant_specs, std::vector* node_names, + std::vector* node_dtypes, + std::vector>* node_shapes, + std::vector>* node_mins, + std::vector>* node_maxs) { quant_specs->inference_input_type = ConvertIODataTypeToDataType(toco_flags.inference_input_type()); tensorflow::DataType inference_type = @@ -211,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, flag.shape().dims().end())); // Currently, only UINT8 and INT8 require inputs stats if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) { - TF_ASSIGN_OR_RETURN( - auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(), - inference_type)); - node_mins->push_back(min_max.first); - node_maxs->push_back(min_max.second); + if (flag.has_mean_value() && flag.has_std_value()) { + TF_ASSIGN_OR_RETURN( + auto min_max, InputStatsToMinMax(flag.mean_value(), + flag.std_value(), inference_type)); + node_mins->push_back(min_max.first); + node_maxs->push_back(min_max.second); + } else { + node_mins->push_back(llvm::None); + node_maxs->push_back(llvm::None); + } } } diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index 3ea36e5eb1d..87e73912a46 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags); // Populate quantization specs (or not) given user specified ranges for each // input arrays. -Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - mlir::TFL::QuantizationSpecs* quant_specs, - std::vector* node_names, - std::vector* node_dtypes, - std::vector>* node_shapes, - std::vector* node_mins, - std::vector* node_maxs); +Status PopulateQuantizationSpecs( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + mlir::TFL::QuantizationSpecs* quant_specs, std::vector* node_names, + std::vector* node_dtypes, + std::vector>* node_shapes, + std::vector>* node_mins, + std::vector>* node_maxs); // Convert imported MLIR file to TfLite flatbuffer. // This will also run relevant passes as well. diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc index 6b897bd5608..3edd9c36760 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc @@ -45,7 +45,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, absl::string_view inference_type, QuantizationSpecs* quant_specs) { std::vector input_nodes = absl::StrSplit(node_names, ','); - std::vector node_mins; + std::vector> node_mins; if (!min_values.empty()) { std::vector node_mins_str = absl::StrSplit(min_values, ','); for (int i = 0; i < node_mins_str.size(); i++) { @@ -57,7 +57,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, } } - std::vector node_maxs; + std::vector> node_maxs; if (!max_values.empty()) { std::vector node_maxs_str = absl::StrSplit(max_values, ','); for (int i = 0; i < node_maxs_str.size(); i++) { @@ -79,11 +79,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, quant_specs); } -bool GetInputNodeQuantSpecs(const std::vector& node_names, - const std::vector& node_mins, - const std::vector& node_maxs, - tensorflow::DataType inference_type, - QuantizationSpecs* quant_specs) { +bool GetInputNodeQuantSpecs( + const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) { quant_specs->inference_type = inference_type; // If min/max are not specified, just return; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index 2ffba579548..a4046553d17 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -19,6 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_ +#include #include #include @@ -69,7 +70,8 @@ struct QuantizationSpecs { // arguments. They are only used when `weight_quantization` is set to false, // and the model is required to have quantization parameters, either from // quantization aware training or calibration, for the remaining tensors. - std::vector> input_ranges; + std::vector, llvm::Optional>> + input_ranges; // The default ranges can be used when a tensor doesn't have quantization // parameters and couldn't be quantized. Used only for latency tests. @@ -130,11 +132,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, // Gets the quantization specification for input arrays. The array names are not // stored in the spec, and will be matched by position. The min/max will be // ignored if the inference_type isn't a quantized type. Returns true if failed. -bool GetInputNodeQuantSpecs(const std::vector& node_names, - const std::vector& node_mins, - const std::vector& node_maxs, - tensorflow::DataType inference_type, - QuantizationSpecs* quant_specs); +bool GetInputNodeQuantSpecs( + const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + tensorflow::DataType inference_type, QuantizationSpecs* quant_specs); } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 87cae3dd957..702808ac892 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -109,8 +109,8 @@ class PrepareQuantizePass // Get the min and max values from the quantization specification for the // current function function and argument index. Uses default values if // the function is specified in the `quantize_whitelist`. - std::pair GetMinMaxValuesForArgument( - llvm::StringRef func_name, int index) { + std::pair, llvm::Optional> + GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) { if (func_name == quant_specs_.target_func) { return quant_specs_.input_ranges[index]; } else { @@ -160,10 +160,14 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) { } auto min_max = GetMinMaxValuesForArgument(func_name, i); + // The input min/max or mean/std are not specified, then skip. + if (!min_max.first.hasValue() || !min_max.second.hasValue()) return; + TypeAttr params = quant::GetQuantizedTypeAttr( - builder, input_type, builder.getF64FloatAttr(min_max.first), - builder.getF64FloatAttr(min_max.second), /*quant_dim=*/-1, num_bits, - narrow_range, is_signed); + builder, input_type, + builder.getF64FloatAttr(min_max.first.getValue()), + builder.getF64FloatAttr(min_max.second.getValue()), + /*quant_dim=*/-1, num_bits, narrow_range, is_signed); builder.setInsertionPoint(block, insertion_point); auto q_op = builder.create(loc, params.getValue(), arg); diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 5df57de6f71..081ba7ac6e7 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/test.h" namespace mlir { @@ -92,7 +93,9 @@ class LstmUtilsTest : public ::testing::Test { LstmUtilsTest() {} void SetUp() override { - builder_ = std::unique_ptr(new Builder(&context_)); + RegisterDialects(); + context_ = std::make_unique(); + builder_ = std::unique_ptr(new Builder(context_.get())); fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false); fused_lstm_func_cifg_ = createLstmCompositeFunc(builder_.get(), false, true); @@ -105,10 +108,17 @@ class LstmUtilsTest : public ::testing::Test { fused_ln_lstm_func_.erase(); builder_.reset(); } + + void RegisterDialects() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + } + FuncOp fused_lstm_func_; FuncOp fused_lstm_func_cifg_; FuncOp fused_ln_lstm_func_; - mlir::MLIRContext context_; + std::unique_ptr context_; std::unique_ptr builder_; }; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 1df8f7fd519..9f407ea774a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1318,6 +1318,126 @@ greater than `clip_value_max` are set to `clip_value_max`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CollectiveBcastRecvOp : TF_Op<"CollectiveBcastRecv", []> { + let summary = "Receives a tensor value broadcast from another device."; + + let description = [{ + }]; + + let arguments = (ins + I64Attr:$group_size, + I64Attr:$group_key, + I64Attr:$instance_key, + TF_ShapeAttr:$shape, + DefaultValuedAttr:$communication_hint + ); + + let results = (outs + TensorOf<[F16, F32, F64, I1, I32, I64]>:$data + ); + + TF_DerivedResultTypeAttr T = TF_DerivedResultTypeAttr<0>; +} + +def TF_CollectiveBcastSendOp : TF_Op<"CollectiveBcastSend", []> { + let summary = "Broadcasts a tensor value to one or more other devices."; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, I1, I32, I64]>:$input, + + I64Attr:$group_size, + I64Attr:$group_key, + I64Attr:$instance_key, + TF_ShapeAttr:$shape, + DefaultValuedAttr:$communication_hint + ); + + let results = (outs + TensorOf<[F16, F32, F64, I1, I32, I64]>:$data + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_CollectiveGatherOp : TF_Op<"CollectiveGather", []> { + let summary = [{ +Mutually accumulates multiple tensors of identical type and shape. + }]; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, I32, I64]>:$input, + + I64Attr:$group_size, + I64Attr:$group_key, + I64Attr:$instance_key, + TF_ShapeAttr:$shape, + DefaultValuedAttr:$communication_hint + ); + + let results = (outs + TensorOf<[F16, F32, F64, I32, I64]>:$data + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_CollectivePermuteOp : TF_Op<"CollectivePermute", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "An Op to permute tensors across replicated TPU instances."; + + let description = [{ +Each instance supplies its own input. + +For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing +source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: +`[D, A, B, C]`. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + I32Tensor:$source_target_pairs + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_CollectiveReduceOp : TF_Op<"CollectiveReduce", [SameOperandsAndResultType]> { + let summary = [{ +Mutually reduces multiple tensors of identical type and shape. + }]; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, I32, I64]>:$input, + + I64Attr:$group_size, + I64Attr:$group_key, + I64Attr:$instance_key, + TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op, + TF_AnyStrAttrOf<["Id", "Div"]>:$final_op, + I64ArrayAttr:$subdiv_offsets, + DefaultValuedAttr:$wait_for, + DefaultValuedAttr:$communication_hint + ); + + let results = (outs + TensorOf<[F16, F32, F64, I32, I64]>:$data + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> { let summary = "Converts two real numbers to a complex number."; diff --git a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir index ad007d0eb50..d6c164f8160 100644 --- a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir +++ b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir @@ -203,12 +203,12 @@ func @moving_alloc_and_inserting_missing_dealloc(%cond : i1, %arg0 : memref<2xf3 "buffer_assignment_test.unary_lowered"(%arg0, %1) : (memref<2xf32>, memref<2xf32>) -> () br ^exit(%1 : memref<2xf32>) ^exit(%arg2: memref<2xf32>): - "bufer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + "buffer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () return } // CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() // CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() -// CHECK: "bufer_assignment_test.copy" +// CHECK: "buffer_assignment_test.copy" // CHECK-NEXT: dealloc // CHECK-NEXT: dealloc // CHECK-NEXT: return @@ -226,11 +226,11 @@ func @moving_invalid_dealloc_op_complex(%cond : i1, %arg0 : memref<2xf32>, %arg1 dealloc %1 : memref<2xf32> br ^exit(%1 : memref<2xf32>) ^exit(%arg2: memref<2xf32>): - "bufer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + "buffer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () return } // CHECK-NEXT: %[[ALLOC:.*]] = alloc() -// CHECK: bufer_assignment_test.copy +// CHECK: buffer_assignment_test.copy // CHECK-NEXT: dealloc // CHECK-NEXT: return @@ -240,10 +240,10 @@ func @moving_invalid_dealloc_op_complex(%cond : i1, %arg0 : memref<2xf32>, %arg1 func @inserting_missing_dealloc_simple(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){ %0 = alloc() : memref<2xf32> "buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> () - "bufer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + "buffer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () return } -// CHECK: bufer_assignment_test.copy +// CHECK: buffer_assignment_test.copy // CHECK-NEXT: dealloc // ----- @@ -253,8 +253,8 @@ func @moving_invalid_dealloc_op(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){ %0 = alloc() : memref<2xf32> "buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> () dealloc %0 : memref<2xf32> - "bufer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + "buffer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () return } -// CHECK: bufer_assignment_test.copy -// CHECK-NEXT: dealloc \ No newline at end of file +// CHECK: buffer_assignment_test.copy +// CHECK-NEXT: dealloc diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment_test.cc b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment_test.cc index 5a0d791079c..40c115f4cbc 100644 --- a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment_test.cc +++ b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment_test.cc @@ -29,60 +29,66 @@ limitations under the License. namespace mlir { namespace xla { namespace { + +/// This dialect independent unary operation has been defined only for testing +/// buffer assignment. +class BufferAssignmentTestUnaryOp + : public Op { + public: + using Op::Op; + static StringRef getOperationName() { return "buffer_assignment_test.unary"; } + static void build(OpBuilder& b, OperationState& state, Value source) { + state.addOperands(source); + } +}; + +/// This dialect independent lowered unary operation has been defined only for +/// testing buffer assignment. +class BufferAssignmentTestUnaryLoweredOp + : public Op::Impl> { + public: + using Op::Op; + static StringRef getOperationName() { + return "buffer_assignment_test.unary_lowered"; + } + static void build(OpBuilder& b, OperationState& state, Value source, + Value target) { + state.addOperands(source); + state.addOperands(target); + } +}; + +/// This dialect independent copy operation has been defined only for testing +/// NonVoidToVoidReturnOpConverter +class BufferAssignmentTestCopyOp + : public Op::Impl> { + public: + using Op::Op; + static StringRef getOperationName() { return "buffer_assignment_test.copy"; } + static void build(OpBuilder& b, OperationState& state, Value from, Value to) { + state.addOperands(from); + state.addOperands(to); + } +}; + +class BufferAssignmentTestDialect : public Dialect { + public: + explicit BufferAssignmentTestDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { + addOperations(); + } + static StringRef getDialectNamespace() { return "buffer_assignment_test"; } +}; + /// This pass tests two provided operation converters, /// FunctionAndBlockSignatureConverter and NonVoidToVoidReturnOpConverter, for /// Buffer Assignment. struct BufferAssignmentPreparationTestPass : mlir::PassWrapper { - /// This dialect independent unary operation has been defined only for testing - /// buffer assignment. - class BufferAssignmentTestUnaryOp - : public Op { - public: - using Op::Op; - static StringRef getOperationName() { - return "buffer_assignment_test.unary"; - } - static void build(OpBuilder& b, OperationState& state, Value source) { - state.addOperands(source); - } - }; - - /// This dialect independent lowered unary operation has been defined only for - /// testing buffer assignment. - class BufferAssignmentTestUnaryLoweredOp - : public Op::Impl> { - public: - using Op::Op; - static StringRef getOperationName() { - return "buffer_assignment_test.unary_lowered"; - } - static void build(OpBuilder& b, OperationState& state, Value source, - Value target) { - state.addOperands(source); - state.addOperands(target); - } - }; - - /// This dialect independent copy operation has been defined only for testing - /// NonVoidToVoidReturnOpConverter - class BufferAssignmentTestCopyOp - : public Op::Impl> { - public: - using Op::Op; - static StringRef getOperationName() { - return "buffer_assignment_test.copy"; - } - static void build(OpBuilder& b, OperationState& state, Value from, - Value to) { - state.addOperands(from); - state.addOperands(to); - } - }; - /// A simple converter that legalizes a BufferAssignmentTestUnaryOp to a /// BufferAssignmentTestUnaryLoweredOp and creates buffer allocation for /// the result of the computation. @@ -151,8 +157,12 @@ struct BufferAssignmentPreparationTestPass } }; }; + } // namespace +static mlir::DialectRegistration + buffer_assignment_test_ops; + /// This pass tests helper methods such as computeAllocPosition, /// FunctionAndBlockSignatureConverter, NonVoidToVoidReturnOpConverter /// conversion patterns. Furthermore, it checks buffer-assignment pass that diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index c7be2c55de7..422695c374b 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import itertools -import os import numpy as np @@ -1609,8 +1608,4 @@ class BinaryOpsTest(xla_test.XLATestCase): if __name__ == "__main__": - # TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems - os.environ[ - "XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get( - "XLA_FLAGS", "") googletest.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index d0e928a5ce6..85bf89c4f9e 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -347,17 +347,15 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array( [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype)) - # TODO(b/130689556): Turn this on for CPU when we start honoring NaNs. - if self.device != "XLA_CPU": - self._assertOpOutputMatchesExpected( - math_ops.tanh, - np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], - [19, -19, 22, -22]], - dtype=dtype), - expected=np.array( - [[0.76159418, 0.96402758, 0.99505478, 0.99932933], - [1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]], - dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], + [19, -19, 22, -22]], + dtype=dtype), + expected=np.array( + [[0.76159418, 0.96402758, 0.99505478, 0.99932933], + [1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]], + dtype=dtype)) self._assertOpOutputMatchesExpected( nn_ops.log_softmax, diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index d01f094dc2e..976ff91f6ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -136,8 +136,11 @@ class TensorListReserveOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); OP_REQUIRES( ctx, num_elements >= 0, - errors::InvalidArgument("XLA compilation requires a fixed tensor list " - "size. Set the number of elements.")); + errors::InvalidArgument( + "XLA compilation requires a fixed tensor list size. Set the number " + "of elements. This could also happen if you're using a TensorArray " + "in a while loop that does not have its maximum_iteration set, you " + "can fix this by setting maximum_iteration to a suitable value.")); // If element shape is compile time constant and it's not "unknown rank" // shape (-1), create an initialized TensorList. Otherwise create an @@ -197,10 +200,13 @@ class EmptyTensorListOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { int64 max_num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); - OP_REQUIRES( - ctx, max_num_elements >= 0, - errors::InvalidArgument("XLA compilation requires a fixed tensor list " - "size. Set the max number of elements.")); + OP_REQUIRES(ctx, max_num_elements >= 0, + errors::InvalidArgument( + "XLA compilation requires a fixed tensor list size. Set " + "the max number of elements. This could also happen if " + "you're using a TensorArray in a while loop that does not " + "have its maximum_iteration set, you can fix this by " + "setting maximum_iteration to a suitable value.")); if (dtype_ != DT_VARIANT) { // We are creating a non-nested TensorList. diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index f1ac1fef451..5d7bd26b01e 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -63,10 +63,6 @@ class ExecutionInput { explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {} explicit ExecutionInput(ShapeTree buffers) : buffers_(std::move(buffers)) {} - ExecutionInput(ShapeTree buffers, - std::vector owner_held_indices) - : buffers_(std::move(buffers)), - unowned_indices_(std::move(owner_held_indices)) {} ExecutionInput(ExecutionInput&&) = default; ~ExecutionInput() { diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 3c4e9f7c1e6..a3056b1ddad 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -40,7 +40,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { public: // A NestedComputer computes an element of the output of the given computation // given a Span of its input elements. - using NestedComputer = std::function( + using NestedComputer = std::function>( const HloComputation&, absl::Span)>; GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config, @@ -91,12 +91,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr> EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view) override { - // TODO(b/118332391): Supported variadic return values. - auto result = compute_nested_(callee, parameters); - if (!result.ok()) { - return result.status(); - } - return std::vector{result.ValueOrDie()}; + return compute_nested_(callee, parameters); } llvm::Value* EmitThreadId() override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 744cd7b56bf..aa8a6215cc7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -698,115 +698,6 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } -Status IrEmitter::HandleReduce(HloInstruction* instr) { - const HloReduceInstruction* reduce = Cast(instr); - const Shape& out_shape = reduce->shape(); - bool returns_tuple = !out_shape.IsArray(); - int accumulators_count = 1; - if (returns_tuple) { - CHECK(out_shape.IsTuple()); - accumulators_count = out_shape.tuple_shapes_size(); - } - - auto arg = reduce->operand(0); - absl::Span dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - return EmitTargetElementLoop( - *reduce, - [=](const llvm_ir::IrArray::Index& index) -> StatusOr { - std::vector accumulator_addrs; - std::vector accumulator_types; - - // Initialize accumulators with initial values. - for (int i = 0; i < accumulators_count; i++) { - auto init_value = reduce->init_values()[i]; - const Shape& element_shape = - returns_tuple ? out_shape.tuple_shapes(i) : out_shape; - PrimitiveType accumulator_type = element_shape.element_type(); - llvm::Type* accumulator_llvm_type = - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); - llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type); - Store(Load(GetBasePointer(*init_value)), accumulator_addr); - accumulator_addrs.push_back(accumulator_addr); - accumulator_types.push_back(accumulator_llvm_type); - } - - // The enclosing loops go over all the target elements. Now we have to - // compute the actual target element. For this, we build a new loop nest - // to iterate over all the reduction dimensions in the argument. - // AddLoopsForShapeOnDimensions will return an Index where induction - // Value*s are placed for each dimension in dimensions, and all the rest - // are nullptrs. - llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); - std::vector input_multi_index = - loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, - "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - - // Build a full index for the input argument, using reduced_dims_index - // as the base. In reduced_dims_index only the reduction dimensions are - // filled in. We fill in the rest of the dimensions with induction - // Value*s taken from 'index' which iterates over the target array. - // See the high-level description in the XLA documentation for details. - llvm_ir::IrArray::Index::const_iterator it = index.begin(); - - for (auto& i : input_multi_index) { - if (i == nullptr) { - i = *it++; - } - } - CHECK(index.end() == it); - - // Apply the reduction function to the loaded value. - llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), - b_.getInt64Ty()); - std::vector reduction_operands(accumulator_addrs.begin(), - accumulator_addrs.end()); - for (int i = 0; i < accumulators_count; i++) { - llvm::Value* input_address = - GetIrArray(*reduce->operand(i), *reduce) - .EmitArrayElementAddress(input_index, &b_); - reduction_operands.push_back(input_address); - } - - llvm::Value* ret_argument; - if (!returns_tuple) { - CHECK_EQ(accumulator_addrs.size(), 1); - ret_argument = accumulator_addrs[0]; - } else { - const Shape& return_shape = function->root_instruction()->shape(); - - llvm::Type* return_value_buffer_type = - llvm_ir::ShapeToIrType(return_shape, module_); - ret_argument = Alloca(return_value_buffer_type); - llvm_ir::IrArray tuple_array(ret_argument, return_shape); - EmitTuple(tuple_array, accumulator_addrs, &b_); - } - - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *function, reduction_operands, ret_argument)); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - - if (!returns_tuple) { - CHECK_EQ(accumulator_addrs.size(), 1); - return Load(accumulator_addrs[0]); - } else { - // Emit a struct for the LoopEmitter dealing with multi-output - // fusion. - llvm::Value* returned_structure = llvm::UndefValue::get( - llvm::StructType::get(b_.getContext(), accumulator_types)); - for (int i = 0; i < accumulators_count; i++) { - llvm::Value* accumulator_value = Load(accumulator_addrs[i]); - returned_structure = - b_.CreateInsertValue(returned_structure, accumulator_value, i); - } - return returned_structure; - } - }); -} - Status IrEmitter::HandleFusion(HloInstruction* fusion) { // kFusion for library calls should be handled by // IrEmitterUnnested::HandleFusion. @@ -866,22 +757,39 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -StatusOr IrEmitter::ComputeNestedElement( +StatusOr> IrEmitter::ComputeNestedElement( const HloComputation& computation, absl::Span parameter_elements) { + const Shape& return_shape = computation.root_instruction()->shape(); llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType( - computation.root_instruction()->shape().element_type(), module_), - "return_buffer", &b_); + llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_); std::vector parameter_buffers; for (llvm::Value* parameter_element : parameter_elements) { parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( parameter_element->getType(), "parameter_buffer", &b_)); Store(parameter_element, parameter_buffers.back()); } + + std::vector allocas_for_returned_scalars; + if (!return_shape.IsTuple()) { + allocas_for_returned_scalars.push_back(return_buffer); + } else { + allocas_for_returned_scalars = + llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_); + llvm_ir::IrArray tuple_array(return_buffer, return_shape); + + EmitTuple(tuple_array, allocas_for_returned_scalars, &b_); + } + TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, return_buffer)); - return Load(return_buffer); + + std::vector returned_scalars; + returned_scalars.reserve(allocas_for_returned_scalars.size()); + for (llvm::Value* addr : allocas_for_returned_scalars) { + returned_scalars.push_back(Load(addr)); + } + return returned_scalars; } std::vector IrEmitter::ConstructIrArrayForOutputs( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index e0fe454dcfe..93712961ea2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -89,7 +89,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleRecv(HloInstruction* recv) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleParameter(HloInstruction* parameter) override; - Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; @@ -213,7 +212,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, const llvm_ir::IrArray::Index& compare_keys_index, const llvm_ir::IrArray& keys_array); - StatusOr ComputeNestedElement( + StatusOr> ComputeNestedElement( const HloComputation& computation, absl::Span parameter_elements); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 3930898d665..ad21efa13c9 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -312,12 +312,13 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, absl::string_view label, - const DebugOptions& debug_options, bool show_backend_config, + const DebugOptions& debug_options, + HloRenderOptions hlo_render_options, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), label_(label), debug_options_(debug_options), - show_backend_config_(show_backend_config), + hlo_render_options_(hlo_render_options), profile_(profile), filter_(std::move(filter)) {} @@ -384,7 +385,7 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph const DebugOptions& debug_options_; - const bool show_backend_config_; + const HloRenderOptions hlo_render_options_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -565,7 +566,8 @@ bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) { bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { if (subcomp->IsFusionComputation()) { const HloInstruction* fusion = subcomp->FusionInstruction(); - if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) { + if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) || + !hlo_render_options_.show_fusion_subcomputations) { return false; } } @@ -1133,7 +1135,8 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeBackendConfig( const HloInstruction* instr) { - if (!show_backend_config_ || instr->raw_backend_config_string().empty()) { + if (!hlo_render_options_.show_backend_config || + instr->raw_backend_config_string().empty()) { return ""; } @@ -1604,14 +1607,14 @@ StatusOr RenderGraph(const HloComputation& computation, const DebugOptions& debug_options, RenderedGraphFormat format, const HloExecutionProfile* hlo_execution_profile, - bool show_backend_config) { + HloRenderOptions hlo_render_options) { tensorflow::mutex_lock lock(url_renderer_mu); if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { return Unavailable("Can't render as URL; no URL renderer was registered."); } string rendered_dot = - HloDotDumper(&computation, label, debug_options, show_backend_config, + HloDotDumper(&computation, label, debug_options, hlo_render_options, hlo_execution_profile, NodeFilter()) .Dump(); return WrapDotInFormat(rendered_dot, format); @@ -1619,7 +1622,7 @@ StatusOr RenderGraph(const HloComputation& computation, StatusOr RenderNeighborhoodAround( const HloInstruction& node, int radius, RenderedGraphFormat format, - bool show_backend_config, + HloRenderOptions hlo_render_options, const absl::flat_hash_set& boundary) { tensorflow::mutex_lock lock(url_renderer_mu); if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { @@ -1632,7 +1635,7 @@ StatusOr RenderNeighborhoodAround( string rendered_dot = HloDotDumper(node.parent(), label, node.GetModule()->config().debug_options(), - show_backend_config, /*profile=*/nullptr, + hlo_render_options, /*profile=*/nullptr, MakeNodeRadiusAroundFilter(&node, radius, boundary)) .Dump(); return WrapDotInFormat(rendered_dot, format); @@ -1641,7 +1644,7 @@ StatusOr RenderNeighborhoodAround( StatusOr RenderAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, int64 max_nodes, RenderedGraphFormat format, - bool show_backend_config) { + HloRenderOptions hlo_render_options) { tensorflow::mutex_lock lock(url_renderer_mu); if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { return FailedPrecondition( @@ -1663,7 +1666,7 @@ StatusOr RenderAllPathsFromTo(const HloInstruction& from, "NODES***

"); } string rendered_dot = - HloDotDumper(from.parent(), label, debug_options, show_backend_config, + HloDotDumper(from.parent(), label, debug_options, hlo_render_options, /*profile=*/nullptr, filter) .Dump(); return WrapDotInFormat(rendered_dot, format); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 324ac67a6dd..528de77e4e6 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -50,6 +50,14 @@ enum class RenderedGraphFormat { kUrl, }; +struct HloRenderOptions { + // Include the backend config string in the rendered graph. + bool show_backend_config = false; + + // Include the fusion subcomputations in the rendered graph. + bool show_fusion_subcomputations = true; +}; + // Renders an HLO module as a human-readable visual graph. // // Note that this only works well for relatively small graphs (no more than a @@ -61,7 +69,7 @@ StatusOr RenderGraph( const HloComputation& computation, absl::string_view label, const DebugOptions& debug_options, RenderedGraphFormat format, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_backend_config = false); + HloRenderOptions hlo_render_options = {}); // Like RenderGraph, but renders only nodes "near" the given node in the graph. // @@ -73,7 +81,7 @@ StatusOr RenderGraph( // will be omitted even if they are within the radius. StatusOr RenderNeighborhoodAround( const HloInstruction& node, int radius, RenderedGraphFormat format, - bool show_backend_config = false, + HloRenderOptions hlo_render_options = {}, const absl::flat_hash_set& boundary = {}); // Renders nodes on any of the paths from `from` to `to`. If there are more @@ -82,7 +90,7 @@ StatusOr RenderNeighborhoodAround( StatusOr RenderAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, int64 max_nodes, RenderedGraphFormat format, - bool show_backend_config = false); + HloRenderOptions hlo_render_options = {}); // Registers a function which implements RenderedGraphFormat::kUrl. // diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index b7a67b4e66e..995b0ece7cd 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -137,9 +137,8 @@ class ShapedBuffer { std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); -// ShapedBuffer derived class which allocates all internal buffers on -// construction and deallocates the memory when the object is -// destructed. +// ScopedShapedBuffer takes allocated buffers as inputs, and deallocates on +// destruction. This class represents an owning wrapper around `ShapedBuffer`. // // TODO(timshen): Remove inheritance between ScopedShapedBuffer and // ShapedBuffer. There should never be a need to consider a ScopedShapedBuffer diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index 8eee452328e..068442ad5c7 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -81,11 +81,10 @@ void SpmdLogger::RegisterLogEntry(HloInstruction* hlo, string report = hlo->ToString(); int64 max_value = -1; for (HloInstruction* inst : group) { - if (inst->shape().IsTuple()) { + if (!inst->shape().IsArray()) { continue; } - max_value = - std::max(max_value, ShapeUtil::ByteSizeOf(inst->shape(), 4)); + max_value = std::max(max_value, ShapeSizeInBytes(inst->shape())); absl::StrAppend(&report, " * ", inst->ToString(), "\n"); } entries_.push_back(std::make_pair(max_value, report)); @@ -149,14 +148,14 @@ template const auto add_report = [&](std::vector* insts) { std::sort(insts->begin(), insts->end(), [](const HloInstruction* inst0, const HloInstruction* inst1) { - return ShapeUtil::ByteSizeOf(inst0->shape()) > - ShapeUtil::ByteSizeOf(inst1->shape()); + return ShapeSizeInBytes(inst0->shape()) > + ShapeSizeInBytes(inst1->shape()); }); for (int64 i = 0; i < std::min(report_instruction_count, insts->size()); ++i) { absl::StrAppend(&report, " ", tensorflow::strings::HumanReadableNumBytes( - ShapeUtil::ByteSizeOf((*insts)[i]->shape())), + ShapeSizeInBytes((*insts)[i]->shape())), " : ", (*insts)[i]->ToString(), "\n"); } }; @@ -1180,8 +1179,8 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( operand, scatter_dims_to_operand_dims, slice_size, num_partitions_) && - ShapeUtil::ByteSizeOf(updates.base_shape()) < - ShapeUtil::ByteSizeOf(scatter->shape())) { + ShapeSizeInBytes(updates.base_shape()) < + ShapeSizeInBytes(scatter->shape())) { // Operand is sharded on trivial slice dims (update slice size 1). We can // adjust the indices on each partition by subtracting the offsets. Then // we execute a scatter on full updated indices, and out-of-bound accesses @@ -1968,8 +1967,8 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( operand, start_index_map, gather->gather_slice_sizes(), num_partitions_) && - ShapeUtil::ByteSizeOf(gather->shape()) < - ShapeUtil::ByteSizeOf(gather->operand(0)->shape())) { + ShapeSizeInBytes(gather->shape()) < + ShapeSizeInBytes(gather->operand(0)->shape())) { indices = indices.Reshard(HloSharding::Replicate()); // Now the operand is partitioned in trivial slice dimensions, and the // indices are replicated. We execute a gather on partitioned operand, @@ -2762,8 +2761,7 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs( auto zero = b_.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(hlo->shape().element_type()))); - if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < - ShapeUtil::ByteSizeOf(rhs.base_shape())) { + if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) { if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { return DefaultAction(hlo); } @@ -3005,8 +3003,8 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { }; auto zero = b_.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(hlo->shape().element_type()))); - if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < - ShapeUtil::ByteSizeOf(rhs.base_shape())) { + if (ShapeSizeInBytes(lhs.base_shape()) < + ShapeSizeInBytes(rhs.base_shape())) { if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { return DefaultAction(hlo); } @@ -3731,7 +3729,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper( }; if (output_lhs_non_contracting_partitions == num_partitions_ && output_sharding_transposed_to_match_lhs == lhs_sharding && - ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()) >= + ShapeSizeInBytes(hlo->operand(1)->shape()) >= options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { if (rhs_contracting_partitions == num_partitions_) { return emit_windowed_dot_general(0, 1, true, false); @@ -3745,7 +3743,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper( } if (output_rhs_non_contracting_partitions == num_partitions_ && output_sharding_transposed_to_match_rhs == rhs_sharding && - ShapeUtil::ByteSizeOf(hlo->operand(0)->shape()) >= + ShapeSizeInBytes(hlo->operand(0)->shape()) >= options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { if (lhs_contracting_partitions == num_partitions_) { return emit_windowed_dot_general(1, 0, true, false); @@ -3775,8 +3773,8 @@ Status SpmdPartitioningVisitor::HandleDotHelper( LiteralUtil::Zero(hlo->shape().element_type()))); // Pad both sides with zero, since NaN at one side cannot be masked by zero // on the other side. - if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < - ShapeUtil::ByteSizeOf(rhs.base_shape())) { + if (ShapeSizeInBytes(lhs.base_shape()) < + ShapeSizeInBytes(rhs.base_shape())) { lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); rhs = rhs.PadWithValue(zero); @@ -4607,8 +4605,8 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b, xpose_permutation[i] = i + tiled_dims.size() - split_dims_added; } else { xpose_permutation[i] = split_dims_added; + xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added; split_dims_added++; - xpose_permutation[i + 1] = i + tiled_dims.size(); i++; } } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 55d7dc43785..e766695385b 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -649,6 +649,43 @@ ENTRY entry { op::ReduceWindow(masked, op::Constant()))); } +TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1), + window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum, + sharding={devices=[5,1]0,1,2,3,4} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/5)); + VLOG(1) << module->ToString(); + auto halo0 = AllOf(op::Shape("f32[1,2]"), + op::CollectivePermute(op::Slice(op::Parameter(0)))); + auto halo1 = + AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0))); + auto pre_mask = + AllOf(op::Shape("f32[4,2]"), + op::Slice(AllOf(op::Shape("f32[5,2]"), + op::Concatenate(halo0, halo1, op::Parameter(0))))); + auto masked = + op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())), + op::Broadcast(op::Constant())), + pre_mask, op::Broadcast(op::Constant())); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) { const char* const hlo_string = R"( HloModule module diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc index 207f854cd9f..8db2ca84a05 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" +#include + #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -23,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -104,6 +107,11 @@ Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { return sharding.TileShape(shape); } +int64 ShapeSizeInBytes(const Shape& shape) { + return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) * + ShapeUtil::ElementsIn(shape); +} + Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, const HloSharding& sharding, int64 partition_id) { @@ -402,33 +410,30 @@ absl::optional ExchangeHalo( std::vector concat_pieces; int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); - if (max_left_halo_size > input_shard_size) { - VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor."; - return absl::nullopt; - } - if (max_left_halo_size > 0) { + for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0; + --i) { std::vector> source_target_pairs; target.tile_assignment().Each( [&](absl::Span indices, int64 device) { - if (indices[dim] > 0) { + if (indices[dim] > i) { std::vector source_indices(indices.begin(), indices.end()); - source_indices[dim] -= 1; + source_indices[dim] -= i + 1; source_target_pairs.emplace_back( target.tile_assignment()(source_indices), device); } }); + int64 halo_size = + std::min(max_left_halo_size - input_shard_size * i, input_shard_size); auto halo_shape = hlo->shape(); auto source_halo_slice = hlo; - if (max_left_halo_size != hlo->shape().dimensions(dim)) { - halo_shape.set_dimensions(dim, max_left_halo_size); + if (halo_size != hlo->shape().dimensions(dim)) { + halo_shape.set_dimensions(dim, halo_size); std::vector halo_start_indices(halo_shape.rank(), 0); - halo_start_indices[dim] = - hlo->shape().dimensions(dim) - max_left_halo_size; + halo_start_indices[dim] = hlo->shape().dimensions(dim) - halo_size; std::vector halo_slice_strides(halo_shape.rank(), 1); - - source_halo_slice = b->AddInstruction( - hlo->CreateSlice(halo_shape, hlo, halo_start_indices, - hlo->shape().dimensions(), halo_slice_strides)); + source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice( + halo_shape, hlo, halo_start_indices, hlo->shape().dimensions(), + halo_slice_strides)); } auto left_halo = collective_ops_creator.create_cross_partition_collective_permute( @@ -441,29 +446,30 @@ absl::optional ExchangeHalo( // Right halo. int64 max_right_halo_size = right_halo_size_function.MaxInRange(0, shard_count - 1); - if (max_right_halo_size > input_shard_size) { - VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor."; - return absl::nullopt; - } - if (max_right_halo_size > 0) { + for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size); + ++i) { std::vector> source_target_pairs; target.tile_assignment().Each( [&](absl::Span indices, int64 device) { - if (indices[dim] > 0) { + if (indices[dim] > i) { std::vector target_indices(indices.begin(), indices.end()); - target_indices[dim] -= 1; + target_indices[dim] -= i + 1; source_target_pairs.emplace_back( device, target.tile_assignment()(target_indices)); } }); + int64 halo_size = + std::min(max_right_halo_size - input_shard_size * i, input_shard_size); auto halo_shape = hlo->shape(); - halo_shape.set_dimensions(dim, max_right_halo_size); - std::vector halo_start_indices(halo_shape.rank(), 0); - std::vector halo_slice_strides(halo_shape.rank(), 1); - - auto source_halo_slice = b->AddInstruction( - hlo->CreateSlice(halo_shape, hlo, halo_start_indices, - halo_shape.dimensions(), halo_slice_strides)); + HloInstruction* source_halo_slice = hlo; + if (halo_size != halo_shape.dimensions(dim)) { + halo_shape.set_dimensions(dim, halo_size); + std::vector halo_start_indices(halo_shape.rank(), 0); + std::vector halo_slice_strides(halo_shape.rank(), 1); + source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice( + halo_shape, hlo, halo_start_indices, halo_shape.dimensions(), + halo_slice_strides)); + } auto right_halo = collective_ops_creator.create_cross_partition_collective_permute( b, source_halo_slice, source_target_pairs, (*next_channel_id)++); diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h index f96b23d7073..440f0e78112 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -57,6 +57,10 @@ bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding); // target sharding. Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding); +// Similar to ShapeUtil::ByteSizeOf(), but does not check it has dense layout +// since this can be before layout assignment. +int64 ShapeSizeInBytes(const Shape& shape); + // Returns the shard shape for a partition without padding due to uneven // sharding. Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz.cc b/tensorflow/compiler/xla/tools/interactive_graphviz.cc index 4f8a6b43314..b6c62beff74 100644 --- a/tensorflow/compiler/xla/tools/interactive_graphviz.cc +++ b/tensorflow/compiler/xla/tools/interactive_graphviz.cc @@ -112,8 +112,7 @@ constexpr int64 kDefaultMaxNumNodesInAllPaths = 100; using absl::EqualsIgnoreCase; -// A global control for whether backend configuration display is enabled. -bool show_backend_config = true; +HloRenderOptions hlo_render_options; HloInstruction* FindInstruction(const HloModule& module, string node_name) { if (absl::StartsWith(node_name, "%")) { @@ -160,6 +159,8 @@ void DoHelpCommand() { Renders all nodes in . backend_config [on|off] Controls whether backend operation configuration information is printed. + show_fusion_subcomputations [on|off] + Controls whether fusion subcomputations are shown. list [name|op_name|op_type] Lists all instructions whose name, metadata op_name, or metadata op_type contains as a substring. @@ -182,15 +183,32 @@ void DoHelpCommand() { // Turn metadata-printing on or off. void DoBackendConfigCommand(const std::vector& tokens) { if (tokens.size() == 2 && tokens[1] == "on") { - show_backend_config = true; + hlo_render_options.show_backend_config = true; } else if (tokens.size() == 2 && tokens[1] == "off") { - show_backend_config = false; + hlo_render_options.show_backend_config = false; } else if (tokens.size() != 1) { std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)" << std::endl; } std::cout << "Backend configuration display " - << (show_backend_config ? "ON" : "OFF") << std::endl; + << (hlo_render_options.show_backend_config ? "ON" : "OFF") + << std::endl; +} + +// Turn fusion computation display on or off. +void DoShowFusionSubcomputationsCommand(const std::vector& tokens) { + if (tokens.size() == 2 && tokens[1] == "on") { + hlo_render_options.show_fusion_subcomputations = true; + } else if (tokens.size() == 2 && tokens[1] == "off") { + hlo_render_options.show_fusion_subcomputations = false; + } else if (tokens.size() != 1) { + std::cerr << "(Illegal show_fusion_subcomputations value. Use either " + "'on' or 'off'.)" + << std::endl; + } + std::cout << "Fusion subcomputations display " + << (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF") + << std::endl; } // List all computations in the module. @@ -373,7 +391,7 @@ void DoExtractCommand(const HloModule& module, auto extracted_module = ExtractModule(instr, height); std::cout << extracted_module->ToString( HloPrintOptions::ShortParsable().set_print_backend_config( - show_backend_config)) + hlo_render_options.show_backend_config)) << std::endl; } @@ -517,7 +535,7 @@ void DoAllPathsCommand(const Options& opts, const HloModule& module, } RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { return RenderAllPathsFromTo(*from, *to, max_nodes, format, - /*show_backend_config=*/show_backend_config); + hlo_render_options); }); } @@ -582,15 +600,13 @@ void DoPlotCommand(const Options& opts, const HloModule& module, RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { return RenderGraph(*comp, /*label=*/"", comp->parent()->config().debug_options(), format, - /*hlo_execution_profile=*/nullptr, - /*show_backend_config=*/show_backend_config); + /*hlo_execution_profile=*/nullptr, hlo_render_options); }); } else { RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { - return RenderNeighborhoodAround( - *instr, graph_width, format, - /*show_backend_config=*/show_backend_config, - /*boundary=*/boundary); + return RenderNeighborhoodAround(*instr, graph_width, format, + hlo_render_options, + /*boundary=*/boundary); }); } } @@ -617,6 +633,8 @@ void InteractiveDumpGraphs(const Options& opts, const HloModule& module) { DoHelpCommand(); } else if (tokens[0] == "backend_config") { DoBackendConfigCommand(tokens); + } else if (tokens[0] == "show_fusion_subcomputations") { + DoShowFusionSubcomputationsCommand(tokens); } else if (tokens[0] == "list") { if (tokens.size() > 1 && tokens[1] == "computations") { DoListComputationsCommand(module, tokens); diff --git a/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt index 09eff6177b1..ae5942b3617 100644 --- a/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt @@ -84,6 +84,13 @@ END name: "Tout" description: <