Add an op to create bounded dynamic dimensions.
This is mainly for writing tests to create a bound to a tensor -- The dynamic dimensions in real workloads should be handled automatically by the framework. PiperOrigin-RevId: 339061664 Change-Id: I23a00638280e9bbac7d923b78d9326cab0b04d5f
This commit is contained in:
parent
9bcee380d1
commit
6ca5914663
@ -2061,6 +2061,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"XlaSelfAdjointEig",
|
||||
"XlaSend",
|
||||
"XlaSetBound",
|
||||
"XlaSetDynamicDimensionSize",
|
||||
"XlaSharding",
|
||||
"XlaSort",
|
||||
"XlaSpmdFullToShardShape",
|
||||
|
||||
@ -108,6 +108,38 @@ class XlaSetBoundOp : public XlaOpKernel {
|
||||
REGISTER_XLA_OP(Name("XlaSetBound").CompileTimeConstantInput("bound"),
|
||||
XlaSetBoundOp);
|
||||
|
||||
class XlaSetDynamicDimensionSizeOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit XlaSetDynamicDimensionSizeOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape dim_index_shape = ctx->InputShape("dim_index");
|
||||
const TensorShape size_shape = ctx->InputShape("size");
|
||||
|
||||
OP_REQUIRES(ctx,
|
||||
ctx->InputType("dim_index") == DT_INT32 &&
|
||||
ctx->InputType("size") == DT_INT32,
|
||||
errors::InvalidArgument("dim_index and size has to be int32 for"
|
||||
"XlaSetDynamicDimensionSizeOp"));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, dim_index_shape.dims() == 0,
|
||||
errors::InvalidArgument("XlaSetDynamicDimensionSizeOp's dim_index and "
|
||||
"size has to be int32 scalar value"));
|
||||
int64 dim_index;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("dim_index", &dim_index));
|
||||
|
||||
xla::XlaOp result =
|
||||
xla::SetDimensionSize(ctx->Input(0), ctx->Input("size"), dim_index);
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(
|
||||
Name("XlaSetDynamicDimensionSize").CompileTimeConstantInput("dim_index"),
|
||||
XlaSetDynamicDimensionSizeOp);
|
||||
|
||||
class ShapeNOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
|
||||
@ -301,6 +301,18 @@ REGISTER_OP("XlaSetBound")
|
||||
returns the same value.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaSetDynamicDimensionSize")
|
||||
.Input("input: T")
|
||||
.Input("dim_index: int32")
|
||||
.Input("size: int32")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(
|
||||
R"doc(Make a static dimension into a xla bounded dynamic dimension.
|
||||
The current static dimension size will become the bound and the second
|
||||
operand becomes the dynamic size of the dimension.)doc");
|
||||
|
||||
REGISTER_OP("XlaDynamicSlice")
|
||||
.Input("input: T")
|
||||
.Input("start_indices: Tindices")
|
||||
|
||||
@ -399,6 +399,21 @@ replica_id = gen_xla_ops.xla_replica_id
|
||||
set_bound = gen_xla_ops.xla_set_bound
|
||||
|
||||
|
||||
# Make a static dimension into a xla bounded dynamic dimension. The current
|
||||
# static dimension size will become the bound and the second operand becomes the
|
||||
# dynamic size of the dimension.
|
||||
#
|
||||
# This should mostly be used for testing.
|
||||
#
|
||||
# def f():
|
||||
# array = tf.convert_to_tensor([[1, 2, 3, 4, 5]])
|
||||
# # Tells xla the valid size of the array is 3.
|
||||
# dim = 0
|
||||
# p = xla_set_dynamic_dimension_size(array, dim, 3)
|
||||
# assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid.
|
||||
set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size
|
||||
|
||||
|
||||
def reshape(x, new_sizes, dimensions=None, name=None):
|
||||
if dimensions is not None:
|
||||
x = array_ops.transpose(x, dimensions)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user