Support dynamic sample size in categorical op.
- Support dynamic sample size in categorical op. - Change dynamic tf.range to use ResolveInputDynamismIntoPred as it gets more mature now. This is more elegant and works better with tf2xla.set_bound API. - Add a testcase. PiperOrigin-RevId: 346413784 Change-Id: Ia76717f678e1b6e306c0cb902287a1cd91779626
This commit is contained in:
parent
b196017e8a
commit
c9cf71a50d
@ -91,6 +91,13 @@ class CategoricalOp : public XlaOpKernel {
|
||||
xla::PrimitiveType type;
|
||||
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type));
|
||||
xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx);
|
||||
bool num_samples_is_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPred(1, &num_samples_is_dynamic));
|
||||
if (num_samples_is_dynamic && num_samples != 1) {
|
||||
// Number samples is dimension 1 in uniform_shape_array.
|
||||
log_uniforms = xla::SetDimensionSize(log_uniforms, ctx->Input(1), 1);
|
||||
}
|
||||
|
||||
// Use Gumbel softmax trick to generate categorical samples.
|
||||
// See:
|
||||
|
||||
@ -110,16 +110,15 @@ class RangeOp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, output.status());
|
||||
|
||||
if (type == DT_INT32 || type == DT_INT64) {
|
||||
// If input has dynamic dimension (value is -1), propagate the dynamic
|
||||
// dimension to output using set-dimension-size.
|
||||
ctx->set_dynamic_dimension_is_minus_one(true);
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit));
|
||||
bool limit_is_dynamic = false;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ResolveInputDynamismIntoPred(1, &limit_is_dynamic));
|
||||
if (type == DT_INT32) {
|
||||
if (limit.Get<int32>({}) == -1) {
|
||||
if (limit_is_dynamic) {
|
||||
output = xla::SetDimensionSize(output.ValueOrDie(), ctx->Input(1), 0);
|
||||
}
|
||||
} else {
|
||||
if (limit.Get<int64>({}) == -1) {
|
||||
if (limit_is_dynamic) {
|
||||
output = xla::SetDimensionSize(
|
||||
output.ValueOrDie(),
|
||||
xla::ConvertElementType(ctx->Input(1), xla::S32), 0);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user