[XLA] Several improvements to dynamic padder.
- Support partial slice on dynamic dimensions -- this is achieved by letting the client to set the dynamic dimension after building a xla slice. - Support dynamic pad on padded dimension. - Fix a terrible bug exposed by rxsang's experiment where transpose creates wrong dynamic dimension. PiperOrigin-RevId: 303033314 Change-Id: Id76f4619d12e88b8c0b3e9ec75baa1d78d1a7270
This commit is contained in:
parent
2d055a4226
commit
b3212dd802
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/ops_util.h"
|
#include "tensorflow/core/framework/ops_util.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -58,18 +59,21 @@ class SliceOp : public XlaOpKernel {
|
|||||||
std::vector<int64> begin;
|
std::vector<int64> begin;
|
||||||
std::vector<int64> size;
|
std::vector<int64> size;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size));
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size));
|
||||||
|
std::vector<int64> wrapped_size(size.size());
|
||||||
if (ctx->ConstantInputAsIntVector(1, &begin).ok()) {
|
if (ctx->ConstantInputAsIntVector(1, &begin).ok()) {
|
||||||
// `begin` is a compile-time constant.
|
// `begin` is a compile-time constant.
|
||||||
for (int i = 0; i < input_dims; ++i) {
|
for (int i = 0; i < input_dims; ++i) {
|
||||||
if (size[i] == -1) {
|
if (size[i] == -1) {
|
||||||
// A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
|
// A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
|
||||||
size[i] = input_shape.dim_size(i) - begin[i];
|
wrapped_size[i] = input_shape.dim_size(i) - begin[i];
|
||||||
|
} else {
|
||||||
|
wrapped_size[i] = size[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < input_dims; ++i) {
|
for (int i = 0; i < input_dims; ++i) {
|
||||||
int64 b = begin[i];
|
int64 b = begin[i];
|
||||||
int64 s = size[i];
|
int64 s = wrapped_size[i];
|
||||||
if (input_shape.dim_size(i) == 0) {
|
if (input_shape.dim_size(i) == 0) {
|
||||||
OP_REQUIRES(ctx, b == 0 && s == 0,
|
OP_REQUIRES(ctx, b == 0 && s == 0,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -91,10 +95,28 @@ class SliceOp : public XlaOpKernel {
|
|||||||
std::vector<int64> limits;
|
std::vector<int64> limits;
|
||||||
limits.reserve(begin.size());
|
limits.reserve(begin.size());
|
||||||
for (int i = 0; i < begin.size(); ++i) {
|
for (int i = 0; i < begin.size(); ++i) {
|
||||||
limits.push_back(begin[i] + size[i]);
|
limits.push_back(begin[i] + wrapped_size[i]);
|
||||||
}
|
}
|
||||||
std::vector<int64> strides(begin.size(), 1);
|
std::vector<int64> strides(begin.size(), 1);
|
||||||
ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides));
|
auto slice = xla::Slice(ctx->Input(0), begin, limits, strides);
|
||||||
|
// Check for slice on dynamic dimensions.
|
||||||
|
ctx->set_dynamic_dimension_is_minus_one(true);
|
||||||
|
std::vector<int64> dynamic_size;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &dynamic_size));
|
||||||
|
|
||||||
|
for (int64 i = 0; i < size.size(); ++i) {
|
||||||
|
if (dynamic_size[i] == -1) {
|
||||||
|
if (size[i] != -1) {
|
||||||
|
// If there is a dynamic dimension, properly set dimension size of
|
||||||
|
// the slice.
|
||||||
|
auto dynamic_size =
|
||||||
|
xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {});
|
||||||
|
|
||||||
|
slice = xla::SetDimensionSize(slice, dynamic_size, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx->SetOutput(0, slice);
|
||||||
} else {
|
} else {
|
||||||
// `begin` is not a compile-time constant.
|
// `begin` is not a compile-time constant.
|
||||||
for (int i = 0; i < input_dims; ++i) {
|
for (int i = 0; i < input_dims; ++i) {
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/mem.h"
|
#include "tensorflow/core/platform/mem.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -115,6 +116,72 @@ class StridedSliceOp : public XlaOpKernel {
|
|||||||
slice = xla::Rev(slice, dimensions_to_reverse);
|
slice = xla::Rev(slice, dimensions_to_reverse);
|
||||||
}
|
}
|
||||||
slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
|
slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
|
||||||
|
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
|
||||||
|
OP_REQUIRES_OK(ctx, operand_shape_or.status());
|
||||||
|
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
|
||||||
|
if (xla_shape.is_static()) {
|
||||||
|
// Static output shape, return a static slice.
|
||||||
|
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||||
|
ctx->SetOutput(0, slice);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto input_dim_sizes = input_shape.dim_sizes();
|
||||||
|
|
||||||
|
for (int64 i = 0; i < xla_shape.rank(); ++i) {
|
||||||
|
if (xla_shape.is_dynamic_dimension(i)) {
|
||||||
|
input_dim_sizes[i] = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PartialTensorShape input_partial_shape(input_dim_sizes);
|
||||||
|
partial_final_shape.Clear();
|
||||||
|
end.clear();
|
||||||
|
strides.clear();
|
||||||
|
begin.clear();
|
||||||
|
// Run shape inferenference again with partial shape.
|
||||||
|
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
|
||||||
|
&begin_tensor, &end_tensor, strides_tensor,
|
||||||
|
input_partial_shape, begin_mask_, end_mask_,
|
||||||
|
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||||
|
&dummy_processing_shape, &partial_final_shape,
|
||||||
|
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
||||||
|
if (partial_final_shape.AsTensorShape(&final_shape)) {
|
||||||
|
// Static output shape, return a static slice.
|
||||||
|
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||||
|
ctx->SetOutput(0, slice);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We consider slicing a dynamic tensor t with negative indices as a
|
||||||
|
// dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n
|
||||||
|
for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
|
||||||
|
bool dynamic_dim = partial_final_shape.dim_size(i) - 1;
|
||||||
|
bool backward_slice = end[i] < 0;
|
||||||
|
if (dynamic_dim && backward_slice) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, strides[i] == 1,
|
||||||
|
errors::InvalidArgument("XLA has not implemented dynamic "
|
||||||
|
"sized slice with non-trival stride yet. "
|
||||||
|
"Please file a bug against XLA"));
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, begin[i] >= 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"XLA has not implemented dynamic "
|
||||||
|
"sized slice with negative begin index %lld. "
|
||||||
|
"Please file a bug against XLA",
|
||||||
|
begin[i]));
|
||||||
|
// If there is a dynamic dimension, properly set dimension size of
|
||||||
|
// the result.
|
||||||
|
auto operand_size = xla::GetDimensionSize(ctx->Input(0), i);
|
||||||
|
|
||||||
|
operand_size = xla::Add(
|
||||||
|
operand_size, xla::ConstantR0<int32>(ctx->builder(), end[i]));
|
||||||
|
slice = xla::SetDimensionSize(
|
||||||
|
slice,
|
||||||
|
xla::Sub(operand_size,
|
||||||
|
xla::ConstantR0<int32>(ctx->builder(), begin[i])),
|
||||||
|
i);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// When output shape is fully defined, it must be a size one slice:
|
// When output shape is fully defined, it must be a size one slice:
|
||||||
//
|
//
|
||||||
|
@ -2404,6 +2404,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:window_util",
|
"//tensorflow/compiler/xla:window_util",
|
||||||
"//tensorflow/core/platform:macros",
|
"//tensorflow/core/platform:macros",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/window_util.h"
|
#include "tensorflow/compiler/xla/window_util.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -250,15 +251,25 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
|
|||||||
}
|
}
|
||||||
const PaddingConfig_PaddingConfigDimension& padding_config =
|
const PaddingConfig_PaddingConfigDimension& padding_config =
|
||||||
hlo->padding_config().dimensions(dimension);
|
hlo->padding_config().dimensions(dimension);
|
||||||
if (padding_config.interior_padding() == 0 &&
|
if (padding_config.interior_padding() == 0) {
|
||||||
padding_config.edge_padding_low() == 0 &&
|
HloInstruction* dynamic_size_adjusted = dynamic_size;
|
||||||
padding_config.edge_padding_high() == 0) {
|
HloInstruction* adjustment = hlo->parent()->AddInstruction(
|
||||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
|
||||||
|
padding_config.edge_padding_low() +
|
||||||
|
padding_config.edge_padding_high())));
|
||||||
|
dynamic_size_adjusted =
|
||||||
|
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
dynamic_size_adjusted->shape(), HloOpcode::kAdd,
|
||||||
|
dynamic_size_adjusted, adjustment));
|
||||||
|
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted,
|
||||||
|
constraint);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
"Dynamic dimension propagation on padding dimension is not "
|
"Dynamic dimension propagation on interio padding dimension is "
|
||||||
"supported.");
|
"not "
|
||||||
|
"supported: %s",
|
||||||
|
hlo->ToString());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -400,11 +411,19 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
|
|||||||
|
|
||||||
Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
|
Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
|
||||||
return ForEachOperandDynamicDimension(
|
return ForEachOperandDynamicDimension(
|
||||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
hlo,
|
||||||
int64 operand_index, HloInstruction* dynamic_size,
|
[&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||||
DimensionConstraint constraint) {
|
int64 operand_index, HloInstruction* dynamic_size,
|
||||||
parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension],
|
DimensionConstraint constraint) -> Status {
|
||||||
dynamic_size, constraint);
|
int64 permuted_dim = -1;
|
||||||
|
for (int64 i = 0; i < hlo->dimensions().size(); ++i) {
|
||||||
|
if (hlo->dimensions()[i] == dimension) {
|
||||||
|
TF_RET_CHECK(permuted_dim == -1);
|
||||||
|
permuted_dim = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size,
|
||||||
|
constraint);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -979,14 +998,8 @@ Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
|
|||||||
hlo->slice_strides(dimension) != 1 ||
|
hlo->slice_strides(dimension) != 1 ||
|
||||||
hlo->slice_limits(dimension) !=
|
hlo->slice_limits(dimension) !=
|
||||||
operand->shape().dimensions(dimension)) {
|
operand->shape().dimensions(dimension)) {
|
||||||
// Slicing a single element out eliminates the dynamic dimension.
|
// Slicing a partial element out eliminates the dynamic dimension.
|
||||||
if (hlo->shape().dimensions(dimension) == 1) {
|
return Status::OK();
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return Unimplemented(
|
|
||||||
"Dynamic dimension propagation on Slice where it doesn't slice "
|
|
||||||
"out an entire dimension is not supported %s",
|
|
||||||
hlo->ToString());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
||||||
|
@ -386,6 +386,53 @@ TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
|
|||||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
|
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DynamicDimensionInferenceTest, DotTestMultiContracting) {
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto lhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 8, 64});
|
||||||
|
auto rhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 512});
|
||||||
|
auto output_shape = ShapeUtil::MakeShape(F32, {8, 64, 512});
|
||||||
|
|
||||||
|
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
/*parameter_number=*/0, lhs_shape, "A"));
|
||||||
|
auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
/*parameter_number=*/1, rhs_shape, "B"));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
/*parameter_number=*/2, scalar_shape_, "size_param"));
|
||||||
|
|
||||||
|
DotDimensionNumbers dot_dnums;
|
||||||
|
dot_dnums.add_lhs_contracting_dimensions(0);
|
||||||
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||||
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||||
|
dot_dnums.add_rhs_contracting_dimensions(1);
|
||||||
|
auto dot = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
|
||||||
|
HloTestBase::DefaultPrecisionConfig(2)));
|
||||||
|
|
||||||
|
module_->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||||
|
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||||
|
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
|
||||||
|
|
||||||
|
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||||
|
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||||
|
DynamicParameterBinding::DynamicDimension{0, {}, 1}));
|
||||||
|
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||||
|
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||||
|
DynamicParameterBinding::DynamicDimension{1, {}, 0}));
|
||||||
|
|
||||||
|
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||||
|
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||||
|
DynamicParameterBinding::DynamicDimension{1, {}, 1}));
|
||||||
|
|
||||||
|
SCOPED_TRACE(module_->ToString());
|
||||||
|
TF_ASSERT_OK(RunInference());
|
||||||
|
// Nothing is dynamic in the output.
|
||||||
|
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), nullptr);
|
||||||
|
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
|
||||||
|
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
|
TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
constexpr int xdim = 3;
|
constexpr int xdim = 3;
|
||||||
@ -474,6 +521,45 @@ TEST_F(DynamicDimensionInferenceTest, TransposeTest) {
|
|||||||
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1);
|
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DynamicDimensionInferenceTest, NonDescendingTransposeTest) {
|
||||||
|
// Test the ability to trace unmodified dimensions
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
|
||||||
|
auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2});
|
||||||
|
|
||||||
|
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
/*parameter_number=*/0, input_shape, "A"));
|
||||||
|
auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
/*parameter_number=*/1, scalar_shape_, "size_param"));
|
||||||
|
auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
/*parameter_number=*/2, scalar_shape_, "size_param"));
|
||||||
|
auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
/*parameter_number=*/3, scalar_shape_, "size_param"));
|
||||||
|
|
||||||
|
auto* transpose = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateTranspose(output_shape, a_param, {2, 0, 1}));
|
||||||
|
|
||||||
|
module_->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||||
|
DynamicParameterBinding::DynamicParameter{1, {}},
|
||||||
|
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
|
||||||
|
|
||||||
|
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||||
|
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||||
|
DynamicParameterBinding::DynamicDimension{0, {}, 1}));
|
||||||
|
|
||||||
|
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||||
|
DynamicParameterBinding::DynamicParameter{3, {}},
|
||||||
|
DynamicParameterBinding::DynamicDimension{0, {}, 2}));
|
||||||
|
|
||||||
|
SCOPED_TRACE(module_->ToString());
|
||||||
|
TF_ASSERT_OK(RunInference());
|
||||||
|
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
|
||||||
|
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_1);
|
||||||
|
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_2);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DynamicDimensionInferenceTest, ReshapeTest) {
|
TEST_F(DynamicDimensionInferenceTest, ReshapeTest) {
|
||||||
// Test the ability to trace unmodified reshape dimensions.
|
// Test the ability to trace unmodified reshape dimensions.
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
@ -865,6 +865,45 @@ ENTRY main {
|
|||||||
EXPECT_EQ(result, expected);
|
EXPECT_EQ(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ExecutionTest, DynamicPad) {
|
||||||
|
const string hlo_text = R"(
|
||||||
|
HloModule TEST
|
||||||
|
|
||||||
|
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||||
|
lhs = s32[] parameter(0)
|
||||||
|
rhs = s32[] parameter(1)
|
||||||
|
ROOT add = s32[] add(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
param = s32[4] parameter(0)
|
||||||
|
size = s32[] constant(3)
|
||||||
|
padding = s32[] constant(2)
|
||||||
|
param_dynamic = s32[4] set-dimension-size(param, size),
|
||||||
|
dimensions={0}
|
||||||
|
// pad head and tail to 2
|
||||||
|
pad = s32[6] pad(param_dynamic, padding), padding=1_1
|
||||||
|
|
||||||
|
init = s32[] constant(0)
|
||||||
|
ROOT reduce = s32[] reduce(pad, init),
|
||||||
|
dimensions={0},
|
||||||
|
to_apply=update_s32
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
Literal operand = LiteralUtil::CreateR1<int32>({1, 4, 3, 5});
|
||||||
|
auto module = GetHloModule(hlo_text);
|
||||||
|
|
||||||
|
// After padding head and tail with "2", the effective data will be [2, 1, 4,
|
||||||
|
// 3, 2]
|
||||||
|
|
||||||
|
Literal result = PadAndExecute(std::move(module), {&operand},
|
||||||
|
/*slice_dynamic_output=*/false);
|
||||||
|
Literal expected = LiteralUtil::CreateR0<int32>(12);
|
||||||
|
|
||||||
|
EXPECT_EQ(result, expected);
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(ExecutionTest, DynamicTupleSort) {
|
XLA_TEST_F(ExecutionTest, DynamicTupleSort) {
|
||||||
const string hlo_text = R"(
|
const string hlo_text = R"(
|
||||||
HloModule TEST
|
HloModule TEST
|
||||||
|
Loading…
Reference in New Issue
Block a user