Add xla.set_bound op.
For cases where we cannot infer the bound of a value, the compilation would fail. This gives user an escape patch. PiperOrigin-RevId: 329626655 Change-Id: Ib5d71054088692697eaf5f2b21c0c5d1a097f1eb
This commit is contained in:
parent
dfaa328f06
commit
9c703cc790
@ -2079,6 +2079,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"XlaSelectAndScatter",
|
||||
"XlaSelfAdjointEig",
|
||||
"XlaSend",
|
||||
"XlaSetBound",
|
||||
"XlaSharding",
|
||||
"XlaSort",
|
||||
"XlaSpmdFullToShardShape",
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
// XLA-specific Shape Ops.
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
@ -24,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -65,6 +67,47 @@ class ShapeOp : public XlaOpKernel {
|
||||
|
||||
REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
|
||||
|
||||
class XlaSetBoundOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit XlaSetBoundOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape("input");
|
||||
const TensorShape bound_shape = ctx->InputShape("bound");
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
ctx->InputType("bound") == DT_INT32 &&
|
||||
ctx->InputType("input") == DT_INT32,
|
||||
errors::InvalidArgument(
|
||||
"XlaSetBound can only set bound for int32 scalar value: got",
|
||||
input_shape.DebugString()));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, input_shape.dims() == 0,
|
||||
errors::InvalidArgument("XlaSetBound should only be used to set a "
|
||||
"bound to the an int32 scalar value: got",
|
||||
input_shape.DebugString()));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, bound_shape.dims() == 0,
|
||||
errors::InvalidArgument("XlaSetBound should only be used to set a "
|
||||
"bound to the an int32 scalar value: got",
|
||||
bound_shape.DebugString()));
|
||||
int64 bound;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound));
|
||||
|
||||
xla::XlaOp result = xla::CustomCall(
|
||||
ctx->builder(), "SetBound", {ctx->Input("input")},
|
||||
ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound));
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaSetBound").CompileTimeConstantInput("bound"),
|
||||
XlaSetBoundOp);
|
||||
|
||||
class ShapeNOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
|
@ -291,6 +291,16 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto.
|
||||
precision_config: a serialized xla::PrecisionConfig proto.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaSetBound")
|
||||
.Input("input: int32")
|
||||
.Input("bound: int32")
|
||||
.Output("output: int32")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(
|
||||
R"doc(Set a bound for the given input value as a hint to Xla compiler,
|
||||
returns the same value.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaDynamicSlice")
|
||||
.Input("input: T")
|
||||
.Input("start_indices: Tindices")
|
||||
|
@ -387,6 +387,14 @@ def reduce_window(operand,
|
||||
|
||||
replica_id = gen_xla_ops.xla_replica_id
|
||||
|
||||
# Set a static bound for the given input value as a hint to Xla compiler,
|
||||
# returns the same value.
|
||||
# Usage:
|
||||
# def f(t, p):
|
||||
# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3.
|
||||
# return t[:p] # xla knows the bound of the slice is 3.
|
||||
set_bound = gen_xla_ops.xla_set_bound
|
||||
|
||||
|
||||
def reshape(x, new_sizes, dimensions=None, name=None):
|
||||
if dimensions is not None:
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/span.h"
|
||||
@ -42,6 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -117,6 +119,15 @@ HloComputationProto CreateReduceOr(int64 reducer_id,
|
||||
}
|
||||
return reducer;
|
||||
}
|
||||
|
||||
bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
|
||||
HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
|
||||
if (opcode == HloOpcode::kCustomCall &&
|
||||
instr_proto->custom_call_target() == "SetBound") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
@ -293,7 +304,6 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
|
||||
// GetDimensionSize is always considered constant in XLA -- If a dynamic
|
||||
// dimension is presented, -1 is returned.
|
||||
break;
|
||||
|
||||
// Non functional ops.
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kAllReduce:
|
||||
@ -306,6 +316,11 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
|
||||
// cannot be constant. We cannot set is_functional=false in other similar
|
||||
// cases since we're already relying on IsConstant to return true.
|
||||
case HloOpcode::kCustomCall:
|
||||
if (instr.custom_call_target() == "SetBound") {
|
||||
// Set bound is considered constant -- the bound is used as the value.
|
||||
break;
|
||||
}
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case HloOpcode::kWhile:
|
||||
// TODO(b/32495713): We aren't checking the condition and body
|
||||
// computations themselves.
|
||||
@ -3086,6 +3101,15 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
case HloOpcode::kConstant:
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, false);
|
||||
break;
|
||||
case HloOpcode::kCustomCall:
|
||||
if (instr_proto->custom_call_target() == "SetBound") {
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||
break;
|
||||
} else {
|
||||
return InvalidArgument(
|
||||
"Dynamic inferencing on custom call %s is not supported",
|
||||
instr_proto->DebugString());
|
||||
}
|
||||
case HloOpcode::kParameter:
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||
break;
|
||||
@ -3149,7 +3173,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||
StringToHloOpcode(instr_proto->opcode()));
|
||||
if (next_operand >= instr_proto->operand_ids_size() ||
|
||||
opcode == HloOpcode::kGetDimensionSize) {
|
||||
opcode == HloOpcode::kGetDimensionSize ||
|
||||
InstrIsSetBound(instr_proto)) {
|
||||
// No more operands to process, process self.
|
||||
int64 new_id = ++global_id;
|
||||
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
|
||||
@ -3235,26 +3260,33 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
LookUpInstructionByHandle(handle));
|
||||
|
||||
if (instr_proto->opcode() ==
|
||||
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
|
||||
// At this point, BuildConstantSubGraph should never encounter a
|
||||
// GetDimensionSize with a dynamic dimension. IsConstant check would have
|
||||
// failed at the beginning of this function.
|
||||
//
|
||||
// Replace GetDimensionSize with a Constant representing the static bound
|
||||
// of the shape.
|
||||
int64 dimension = instr_proto->dimensions(0);
|
||||
int64 operand_handle = instr_proto->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
LookUpInstructionByHandle(operand_handle));
|
||||
HloOpcodeString(HloOpcode::kGetDimensionSize) ||
|
||||
InstrIsSetBound(instr_proto)) {
|
||||
int32 constant_value = -1;
|
||||
if (instr_proto->opcode() ==
|
||||
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
|
||||
// At this point, BuildConstantSubGraph should never encounter a
|
||||
// GetDimensionSize with a dynamic dimension. IsConstant check would
|
||||
// have failed at the beginning of this function.
|
||||
//
|
||||
// Replace GetDimensionSize with a Constant representing the static
|
||||
// bound of the shape.
|
||||
int64 dimension = instr_proto->dimensions(0);
|
||||
int64 operand_handle = instr_proto->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
LookUpInstructionByHandle(operand_handle));
|
||||
|
||||
int32 constant_dimension_size = -1;
|
||||
if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
|
||||
dynamic_dimension_is_minus_one)) {
|
||||
constant_dimension_size =
|
||||
static_cast<int32>(operand_proto->shape().dimensions(dimension));
|
||||
if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
|
||||
dynamic_dimension_is_minus_one)) {
|
||||
constant_value =
|
||||
static_cast<int32>(operand_proto->shape().dimensions(dimension));
|
||||
}
|
||||
} else {
|
||||
TF_RET_CHECK(
|
||||
absl::SimpleAtoi(instr_proto->backend_config(), &constant_value));
|
||||
}
|
||||
|
||||
Literal literal = LiteralUtil::CreateR0(constant_dimension_size);
|
||||
Literal literal = LiteralUtil::CreateR0(constant_value);
|
||||
|
||||
HloInstructionProto const_instr;
|
||||
*const_instr.mutable_shape() = literal.shape().ToProto();
|
||||
@ -3286,6 +3318,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) {
|
||||
continue;
|
||||
}
|
||||
if (InstrIsSetBound(instr_src)) {
|
||||
continue;
|
||||
}
|
||||
auto* instr = entry.add_instructions();
|
||||
|
||||
*instr = *instr_src;
|
||||
|
@ -179,6 +179,22 @@ StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> ReplaceSetBound(HloInstruction* instr) {
|
||||
if (instr->opcode() != HloOpcode::kCustomCall ||
|
||||
instr->custom_call_target() != "SetBound") {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
|
||||
instr->shape(), instr->operand(0)->shape()))
|
||||
<< "instr->shape() " << instr->shape().ToString() << " , "
|
||||
<< "instruction operand shape " << instr->operand(0)->shape();
|
||||
HloInstruction* operand = instr->mutable_operand(0);
|
||||
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
|
||||
int64 dimension) {
|
||||
if ((inst->opcode() == HloOpcode::kReduceWindow ||
|
||||
@ -1370,7 +1386,10 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
for (auto* computation : module->computations()) {
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
|
||||
TF_ASSIGN_OR_RETURN(bool replaced_set_bound,
|
||||
ReplaceSetBound(instruction));
|
||||
changed = changed || replaced_set_size;
|
||||
changed = changed || replaced_set_bound;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user