Add xla.set_bound op.
For cases where we cannot infer the bound of a value, the compilation would fail. This gives user an escape patch. PiperOrigin-RevId: 329626655 Change-Id: Ib5d71054088692697eaf5f2b21c0c5d1a097f1eb
This commit is contained in:
parent
dfaa328f06
commit
9c703cc790
@ -2079,6 +2079,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
|||||||
"XlaSelectAndScatter",
|
"XlaSelectAndScatter",
|
||||||
"XlaSelfAdjointEig",
|
"XlaSelfAdjointEig",
|
||||||
"XlaSend",
|
"XlaSend",
|
||||||
|
"XlaSetBound",
|
||||||
"XlaSharding",
|
"XlaSharding",
|
||||||
"XlaSort",
|
"XlaSort",
|
||||||
"XlaSpmdFullToShardShape",
|
"XlaSpmdFullToShardShape",
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
// XLA-specific Shape Ops.
|
// XLA-specific Shape Ops.
|
||||||
|
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
@ -24,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
@ -65,6 +67,47 @@ class ShapeOp : public XlaOpKernel {
|
|||||||
|
|
||||||
REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
|
REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
|
||||||
|
|
||||||
|
class XlaSetBoundOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit XlaSetBoundOp(OpKernelConstruction* context)
|
||||||
|
: XlaOpKernel(context) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
const TensorShape input_shape = ctx->InputShape("input");
|
||||||
|
const TensorShape bound_shape = ctx->InputShape("bound");
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
ctx->InputType("bound") == DT_INT32 &&
|
||||||
|
ctx->InputType("input") == DT_INT32,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"XlaSetBound can only set bound for int32 scalar value: got",
|
||||||
|
input_shape.DebugString()));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, input_shape.dims() == 0,
|
||||||
|
errors::InvalidArgument("XlaSetBound should only be used to set a "
|
||||||
|
"bound to the an int32 scalar value: got",
|
||||||
|
input_shape.DebugString()));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, bound_shape.dims() == 0,
|
||||||
|
errors::InvalidArgument("XlaSetBound should only be used to set a "
|
||||||
|
"bound to the an int32 scalar value: got",
|
||||||
|
bound_shape.DebugString()));
|
||||||
|
int64 bound;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound));
|
||||||
|
|
||||||
|
xla::XlaOp result = xla::CustomCall(
|
||||||
|
ctx->builder(), "SetBound", {ctx->Input("input")},
|
||||||
|
ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound));
|
||||||
|
ctx->SetOutput(0, result);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("XlaSetBound").CompileTimeConstantInput("bound"),
|
||||||
|
XlaSetBoundOp);
|
||||||
|
|
||||||
class ShapeNOp : public XlaOpKernel {
|
class ShapeNOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
@ -291,6 +291,16 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto.
|
|||||||
precision_config: a serialized xla::PrecisionConfig proto.
|
precision_config: a serialized xla::PrecisionConfig proto.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("XlaSetBound")
|
||||||
|
.Input("input: int32")
|
||||||
|
.Input("bound: int32")
|
||||||
|
.Output("output: int32")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
|
.Doc(
|
||||||
|
R"doc(Set a bound for the given input value as a hint to Xla compiler,
|
||||||
|
returns the same value.
|
||||||
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("XlaDynamicSlice")
|
REGISTER_OP("XlaDynamicSlice")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Input("start_indices: Tindices")
|
.Input("start_indices: Tindices")
|
||||||
|
@ -387,6 +387,14 @@ def reduce_window(operand,
|
|||||||
|
|
||||||
replica_id = gen_xla_ops.xla_replica_id
|
replica_id = gen_xla_ops.xla_replica_id
|
||||||
|
|
||||||
|
# Set a static bound for the given input value as a hint to Xla compiler,
|
||||||
|
# returns the same value.
|
||||||
|
# Usage:
|
||||||
|
# def f(t, p):
|
||||||
|
# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3.
|
||||||
|
# return t[:p] # xla knows the bound of the slice is 3.
|
||||||
|
set_bound = gen_xla_ops.xla_set_bound
|
||||||
|
|
||||||
|
|
||||||
def reshape(x, new_sizes, dimensions=None, name=None):
|
def reshape(x, new_sizes, dimensions=None, name=None):
|
||||||
if dimensions is not None:
|
if dimensions is not None:
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
|
#include "absl/strings/numbers.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
@ -42,6 +43,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -117,6 +119,15 @@ HloComputationProto CreateReduceOr(int64 reducer_id,
|
|||||||
}
|
}
|
||||||
return reducer;
|
return reducer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
|
||||||
|
HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
|
||||||
|
if (opcode == HloOpcode::kCustomCall &&
|
||||||
|
instr_proto->custom_call_target() == "SetBound") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
@ -293,7 +304,6 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
|
|||||||
// GetDimensionSize is always considered constant in XLA -- If a dynamic
|
// GetDimensionSize is always considered constant in XLA -- If a dynamic
|
||||||
// dimension is presented, -1 is returned.
|
// dimension is presented, -1 is returned.
|
||||||
break;
|
break;
|
||||||
|
|
||||||
// Non functional ops.
|
// Non functional ops.
|
||||||
case HloOpcode::kRng:
|
case HloOpcode::kRng:
|
||||||
case HloOpcode::kAllReduce:
|
case HloOpcode::kAllReduce:
|
||||||
@ -306,6 +316,11 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
|
|||||||
// cannot be constant. We cannot set is_functional=false in other similar
|
// cannot be constant. We cannot set is_functional=false in other similar
|
||||||
// cases since we're already relying on IsConstant to return true.
|
// cases since we're already relying on IsConstant to return true.
|
||||||
case HloOpcode::kCustomCall:
|
case HloOpcode::kCustomCall:
|
||||||
|
if (instr.custom_call_target() == "SetBound") {
|
||||||
|
// Set bound is considered constant -- the bound is used as the value.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
TF_FALLTHROUGH_INTENDED;
|
||||||
case HloOpcode::kWhile:
|
case HloOpcode::kWhile:
|
||||||
// TODO(b/32495713): We aren't checking the condition and body
|
// TODO(b/32495713): We aren't checking the condition and body
|
||||||
// computations themselves.
|
// computations themselves.
|
||||||
@ -3086,6 +3101,15 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
case HloOpcode::kConstant:
|
case HloOpcode::kConstant:
|
||||||
SetInstructionAsConstant(new_instr, id, new_shape, false);
|
SetInstructionAsConstant(new_instr, id, new_shape, false);
|
||||||
break;
|
break;
|
||||||
|
case HloOpcode::kCustomCall:
|
||||||
|
if (instr_proto->custom_call_target() == "SetBound") {
|
||||||
|
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
return InvalidArgument(
|
||||||
|
"Dynamic inferencing on custom call %s is not supported",
|
||||||
|
instr_proto->DebugString());
|
||||||
|
}
|
||||||
case HloOpcode::kParameter:
|
case HloOpcode::kParameter:
|
||||||
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||||
break;
|
break;
|
||||||
@ -3149,7 +3173,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||||
StringToHloOpcode(instr_proto->opcode()));
|
StringToHloOpcode(instr_proto->opcode()));
|
||||||
if (next_operand >= instr_proto->operand_ids_size() ||
|
if (next_operand >= instr_proto->operand_ids_size() ||
|
||||||
opcode == HloOpcode::kGetDimensionSize) {
|
opcode == HloOpcode::kGetDimensionSize ||
|
||||||
|
InstrIsSetBound(instr_proto)) {
|
||||||
// No more operands to process, process self.
|
// No more operands to process, process self.
|
||||||
int64 new_id = ++global_id;
|
int64 new_id = ++global_id;
|
||||||
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
|
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
|
||||||
@ -3234,27 +3259,34 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
|||||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
|
||||||
LookUpInstructionByHandle(handle));
|
LookUpInstructionByHandle(handle));
|
||||||
|
|
||||||
|
if (instr_proto->opcode() ==
|
||||||
|
HloOpcodeString(HloOpcode::kGetDimensionSize) ||
|
||||||
|
InstrIsSetBound(instr_proto)) {
|
||||||
|
int32 constant_value = -1;
|
||||||
if (instr_proto->opcode() ==
|
if (instr_proto->opcode() ==
|
||||||
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
|
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
|
||||||
// At this point, BuildConstantSubGraph should never encounter a
|
// At this point, BuildConstantSubGraph should never encounter a
|
||||||
// GetDimensionSize with a dynamic dimension. IsConstant check would have
|
// GetDimensionSize with a dynamic dimension. IsConstant check would
|
||||||
// failed at the beginning of this function.
|
// have failed at the beginning of this function.
|
||||||
//
|
//
|
||||||
// Replace GetDimensionSize with a Constant representing the static bound
|
// Replace GetDimensionSize with a Constant representing the static
|
||||||
// of the shape.
|
// bound of the shape.
|
||||||
int64 dimension = instr_proto->dimensions(0);
|
int64 dimension = instr_proto->dimensions(0);
|
||||||
int64 operand_handle = instr_proto->operand_ids(0);
|
int64 operand_handle = instr_proto->operand_ids(0);
|
||||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||||
LookUpInstructionByHandle(operand_handle));
|
LookUpInstructionByHandle(operand_handle));
|
||||||
|
|
||||||
int32 constant_dimension_size = -1;
|
|
||||||
if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
|
if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
|
||||||
dynamic_dimension_is_minus_one)) {
|
dynamic_dimension_is_minus_one)) {
|
||||||
constant_dimension_size =
|
constant_value =
|
||||||
static_cast<int32>(operand_proto->shape().dimensions(dimension));
|
static_cast<int32>(operand_proto->shape().dimensions(dimension));
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
TF_RET_CHECK(
|
||||||
|
absl::SimpleAtoi(instr_proto->backend_config(), &constant_value));
|
||||||
|
}
|
||||||
|
|
||||||
Literal literal = LiteralUtil::CreateR0(constant_dimension_size);
|
Literal literal = LiteralUtil::CreateR0(constant_value);
|
||||||
|
|
||||||
HloInstructionProto const_instr;
|
HloInstructionProto const_instr;
|
||||||
*const_instr.mutable_shape() = literal.shape().ToProto();
|
*const_instr.mutable_shape() = literal.shape().ToProto();
|
||||||
@ -3286,6 +3318,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
|||||||
if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) {
|
if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (InstrIsSetBound(instr_src)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto* instr = entry.add_instructions();
|
auto* instr = entry.add_instructions();
|
||||||
|
|
||||||
*instr = *instr_src;
|
*instr = *instr_src;
|
||||||
|
@ -179,6 +179,22 @@ StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<bool> ReplaceSetBound(HloInstruction* instr) {
|
||||||
|
if (instr->opcode() != HloOpcode::kCustomCall ||
|
||||||
|
instr->custom_call_target() != "SetBound") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
|
||||||
|
instr->shape(), instr->operand(0)->shape()))
|
||||||
|
<< "instr->shape() " << instr->shape().ToString() << " , "
|
||||||
|
<< "instruction operand shape " << instr->operand(0)->shape();
|
||||||
|
HloInstruction* operand = instr->mutable_operand(0);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
|
bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
|
||||||
int64 dimension) {
|
int64 dimension) {
|
||||||
if ((inst->opcode() == HloOpcode::kReduceWindow ||
|
if ((inst->opcode() == HloOpcode::kReduceWindow ||
|
||||||
@ -1370,7 +1386,10 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
|||||||
for (auto* computation : module->computations()) {
|
for (auto* computation : module->computations()) {
|
||||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||||
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
|
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
|
||||||
|
TF_ASSIGN_OR_RETURN(bool replaced_set_bound,
|
||||||
|
ReplaceSetBound(instruction));
|
||||||
changed = changed || replaced_set_size;
|
changed = changed || replaced_set_size;
|
||||||
|
changed = changed || replaced_set_bound;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user