diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index af0a192639c..46d2354b779 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2079,6 +2079,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() { "XlaSelectAndScatter", "XlaSelfAdjointEig", "XlaSend", + "XlaSetBound", "XlaSharding", "XlaSort", "XlaSpmdFullToShardShape", diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 85917af6a65..75faa2eac81 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Shape Ops. +#include "absl/strings/str_format.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -65,6 +67,47 @@ class ShapeOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp); +class XlaSetBoundOp : public XlaOpKernel { + public: + explicit XlaSetBoundOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape bound_shape = ctx->InputShape("bound"); + + OP_REQUIRES( + ctx, + ctx->InputType("bound") == DT_INT32 && + ctx->InputType("input") == DT_INT32, + errors::InvalidArgument( + "XlaSetBound can only set bound for int32 scalar value: got", + input_shape.DebugString())); + + OP_REQUIRES( + ctx, input_shape.dims() == 0, + errors::InvalidArgument("XlaSetBound should only be used to set a " + "bound to the an int32 scalar value: got", + input_shape.DebugString())); + + OP_REQUIRES( + ctx, bound_shape.dims() == 0, + errors::InvalidArgument("XlaSetBound should only be used to set a " + "bound to the an int32 scalar value: got", + bound_shape.DebugString())); + int64 bound; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound)); + + xla::XlaOp result = xla::CustomCall( + ctx->builder(), "SetBound", {ctx->Input("input")}, + ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound)); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("XlaSetBound").CompileTimeConstantInput("bound"), + XlaSetBoundOp); + class ShapeNOp : public XlaOpKernel { public: explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index f4b9e9654d2..2f895b17219 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -291,6 +291,16 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto. precision_config: a serialized xla::PrecisionConfig proto. )doc"); +REGISTER_OP("XlaSetBound") + .Input("input: int32") + .Input("bound: int32") + .Output("output: int32") + .SetShapeFn(shape_inference::UnknownShape) + .Doc( + R"doc(Set a bound for the given input value as a hint to Xla compiler, + returns the same value. +)doc"); + REGISTER_OP("XlaDynamicSlice") .Input("input: T") .Input("start_indices: Tindices") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 846dafa2570..19104518b71 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -387,6 +387,14 @@ def reduce_window(operand, replica_id = gen_xla_ops.xla_replica_id +# Set a static bound for the given input value as a hint to Xla compiler, +# returns the same value. +# Usage: +# def f(t, p): +# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3. +# return t[:p] # xla knows the bound of the slice is 3. +set_bound = gen_xla_ops.xla_set_bound + def reshape(x, new_sizes, dimensions=None, name=None): if dimensions is not None: diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 34d78f9d933..3e2a4eb53a7 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" namespace xla { @@ -117,6 +119,15 @@ HloComputationProto CreateReduceOr(int64 reducer_id, } return reducer; } + +bool InstrIsSetBound(const HloInstructionProto* instr_proto) { + HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie(); + if (opcode == HloOpcode::kCustomCall && + instr_proto->custom_call_target() == "SetBound") { + return true; + } + return false; +} } // namespace namespace internal { @@ -293,7 +304,6 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, // GetDimensionSize is always considered constant in XLA -- If a dynamic // dimension is presented, -1 is returned. break; - // Non functional ops. case HloOpcode::kRng: case HloOpcode::kAllReduce: @@ -306,6 +316,11 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, // cannot be constant. We cannot set is_functional=false in other similar // cases since we're already relying on IsConstant to return true. case HloOpcode::kCustomCall: + if (instr.custom_call_target() == "SetBound") { + // Set bound is considered constant -- the bound is used as the value. + break; + } + TF_FALLTHROUGH_INTENDED; case HloOpcode::kWhile: // TODO(b/32495713): We aren't checking the condition and body // computations themselves. @@ -3086,6 +3101,15 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { case HloOpcode::kConstant: SetInstructionAsConstant(new_instr, id, new_shape, false); break; + case HloOpcode::kCustomCall: + if (instr_proto->custom_call_target() == "SetBound") { + SetInstructionAsConstant(new_instr, id, new_shape, true); + break; + } else { + return InvalidArgument( + "Dynamic inferencing on custom call %s is not supported", + instr_proto->DebugString()); + } case HloOpcode::kParameter: SetInstructionAsConstant(new_instr, id, new_shape, true); break; @@ -3149,7 +3173,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(instr_proto->opcode())); if (next_operand >= instr_proto->operand_ids_size() || - opcode == HloOpcode::kGetDimensionSize) { + opcode == HloOpcode::kGetDimensionSize || + InstrIsSetBound(instr_proto)) { // No more operands to process, process self. int64 new_id = ++global_id; VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name(); @@ -3235,26 +3260,33 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph( LookUpInstructionByHandle(handle)); if (instr_proto->opcode() == - HloOpcodeString(HloOpcode::kGetDimensionSize)) { - // At this point, BuildConstantSubGraph should never encounter a - // GetDimensionSize with a dynamic dimension. IsConstant check would have - // failed at the beginning of this function. - // - // Replace GetDimensionSize with a Constant representing the static bound - // of the shape. - int64 dimension = instr_proto->dimensions(0); - int64 operand_handle = instr_proto->operand_ids(0); - TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, - LookUpInstructionByHandle(operand_handle)); + HloOpcodeString(HloOpcode::kGetDimensionSize) || + InstrIsSetBound(instr_proto)) { + int32 constant_value = -1; + if (instr_proto->opcode() == + HloOpcodeString(HloOpcode::kGetDimensionSize)) { + // At this point, BuildConstantSubGraph should never encounter a + // GetDimensionSize with a dynamic dimension. IsConstant check would + // have failed at the beginning of this function. + // + // Replace GetDimensionSize with a Constant representing the static + // bound of the shape. + int64 dimension = instr_proto->dimensions(0); + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); - int32 constant_dimension_size = -1; - if (!(operand_proto->shape().is_dynamic_dimension(dimension) && - dynamic_dimension_is_minus_one)) { - constant_dimension_size = - static_cast<int32>(operand_proto->shape().dimensions(dimension)); + if (!(operand_proto->shape().is_dynamic_dimension(dimension) && + dynamic_dimension_is_minus_one)) { + constant_value = + static_cast<int32>(operand_proto->shape().dimensions(dimension)); + } + } else { + TF_RET_CHECK( + absl::SimpleAtoi(instr_proto->backend_config(), &constant_value)); } - Literal literal = LiteralUtil::CreateR0(constant_dimension_size); + Literal literal = LiteralUtil::CreateR0(constant_value); HloInstructionProto const_instr; *const_instr.mutable_shape() = literal.shape().ToProto(); @@ -3286,6 +3318,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph( if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) { continue; } + if (InstrIsSetBound(instr_src)) { + continue; + } auto* instr = entry.add_instructions(); *instr = *instr_src; diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 9b4d24bbbe9..b4c56113239 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -179,6 +179,22 @@ StatusOr<bool> ReplaceSetSize(HloInstruction* instr) { return true; } +StatusOr<bool> ReplaceSetBound(HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kCustomCall || + instr->custom_call_target() != "SetBound") { + return false; + } + + TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()( + instr->shape(), instr->operand(0)->shape())) + << "instr->shape() " << instr->shape().ToString() << " , " + << "instruction operand shape " << instr->operand(0)->shape(); + HloInstruction* operand = instr->mutable_operand(0); + + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); + return true; +} + bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num, int64 dimension) { if ((inst->opcode() == HloOpcode::kReduceWindow || @@ -1370,7 +1386,10 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) { for (auto* computation : module->computations()) { for (auto instruction : computation->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction)); + TF_ASSIGN_OR_RETURN(bool replaced_set_bound, + ReplaceSetBound(instruction)); changed = changed || replaced_set_size; + changed = changed || replaced_set_bound; } }