Support dynamic sample size in categorical op.

- Support dynamic sample size in categorical op.
- Change dynamic tf.range to use ResolveInputDynamismIntoPred as it gets more mature now. This is more elegant and works better with tf2xla.set_bound API.
- Add a testcase.

PiperOrigin-RevId: 346413784
Change-Id: Ia76717f678e1b6e306c0cb902287a1cd91779626
This commit is contained in:
Yunxing Dai 2020-12-08 14:40:22 -08:00 committed by TensorFlower Gardener
parent b196017e8a
commit c9cf71a50d
2 changed files with 12 additions and 6 deletions

View File

@ -91,6 +91,13 @@ class CategoricalOp : public XlaOpKernel {
xla::PrimitiveType type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type));
xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx);
bool num_samples_is_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPred(1, &num_samples_is_dynamic));
if (num_samples_is_dynamic && num_samples != 1) {
// Number samples is dimension 1 in uniform_shape_array.
log_uniforms = xla::SetDimensionSize(log_uniforms, ctx->Input(1), 1);
}
// Use Gumbel softmax trick to generate categorical samples.
// See:

View File

@ -110,16 +110,15 @@ class RangeOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, output.status());
if (type == DT_INT32 || type == DT_INT64) {
// If input has dynamic dimension (value is -1), propagate the dynamic
// dimension to output using set-dimension-size.
ctx->set_dynamic_dimension_is_minus_one(true);
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit));
bool limit_is_dynamic = false;
OP_REQUIRES_OK(ctx,
ctx->ResolveInputDynamismIntoPred(1, &limit_is_dynamic));
if (type == DT_INT32) {
if (limit.Get<int32>({}) == -1) {
if (limit_is_dynamic) {
output = xla::SetDimensionSize(output.ValueOrDie(), ctx->Input(1), 0);
}
} else {
if (limit.Get<int64>({}) == -1) {
if (limit_is_dynamic) {
output = xla::SetDimensionSize(
output.ValueOrDie(),
xla::ConvertElementType(ctx->Input(1), xla::S32), 0);