Add xla.set_bound op.

For cases where we cannot infer the bound of a value, the compilation would fail. This gives user an escape patch.

PiperOrigin-RevId: 329626655
Change-Id: Ib5d71054088692697eaf5f2b21c0c5d1a097f1eb
This commit is contained in:
Yunxing Dai 2020-09-01 19:04:14 -07:00 committed by TensorFlower Gardener
parent dfaa328f06
commit 9c703cc790
6 changed files with 135 additions and 19 deletions

View File

@ -2079,6 +2079,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"XlaSelectAndScatter",
"XlaSelfAdjointEig",
"XlaSend",
"XlaSetBound",
"XlaSharding",
"XlaSort",
"XlaSpmdFullToShardShape",

View File

@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific Shape Ops.
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -65,6 +67,47 @@ class ShapeOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
class XlaSetBoundOp : public XlaOpKernel {
public:
explicit XlaSetBoundOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape("input");
const TensorShape bound_shape = ctx->InputShape("bound");
OP_REQUIRES(
ctx,
ctx->InputType("bound") == DT_INT32 &&
ctx->InputType("input") == DT_INT32,
errors::InvalidArgument(
"XlaSetBound can only set bound for int32 scalar value: got",
input_shape.DebugString()));
OP_REQUIRES(
ctx, input_shape.dims() == 0,
errors::InvalidArgument("XlaSetBound should only be used to set a "
"bound to the an int32 scalar value: got",
input_shape.DebugString()));
OP_REQUIRES(
ctx, bound_shape.dims() == 0,
errors::InvalidArgument("XlaSetBound should only be used to set a "
"bound to the an int32 scalar value: got",
bound_shape.DebugString()));
int64 bound;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound));
xla::XlaOp result = xla::CustomCall(
ctx->builder(), "SetBound", {ctx->Input("input")},
ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound));
ctx->SetOutput(0, result);
}
};
REGISTER_XLA_OP(Name("XlaSetBound").CompileTimeConstantInput("bound"),
XlaSetBoundOp);
class ShapeNOp : public XlaOpKernel {
public:
explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {

View File

@ -291,6 +291,16 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
)doc");
REGISTER_OP("XlaSetBound")
.Input("input: int32")
.Input("bound: int32")
.Output("output: int32")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(
R"doc(Set a bound for the given input value as a hint to Xla compiler,
returns the same value.
)doc");
REGISTER_OP("XlaDynamicSlice")
.Input("input: T")
.Input("start_indices: Tindices")

View File

@ -387,6 +387,14 @@ def reduce_window(operand,
replica_id = gen_xla_ops.xla_replica_id
# Set a static bound for the given input value as a hint to Xla compiler,
# returns the same value.
# Usage:
# def f(t, p):
# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3.
# return t[:p] # xla knows the bound of the slice is 3.
set_bound = gen_xla_ops.xla_set_bound
def reshape(x, new_sizes, dimensions=None, name=None):
if dimensions is not None:

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
@ -42,6 +43,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@ -117,6 +119,15 @@ HloComputationProto CreateReduceOr(int64 reducer_id,
}
return reducer;
}
bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
if (opcode == HloOpcode::kCustomCall &&
instr_proto->custom_call_target() == "SetBound") {
return true;
}
return false;
}
} // namespace
namespace internal {
@ -293,7 +304,6 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
// GetDimensionSize is always considered constant in XLA -- If a dynamic
// dimension is presented, -1 is returned.
break;
// Non functional ops.
case HloOpcode::kRng:
case HloOpcode::kAllReduce:
@ -306,6 +316,11 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
// cannot be constant. We cannot set is_functional=false in other similar
// cases since we're already relying on IsConstant to return true.
case HloOpcode::kCustomCall:
if (instr.custom_call_target() == "SetBound") {
// Set bound is considered constant -- the bound is used as the value.
break;
}
TF_FALLTHROUGH_INTENDED;
case HloOpcode::kWhile:
// TODO(b/32495713): We aren't checking the condition and body
// computations themselves.
@ -3086,6 +3101,15 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
case HloOpcode::kConstant:
SetInstructionAsConstant(new_instr, id, new_shape, false);
break;
case HloOpcode::kCustomCall:
if (instr_proto->custom_call_target() == "SetBound") {
SetInstructionAsConstant(new_instr, id, new_shape, true);
break;
} else {
return InvalidArgument(
"Dynamic inferencing on custom call %s is not supported",
instr_proto->DebugString());
}
case HloOpcode::kParameter:
SetInstructionAsConstant(new_instr, id, new_shape, true);
break;
@ -3149,7 +3173,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instr_proto->opcode()));
if (next_operand >= instr_proto->operand_ids_size() ||
opcode == HloOpcode::kGetDimensionSize) {
opcode == HloOpcode::kGetDimensionSize ||
InstrIsSetBound(instr_proto)) {
// No more operands to process, process self.
int64 new_id = ++global_id;
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
@ -3235,26 +3260,33 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
LookUpInstructionByHandle(handle));
if (instr_proto->opcode() ==
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
// At this point, BuildConstantSubGraph should never encounter a
// GetDimensionSize with a dynamic dimension. IsConstant check would have
// failed at the beginning of this function.
//
// Replace GetDimensionSize with a Constant representing the static bound
// of the shape.
int64 dimension = instr_proto->dimensions(0);
int64 operand_handle = instr_proto->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
HloOpcodeString(HloOpcode::kGetDimensionSize) ||
InstrIsSetBound(instr_proto)) {
int32 constant_value = -1;
if (instr_proto->opcode() ==
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
// At this point, BuildConstantSubGraph should never encounter a
// GetDimensionSize with a dynamic dimension. IsConstant check would
// have failed at the beginning of this function.
//
// Replace GetDimensionSize with a Constant representing the static
// bound of the shape.
int64 dimension = instr_proto->dimensions(0);
int64 operand_handle = instr_proto->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
int32 constant_dimension_size = -1;
if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
dynamic_dimension_is_minus_one)) {
constant_dimension_size =
static_cast<int32>(operand_proto->shape().dimensions(dimension));
if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
dynamic_dimension_is_minus_one)) {
constant_value =
static_cast<int32>(operand_proto->shape().dimensions(dimension));
}
} else {
TF_RET_CHECK(
absl::SimpleAtoi(instr_proto->backend_config(), &constant_value));
}
Literal literal = LiteralUtil::CreateR0(constant_dimension_size);
Literal literal = LiteralUtil::CreateR0(constant_value);
HloInstructionProto const_instr;
*const_instr.mutable_shape() = literal.shape().ToProto();
@ -3286,6 +3318,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) {
continue;
}
if (InstrIsSetBound(instr_src)) {
continue;
}
auto* instr = entry.add_instructions();
*instr = *instr_src;

View File

@ -179,6 +179,22 @@ StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
return true;
}
StatusOr<bool> ReplaceSetBound(HloInstruction* instr) {
if (instr->opcode() != HloOpcode::kCustomCall ||
instr->custom_call_target() != "SetBound") {
return false;
}
TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
instr->shape(), instr->operand(0)->shape()))
<< "instr->shape() " << instr->shape().ToString() << " , "
<< "instruction operand shape " << instr->operand(0)->shape();
HloInstruction* operand = instr->mutable_operand(0);
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
return true;
}
bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
int64 dimension) {
if ((inst->opcode() == HloOpcode::kReduceWindow ||
@ -1370,7 +1386,10 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
for (auto* computation : module->computations()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
TF_ASSIGN_OR_RETURN(bool replaced_set_bound,
ReplaceSetBound(instruction));
changed = changed || replaced_set_size;
changed = changed || replaced_set_bound;
}
}