[XLA] Adapt HLO pipeline to scalar-index DS and DUS
This and adapts HLO to work with the scalar DynamicSlice / DynamicUpdateSlice form. This means it also effectively switches the backends to use the new form. Also a bunch of test churn - tests that don't run optimizations (and, hence, don't trigger the splitter pass) needed to be converted. PiperOrigin-RevId: 227908542
This commit is contained in:
parent
4614b04cc0
commit
962d8821eb
@ -1380,6 +1380,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
|
||||
// => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
|
||||
|
||||
bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
|
||||
HloDynamicSliceInstruction* dynamic_slice =
|
||||
lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs)
|
||||
: Cast<HloDynamicSliceInstruction>(rhs);
|
||||
|
||||
// ctA:
|
||||
HloInstruction* left_operand =
|
||||
@ -1397,8 +1400,6 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
|
||||
HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
|
||||
dnums, dot->precision_config()));
|
||||
// Get pair {start, 0} or {0, start}.
|
||||
HloInstruction* original_start_indices =
|
||||
lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
|
||||
// Position of start:
|
||||
int index_of_non_zero_start = lhs_is_dynamic_slice
|
||||
? 1 - lhs_contracting_dimension
|
||||
@ -1407,23 +1408,19 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
|
||||
int index_of_zero_start = 1 - index_of_non_zero_start;
|
||||
|
||||
// Slice out start and 0 components and reorder if necessary.
|
||||
auto indices_type = original_start_indices->shape().element_type();
|
||||
auto indices_type = dynamic_slice->operand(1)->shape().element_type();
|
||||
Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
|
||||
Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
|
||||
HloInstruction* non_zero_start =
|
||||
computation_->AddInstruction(HloInstruction::CreateSlice(
|
||||
s_shape, original_start_indices, {index_of_non_zero_start},
|
||||
{index_of_non_zero_start + 1}, {1}));
|
||||
dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
|
||||
HloInstruction* zero_start =
|
||||
computation_->AddInstruction(HloInstruction::CreateSlice(
|
||||
s_shape, original_start_indices, {index_of_zero_start},
|
||||
{index_of_zero_start + 1}, {1}));
|
||||
HloInstruction* new_start_indices =
|
||||
lhs_is_dynamic_slice
|
||||
? computation_->AddInstruction(HloInstruction::CreateConcatenate(
|
||||
d_shape, {non_zero_start, zero_start}, 0))
|
||||
: computation_->AddInstruction(HloInstruction::CreateConcatenate(
|
||||
d_shape, {zero_start, non_zero_start}, 0));
|
||||
dynamic_slice->mutable_operand(1 + index_of_zero_start);
|
||||
std::vector<HloInstruction*> new_start_indices;
|
||||
if (lhs_is_dynamic_slice) {
|
||||
new_start_indices = {non_zero_start, zero_start};
|
||||
} else {
|
||||
new_start_indices = {zero_start, non_zero_start};
|
||||
}
|
||||
|
||||
// Build DynamicSlice(ctA x ctB).
|
||||
const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
|
||||
|
@ -4535,14 +4535,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
|
||||
|
||||
int32 start_row = (spec.lcd == 0) ? 0 : spec.s;
|
||||
int32 start_col = (spec.lcd == 0) ? spec.s : 0;
|
||||
const auto start_indices =
|
||||
std::vector<HloInstruction*> start_indices = {
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<int32>({start_row, start_col})));
|
||||
LiteralUtil::CreateR0<int32>(start_row))),
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(start_col)))};
|
||||
int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
|
||||
int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
|
||||
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
|
||||
std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
|
||||
Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
|
||||
auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
ds_shape, lhs, start_indices, {slice_row_size, slice_col_size}));
|
||||
ds_shape, lhs, start_indices, slice_sizes));
|
||||
|
||||
int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
|
||||
int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
|
||||
@ -4575,7 +4578,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
|
||||
} else {
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
|
||||
m::Concatenate())));
|
||||
m::Constant(), m::Constant())));
|
||||
}
|
||||
}
|
||||
|
||||
@ -4613,14 +4616,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
|
||||
|
||||
int32 start_row = (spec.rcd == 0) ? 0 : spec.s;
|
||||
int32 start_col = (spec.rcd == 0) ? spec.s : 0;
|
||||
const auto start_indices =
|
||||
std::vector<HloInstruction*> start_indices = {
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<int32>({start_row, start_col})));
|
||||
LiteralUtil::CreateR0<int32>(start_row))),
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(start_col)))};
|
||||
int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
|
||||
int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
|
||||
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
|
||||
std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
|
||||
Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
|
||||
auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
ds_shape, rhs, start_indices, {slice_row_size, slice_col_size}));
|
||||
ds_shape, rhs, start_indices, slice_sizes));
|
||||
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
|
||||
@ -4645,7 +4651,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
|
||||
} else {
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
|
||||
m::Concatenate())));
|
||||
m::Constant(), m::Constant())));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2486,9 +2486,9 @@ while_body {
|
||||
get-tuple-element.3 = s32[] get-tuple-element(state), index=0
|
||||
constant.2 = s32[] constant(128)
|
||||
add.5 = s32[] add(get-tuple-element.3, constant.2)
|
||||
constant.3 = s32[3]{0} constant({0, 0, 0})
|
||||
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3)
|
||||
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3)
|
||||
constant.3 = s32[] constant(0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,14 @@
|
||||
# Description:
|
||||
# LLVM-based CPU backend for XLA.
|
||||
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"mkl_deps",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||
load(":build_defs.bzl", "runtime_copts")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
@ -14,15 +22,6 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
load(":build_defs.bzl", "runtime_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"mkl_deps",
|
||||
)
|
||||
|
||||
# Filegroup used to collect source files for dependency checking.
|
||||
filegroup(
|
||||
name = "c_srcs",
|
||||
@ -114,6 +113,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
|
@ -69,6 +69,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -244,6 +245,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
HloPassPipeline pipeline("HLO passes through layout assignment");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
pipeline.AddPass<DynamicIndexSplitter>();
|
||||
pipeline.AddPass<CpuHloSupportChecker>();
|
||||
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
|
@ -3,6 +3,11 @@
|
||||
|
||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
@ -24,12 +29,6 @@ filegroup(
|
||||
]),
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
|
||||
xla_proto_library(
|
||||
name = "backend_configs",
|
||||
srcs = ["backend_configs.proto"],
|
||||
@ -700,6 +699,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
||||
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
|
@ -628,8 +628,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) {
|
||||
p.1 = s32[1]{0} parameter(1)
|
||||
p.2 = f16[1,96,1024]{2,1,0} parameter(2)
|
||||
c.0 = s32[] constant(0)
|
||||
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
|
||||
ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
|
||||
ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
|
||||
}
|
||||
|
||||
fusion.2 {
|
||||
@ -638,7 +637,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) {
|
||||
p.2 = f16[1,96,1024]{2,1,0} parameter(2)
|
||||
c.0 = s32[] constant(0)
|
||||
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
|
||||
ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
|
||||
ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h"
|
||||
@ -152,6 +153,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
HloPassPipeline pipeline("optimization");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
pipeline.AddPass<DynamicIndexSplitter>();
|
||||
pipeline.AddPass<GpuHloSupportChecker>();
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
&pipeline, hlo_module->config().debug_options(),
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -105,12 +106,26 @@ StatusOr<HloInstruction*> MakeDynamicSliceHlo(
|
||||
absl::Span<const int64> slice_sizes) {
|
||||
HloComputation* computation = operand->parent();
|
||||
CHECK_EQ(computation, start_indices->parent());
|
||||
int64 rank = start_indices->shape().dimensions(0);
|
||||
std::vector<HloInstruction*> scalar_start_indices;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
// TODO(b/118437727): Update callers to provide scalars directly.
|
||||
auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
|
||||
start_indices, {i}, {i + 1}, {1}));
|
||||
scalar_start_indices.push_back(
|
||||
computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
|
||||
slice)));
|
||||
}
|
||||
std::vector<Shape> scalar_start_indices_shapes(
|
||||
rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape dynamic_slice_shape,
|
||||
ShapeInference::InferDynamicSliceShape(
|
||||
operand->shape(), {start_indices->shape()}, slice_sizes));
|
||||
operand->shape(), scalar_start_indices_shapes, slice_sizes));
|
||||
return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
dynamic_slice_shape, operand, start_indices, slice_sizes));
|
||||
dynamic_slice_shape, operand, scalar_start_indices, slice_sizes));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
|
||||
@ -119,12 +134,26 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
|
||||
HloComputation* computation = operand->parent();
|
||||
CHECK_EQ(computation, update->parent());
|
||||
CHECK_EQ(computation, start_indices->parent());
|
||||
int64 rank = start_indices->shape().dimensions(0);
|
||||
std::vector<HloInstruction*> scalar_start_indices;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
// TODO(b/118437727): Update callers to provide scalars directly.
|
||||
auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
|
||||
start_indices, {i}, {i + 1}, {1}));
|
||||
scalar_start_indices.push_back(
|
||||
computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
|
||||
slice)));
|
||||
}
|
||||
std::vector<Shape> scalar_start_indices_shapes(
|
||||
rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape dynamic_update_slice_shape,
|
||||
ShapeInference::InferDynamicUpdateSliceShape(
|
||||
operand->shape(), update->shape(), {start_indices->shape()}));
|
||||
operand->shape(), update->shape(), scalar_start_indices_shapes));
|
||||
return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
dynamic_update_slice_shape, operand, update, start_indices));
|
||||
dynamic_update_slice_shape, operand, update, scalar_start_indices));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> MakeBroadcastHlo(
|
||||
|
@ -1410,22 +1410,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
auto operand = dynamic_slice->operand(0);
|
||||
auto start_indices = dynamic_slice->operand(1);
|
||||
auto result_shape = dynamic_slice->shape();
|
||||
// TODO(b/118437727): Remove all of this nonsense.
|
||||
// We may get an instruction without a parent module. In this case, assume
|
||||
// scalar indices are not allowed.
|
||||
bool allow_scalar_index = false;
|
||||
if (dynamic_slice->GetModule() != nullptr) {
|
||||
allow_scalar_index = dynamic_slice->GetModule()
|
||||
->config()
|
||||
.debug_options()
|
||||
.xla_allow_scalar_index_dynamic_ops();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto inferred_return_shape,
|
||||
ShapeInference::InferDynamicSliceShape(
|
||||
operand->shape(),
|
||||
Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
|
||||
dynamic_slice->dynamic_slice_sizes(), allow_scalar_index));
|
||||
dynamic_slice->dynamic_slice_sizes()));
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
|
||||
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
|
||||
<< " but is inferred to be: "
|
||||
@ -1483,20 +1473,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
auto update = dynamic_update_slice->operand(1);
|
||||
auto start_indices = dynamic_update_slice->operand(2);
|
||||
auto result_shape = dynamic_update_slice->shape();
|
||||
bool allow_scalar_index = false;
|
||||
if (dynamic_update_slice->GetModule() != nullptr) {
|
||||
allow_scalar_index = dynamic_update_slice->GetModule()
|
||||
->config()
|
||||
.debug_options()
|
||||
.xla_allow_scalar_index_dynamic_ops();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto inferred_return_shape,
|
||||
ShapeInference::InferDynamicUpdateSliceShape(
|
||||
operand->shape(), update->shape(),
|
||||
Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
|
||||
->index_shapes(),
|
||||
allow_scalar_index));
|
||||
->index_shapes()));
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
|
||||
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
|
||||
<< " but is inferred to be: "
|
||||
|
@ -504,38 +504,23 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
|
||||
const DebugOptions& debug_options =
|
||||
dynamic_slice->GetModule()->config().debug_options();
|
||||
const bool allow_scalar_indices =
|
||||
debug_options.xla_allow_scalar_index_dynamic_ops();
|
||||
if (!allow_scalar_indices) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2));
|
||||
}
|
||||
return CheckShape(
|
||||
dynamic_slice,
|
||||
ShapeInference::InferDynamicSliceShape(
|
||||
dynamic_slice->operand(0)->shape(),
|
||||
Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
|
||||
dynamic_slice->dynamic_slice_sizes(), allow_scalar_indices));
|
||||
dynamic_slice->dynamic_slice_sizes()));
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleDynamicUpdateSlice(
|
||||
HloInstruction* dynamic_update_slice) {
|
||||
const DebugOptions& debug_options =
|
||||
dynamic_update_slice->GetModule()->config().debug_options();
|
||||
const bool allow_scalar_indices =
|
||||
debug_options.xla_allow_scalar_index_dynamic_ops();
|
||||
if (!allow_scalar_indices) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3));
|
||||
}
|
||||
return CheckShape(
|
||||
dynamic_update_slice,
|
||||
ShapeInference::InferDynamicUpdateSliceShape(
|
||||
dynamic_update_slice->operand(0)->shape(),
|
||||
dynamic_update_slice->operand(1)->shape(),
|
||||
Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
|
||||
->index_shapes(),
|
||||
allow_scalar_indices));
|
||||
->index_shapes()));
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {
|
||||
|
@ -450,8 +450,9 @@ TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
|
||||
HloModule SliceWithLayoutChange
|
||||
ENTRY SliceWithLayoutChange {
|
||||
par0 = f32[4,5]{0,1} parameter(0)
|
||||
par1 = s32[2] parameter(1)
|
||||
ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1),
|
||||
par1 = s32[] parameter(1)
|
||||
par2 = s32[] parameter(2)
|
||||
ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
|
||||
dynamic_slice_sizes={3,4}
|
||||
}
|
||||
)";
|
||||
|
@ -1,12 +1,12 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"if_static",
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_library(
|
||||
name = "interpreter_transfer_manager",
|
||||
srcs = ["interpreter_transfer_manager.cc"],
|
||||
@ -35,6 +35,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cse.h"
|
||||
@ -44,6 +45,7 @@ namespace interpreter {
|
||||
Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
|
||||
HloPassPipeline pipeline("Interpreter");
|
||||
|
||||
pipeline.AddPass<DynamicIndexSplitter>();
|
||||
pipeline.AddPass<LayoutAssignment>(
|
||||
hlo_module->mutable_entry_computation_layout(),
|
||||
LayoutAssignment::InstructionCanChangeLayout);
|
||||
|
@ -960,8 +960,9 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
|
||||
ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
|
||||
par0 = f32[3,4]{1,0} parameter(0)
|
||||
par1 = f32[4,5]{0,1} parameter(1)
|
||||
par2 = s32[2] parameter(2)
|
||||
dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4}
|
||||
par2 = s32[] parameter(2)
|
||||
par3 = s32[] parameter(3)
|
||||
dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
|
||||
ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
|
||||
}
|
||||
)";
|
||||
@ -982,7 +983,7 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
|
||||
m::Parameter(),
|
||||
m::DynamicSlice(
|
||||
m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
|
||||
m::Parameter(2)))));
|
||||
m::Parameter(2), m::Parameter(3)))));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
|
||||
|
@ -2118,7 +2118,6 @@ XLA_BINOP_PATTERN(Divide)
|
||||
XLA_BINOP_PATTERN(Complex)
|
||||
XLA_BINOP_PATTERN(Convolution)
|
||||
XLA_BINOP_PATTERN(Dot)
|
||||
XLA_BINOP_PATTERN(DynamicSlice)
|
||||
XLA_COMMUTATIVE_BINOP_PATTERN(Eq)
|
||||
XLA_BINOP_PATTERN(Gather)
|
||||
XLA_BINOP_PATTERN(Ge)
|
||||
@ -2235,6 +2234,7 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
|
||||
XLA_VARIADIC_OP_PATTERN(AfterAll);
|
||||
XLA_VARIADIC_OP_PATTERN(Concatenate);
|
||||
XLA_VARIADIC_OP_PATTERN(CustomCall);
|
||||
XLA_VARIADIC_OP_PATTERN(DynamicSlice)
|
||||
XLA_VARIADIC_OP_PATTERN(Map)
|
||||
XLA_VARIADIC_OP_PATTERN(Reduce);
|
||||
XLA_VARIADIC_OP_PATTERN(Sort);
|
||||
|
@ -2099,15 +2099,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
}
|
||||
|
||||
const Shape& start_indices_shape = start_index_shapes[0];
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpectArray(start_indices_shape, "start indices of dynamic slice"));
|
||||
|
||||
VLOG(2) << StrFormat(
|
||||
"slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
|
||||
ShapeUtil::HumanString(operand_shape),
|
||||
ShapeUtil::HumanString(start_indices_shape),
|
||||
StrJoin(slice_sizes, ", "));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpectArray(start_indices_shape, "start indices of dynamic slice"));
|
||||
|
||||
if (start_indices_shape.rank() != 1) {
|
||||
return InvalidArgument(
|
||||
"Dynamic slice start indices of rank %d must be rank1.",
|
||||
@ -2151,7 +2151,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
"Dynamic slice start indices must be of integral type.");
|
||||
}
|
||||
for (const Shape& index_shape : start_index_shapes) {
|
||||
if (!ShapeUtil::Equal(first_index_shape, index_shape)) {
|
||||
if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
|
||||
return InvalidArgument(
|
||||
"Dynamic slice start indices must all have the same shape, got "
|
||||
"mismatching indices with shapes %s and %s.",
|
||||
@ -2258,7 +2258,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
"Dynamic update slice start indices must be of integral type.");
|
||||
}
|
||||
for (const Shape& index_shape : start_index_shapes) {
|
||||
if (!ShapeUtil::Equal(first_index_shape, index_shape)) {
|
||||
if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
|
||||
return InvalidArgument(
|
||||
"Dynamic update slice start indices must all have the same "
|
||||
"shape, got mismatching indices with shapes %s and %s.",
|
||||
|
@ -177,14 +177,14 @@ class ShapeInference {
|
||||
// in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
|
||||
static StatusOr<Shape> InferDynamicSliceShape(
|
||||
const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
|
||||
absl::Span<const int64> slice_sizes, bool allow_scalar_indices = false);
|
||||
absl::Span<const int64> slice_sizes, bool allow_scalar_indices = true);
|
||||
|
||||
// Infers the shape produced by a dynamic update slice operation based
|
||||
// on the shape of operand and update.
|
||||
static StatusOr<Shape> InferDynamicUpdateSliceShape(
|
||||
const Shape& operand_shape, const Shape& update_shape,
|
||||
absl::Span<const Shape> start_index_shapes,
|
||||
bool allow_scalar_indices = false);
|
||||
bool allow_scalar_indices = true);
|
||||
|
||||
// Infers the shape produced by doing a compile-time-constant indexing into
|
||||
// the given input shape. This is essential for operations on tuples, because
|
||||
|
@ -112,6 +112,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
|
||||
|
||||
if (compare_layouts) {
|
||||
if (lhs.layout().format() != rhs.layout().format()) {
|
||||
VLOG(3) << "CompareShapes: lhs layout format != rhs layout format";
|
||||
return false;
|
||||
}
|
||||
if (LayoutUtil::IsDenseArray(lhs)) {
|
||||
|
@ -71,6 +71,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
@ -274,16 +275,9 @@ bool NeedsInitValue(const HloUse& use) {
|
||||
|
||||
// Generate random values that are constrained to the input_shape minus the
|
||||
// output_shape so as not to produce wrapping slices, for instance.
|
||||
Literal MakeRandomIndex(absl::Span<const int64> index_space,
|
||||
std::minstd_rand0* engine) {
|
||||
std::vector<int32> start_indices(index_space.size());
|
||||
if (engine != nullptr) {
|
||||
for (int i = 0; i < index_space.size(); ++i) {
|
||||
std::uniform_int_distribution<int32> generator(0, index_space[i]);
|
||||
start_indices[i] = generator(*engine);
|
||||
}
|
||||
}
|
||||
return LiteralUtil::CreateR1<int32>(start_indices);
|
||||
Literal MakeRandomIndex(int64 index_bound, std::minstd_rand0* engine) {
|
||||
std::uniform_int_distribution<int32> generator(0, index_bound);
|
||||
return LiteralUtil::CreateR0<int32>(generator(*engine));
|
||||
}
|
||||
|
||||
// Use dataflow analysis on each parameter to see if there are uses that would
|
||||
@ -300,8 +294,8 @@ std::vector<HloInstruction*> FindConstrainedUses(
|
||||
HloInstruction* instruction = use.instruction;
|
||||
const HloOpcode opcode = instruction->opcode();
|
||||
const int64 op_num = use.operand_number;
|
||||
if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) ||
|
||||
(opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) {
|
||||
if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) ||
|
||||
(opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) {
|
||||
constrained_uses.push_back(instruction);
|
||||
} else if (opcode == HloOpcode::kFusion) {
|
||||
const HloInstruction* const to_analyze =
|
||||
@ -336,7 +330,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
|
||||
StatusOr<Literal> CreateLiteralForConstrainedUses(
|
||||
const absl::Span<HloInstruction* const> constrained_uses,
|
||||
const HloInstruction& param, std::minstd_rand0* engine) {
|
||||
std::vector<int64> index_space;
|
||||
int64 index_bound = INT64_MAX;
|
||||
bool no_duplicates = false;
|
||||
bool needs_constant = false;
|
||||
ConstantType constant_type = ConstantType::kUnknown;
|
||||
@ -348,19 +342,16 @@ StatusOr<Literal> CreateLiteralForConstrainedUses(
|
||||
const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
|
||||
? use->shape()
|
||||
: use->operand(1)->shape();
|
||||
const int64 rank = indexed_shape.rank();
|
||||
if (!index_space.empty()) {
|
||||
TF_RET_CHECK(rank == index_space.size());
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
index_space[i] = std::min(
|
||||
index_space[i], ShapeUtil::GetDimension(indexed_shape, i) -
|
||||
ShapeUtil::GetDimension(slice_shape, i));
|
||||
}
|
||||
} else {
|
||||
index_space.resize(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) -
|
||||
ShapeUtil::GetDimension(slice_shape, i);
|
||||
const int64 first_index =
|
||||
Cast<HloDynamicIndexInstruction>(use)->first_index_operand_number();
|
||||
for (int64 operand = first_index; operand < use->operand_count();
|
||||
++operand) {
|
||||
if (use->operand(operand) == ¶m) {
|
||||
index_bound = std::min(
|
||||
index_bound,
|
||||
ShapeUtil::GetDimension(indexed_shape, operand - first_index) -
|
||||
ShapeUtil::GetDimension(slice_shape,
|
||||
operand - first_index));
|
||||
}
|
||||
}
|
||||
break;
|
||||
@ -388,16 +379,13 @@ StatusOr<Literal> CreateLiteralForConstrainedUses(
|
||||
}
|
||||
int constraint_count = 0;
|
||||
constraint_count += no_duplicates ? 1 : 0;
|
||||
constraint_count += !index_space.empty() ? 1 : 0;
|
||||
constraint_count += (index_bound != INT64_MAX) ? 1 : 0;
|
||||
constraint_count += needs_constant ? 1 : 0;
|
||||
if (constraint_count > 1) {
|
||||
return Unimplemented("Conflicting operand generation constraints.");
|
||||
}
|
||||
if (!index_space.empty()) {
|
||||
// constrained_uses looks through bitcasts, so param and indexed_space may
|
||||
// not have the same shape. (For example, param might be an R0 while
|
||||
// indexed_space might have size 1.)
|
||||
return MakeRandomIndex(index_space, engine)
|
||||
if (index_bound != INT64_MAX) {
|
||||
return MakeRandomIndex(index_bound, engine)
|
||||
.Reshape(param.shape().dimensions());
|
||||
} else if (needs_constant) {
|
||||
switch (constant_type) {
|
||||
|
@ -79,25 +79,26 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
|
||||
R"(HloModule index_space_module
|
||||
|
||||
ENTRY IndexSpace {
|
||||
index_param = s32[3]{0} parameter(0)
|
||||
array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
|
||||
array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
|
||||
dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3}
|
||||
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
|
||||
index_param.0 = s32[] parameter(0)
|
||||
index_param.1 = s32[] parameter(1)
|
||||
index_param.2 = s32[] parameter(2)
|
||||
array_param.1 = f32[123,4,789]{0,1,2} parameter(3)
|
||||
array_param.2 = f32[3,3000,5]{0,1,2} parameter(4)
|
||||
dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={1,2,3}
|
||||
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={3,2,2}
|
||||
})")
|
||||
.ValueOrDie();
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
|
||||
MakeFakeArguments(module.get()));
|
||||
ASSERT_EQ(args.size(), 3);
|
||||
const Literal& index_arg = args[0];
|
||||
ASSERT_EQ(args.size(), 5);
|
||||
|
||||
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
|
||||
EXPECT_EQ(args[0].Get<int32>({}), 0);
|
||||
|
||||
EXPECT_GE(index_arg.Get<int32>({1}), 0);
|
||||
EXPECT_LE(index_arg.Get<int32>({1}), 2);
|
||||
EXPECT_GE(args[1].Get<int32>({}), 0);
|
||||
EXPECT_LE(args[0].Get<int32>({}), 2);
|
||||
|
||||
EXPECT_GE(index_arg.Get<int32>({2}), 0);
|
||||
EXPECT_LE(index_arg.Get<int32>({2}), 3);
|
||||
EXPECT_GE(args[2].Get<int32>({}), 0);
|
||||
EXPECT_LE(args[2].Get<int32>({}), 3);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
|
||||
@ -105,28 +106,29 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
|
||||
R"(HloModule index_space_module
|
||||
|
||||
ENTRY IndexSpace {
|
||||
index_param = s32[3]{0} parameter(0)
|
||||
array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
|
||||
array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
|
||||
update_param.1 = f32[1,2,3]{0,1,2} parameter(3)
|
||||
update_param.2 = f32[3,2,2]{0,1,2} parameter(4)
|
||||
index_param.0 = s32[] parameter(0)
|
||||
index_param.1 = s32[] parameter(1)
|
||||
index_param.2 = s32[] parameter(2)
|
||||
array_param.1 = f32[123,4,789]{0,1,2} parameter(3)
|
||||
array_param.2 = f32[3,3000,5]{0,1,2} parameter(4)
|
||||
update_param.1 = f32[1,2,3]{0,1,2} parameter(5)
|
||||
update_param.2 = f32[3,2,2]{0,1,2} parameter(6)
|
||||
|
||||
dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param)
|
||||
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
|
||||
dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param.0, index_param.1, index_param.2)
|
||||
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param.0, index_param.1, index_param.2)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
|
||||
MakeFakeArguments(module.get()));
|
||||
ASSERT_EQ(args.size(), 5);
|
||||
const Literal& index_arg = args[0];
|
||||
ASSERT_EQ(args.size(), 7);
|
||||
|
||||
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
|
||||
EXPECT_EQ(args[0].Get<int32>({}), 0);
|
||||
|
||||
EXPECT_GE(index_arg.Get<int32>({1}), 0);
|
||||
EXPECT_LE(index_arg.Get<int32>({1}), 2);
|
||||
EXPECT_GE(args[1].Get<int32>({}), 0);
|
||||
EXPECT_LE(args[0].Get<int32>({}), 2);
|
||||
|
||||
EXPECT_GE(index_arg.Get<int32>({2}), 0);
|
||||
EXPECT_LE(index_arg.Get<int32>({2}), 3);
|
||||
EXPECT_GE(args[2].Get<int32>({}), 0);
|
||||
EXPECT_LE(args[2].Get<int32>({}), 3);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user