Add an op to create bounded dynamic dimensions.

This is mainly for writing tests to create a bound to a tensor -- The dynamic dimensions in real workloads should be handled automatically by the framework.

PiperOrigin-RevId: 339061664
Change-Id: I23a00638280e9bbac7d923b78d9326cab0b04d5f
This commit is contained in:
Yunxing Dai 2020-10-26 10:05:52 -07:00 committed by TensorFlower Gardener
parent 9bcee380d1
commit 6ca5914663
4 changed files with 60 additions and 0 deletions

View File

@ -2061,6 +2061,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"XlaSelfAdjointEig",
"XlaSend",
"XlaSetBound",
"XlaSetDynamicDimensionSize",
"XlaSharding",
"XlaSort",
"XlaSpmdFullToShardShape",

View File

@ -108,6 +108,38 @@ class XlaSetBoundOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("XlaSetBound").CompileTimeConstantInput("bound"),
XlaSetBoundOp);
class XlaSetDynamicDimensionSizeOp : public XlaOpKernel {
public:
explicit XlaSetDynamicDimensionSizeOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape dim_index_shape = ctx->InputShape("dim_index");
const TensorShape size_shape = ctx->InputShape("size");
OP_REQUIRES(ctx,
ctx->InputType("dim_index") == DT_INT32 &&
ctx->InputType("size") == DT_INT32,
errors::InvalidArgument("dim_index and size has to be int32 for"
"XlaSetDynamicDimensionSizeOp"));
OP_REQUIRES(
ctx, dim_index_shape.dims() == 0,
errors::InvalidArgument("XlaSetDynamicDimensionSizeOp's dim_index and "
"size has to be int32 scalar value"));
int64 dim_index;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("dim_index", &dim_index));
xla::XlaOp result =
xla::SetDimensionSize(ctx->Input(0), ctx->Input("size"), dim_index);
ctx->SetOutput(0, result);
}
};
REGISTER_XLA_OP(
Name("XlaSetDynamicDimensionSize").CompileTimeConstantInput("dim_index"),
XlaSetDynamicDimensionSizeOp);
class ShapeNOp : public XlaOpKernel {
public:
explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {

View File

@ -301,6 +301,18 @@ REGISTER_OP("XlaSetBound")
returns the same value.
)doc");
REGISTER_OP("XlaSetDynamicDimensionSize")
.Input("input: T")
.Input("dim_index: int32")
.Input("size: int32")
.Output("output: T")
.Attr("T: type")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(
R"doc(Make a static dimension into a xla bounded dynamic dimension.
The current static dimension size will become the bound and the second
operand becomes the dynamic size of the dimension.)doc");
REGISTER_OP("XlaDynamicSlice")
.Input("input: T")
.Input("start_indices: Tindices")

View File

@ -399,6 +399,21 @@ replica_id = gen_xla_ops.xla_replica_id
set_bound = gen_xla_ops.xla_set_bound
# Make a static dimension into a xla bounded dynamic dimension. The current
# static dimension size will become the bound and the second operand becomes the
# dynamic size of the dimension.
#
# This should mostly be used for testing.
#
# def f():
# array = tf.convert_to_tensor([[1, 2, 3, 4, 5]])
# # Tells xla the valid size of the array is 3.
# dim = 0
# p = xla_set_dynamic_dimension_size(array, dim, 3)
# assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid.
set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size
def reshape(x, new_sizes, dimensions=None, name=None):
if dimensions is not None:
x = array_ops.transpose(x, dimensions)