Support unsigned indices for in-place DynamicUpdateSlice.
For unsigned indices, we need to use unsigned comparisons when clamping the start_indices. Also rename the files from ops.* to dynamic_update_slice_util.* PiperOrigin-RevId: 205072344
This commit is contained in:
parent
ff791a7fde
commit
a46c9ab441
@ -252,12 +252,12 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||||
"//tensorflow/compiler/xla/service:name_uniquer",
|
"//tensorflow/compiler/xla/service:name_uniquer",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
|
"//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
|
||||||
|
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
|
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
|
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:ops",
|
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@llvm//:code_gen",
|
"@llvm//:code_gen",
|
||||||
|
@ -51,10 +51,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
@ -162,6 +162,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:name_uniquer",
|
"//tensorflow/compiler/xla/service:name_uniquer",
|
||||||
|
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
|
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
|
||||||
@ -169,7 +170,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
|
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
|
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:ops",
|
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
|
@ -59,10 +59,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
||||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
@ -164,9 +164,9 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "ops",
|
name = "dynamic_update_slice_util",
|
||||||
srcs = ["ops.cc"],
|
srcs = ["dynamic_update_slice_util.cc"],
|
||||||
hdrs = ["ops.h"],
|
hdrs = ["dynamic_update_slice_util.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":fused_ir_emitter",
|
":fused_ir_emitter",
|
||||||
":ir_array",
|
":ir_array",
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
|
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
|
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||||
@ -38,8 +38,8 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
|
|||||||
// Emits a sequential loop if launch_dimensions is null.
|
// Emits a sequential loop if launch_dimensions is null.
|
||||||
static Status EmitDynamicUpdateSliceInPlaceImpl(
|
static Status EmitDynamicUpdateSliceInPlaceImpl(
|
||||||
const Shape& update_shape, const ElementGenerator& start_indices_generator,
|
const Shape& update_shape, const ElementGenerator& start_indices_generator,
|
||||||
ElementGenerator update_array_generator, const IrArray& output_array,
|
bool is_signed, ElementGenerator update_array_generator,
|
||||||
const gpu::LaunchDimensions* launch_dimensions,
|
const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
|
||||||
tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) {
|
tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) {
|
||||||
const Shape& output_shape = output_array.GetShape();
|
const Shape& output_shape = output_array.GetShape();
|
||||||
|
|
||||||
@ -59,17 +59,20 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
|
|||||||
|
|
||||||
// TODO(b/74360564): This is implementation defined behavior, but is
|
// TODO(b/74360564): This is implementation defined behavior, but is
|
||||||
// currently respected by all implementations. Change this if we ever decide
|
// currently respected by all implementations. Change this if we ever decide
|
||||||
// to oficially document different behavior.
|
// to officially document different behavior.
|
||||||
llvm::Value* max_bound =
|
llvm::Value* max_bound =
|
||||||
ir_builder->CreateSub(output_dim_size, update_dim_size);
|
ir_builder->CreateSub(output_dim_size, update_dim_size);
|
||||||
llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0);
|
llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0);
|
||||||
start_index[i] = ir_builder->CreateSelect(
|
start_index[i] = ir_builder->CreateSelect(
|
||||||
ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]),
|
ir_builder->CreateICmp(
|
||||||
|
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
|
||||||
|
zero, start_index[i]),
|
||||||
zero, start_index[i]);
|
zero, start_index[i]);
|
||||||
|
|
||||||
start_index[i] = ir_builder->CreateSelect(
|
start_index[i] = ir_builder->CreateSelect(
|
||||||
ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound,
|
ir_builder->CreateICmp(
|
||||||
start_index[i]),
|
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
|
||||||
|
max_bound, start_index[i]),
|
||||||
max_bound, start_index[i]);
|
max_bound, start_index[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,8 +125,9 @@ Status EmitDynamicUpdateSliceInPlace(
|
|||||||
return update_array.EmitReadArrayElement(index, ir_builder);
|
return update_array.EmitReadArrayElement(index, ir_builder);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape());
|
||||||
return EmitDynamicUpdateSliceInPlaceImpl(
|
return EmitDynamicUpdateSliceInPlaceImpl(
|
||||||
update_shape, start_indices_generator, update_array_generator,
|
update_shape, start_indices_generator, is_signed, update_array_generator,
|
||||||
output_array, /*launch_dimensions=*/nullptr, name, ir_builder);
|
output_array, /*launch_dimensions=*/nullptr, name, ir_builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,8 +174,9 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
|
|||||||
ElementGenerator start_indices_generator =
|
ElementGenerator start_indices_generator =
|
||||||
fused_emitter.GetGenerator(start_indices);
|
fused_emitter.GetGenerator(start_indices);
|
||||||
|
|
||||||
|
bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
|
||||||
return EmitDynamicUpdateSliceInPlaceImpl(
|
return EmitDynamicUpdateSliceInPlaceImpl(
|
||||||
update_shape, start_indices_generator, update_array_generator,
|
update_shape, start_indices_generator, is_signed, update_array_generator,
|
||||||
fusion_output_array, launch_dimensions, IrName(fusion), ir_builder);
|
fusion_output_array, launch_dimensions, IrName(fusion), ir_builder);
|
||||||
}
|
}
|
||||||
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
||||||
@ -90,4 +90,4 @@ Status EmitParallelFusedDynamicUpdateSliceInPlace(
|
|||||||
} // namespace llvm_ir
|
} // namespace llvm_ir
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_
|
Loading…
Reference in New Issue
Block a user