[XLA] Adapt HLO pipeline to scalar-index DS and DUS

This and adapts HLO to work with the scalar DynamicSlice / DynamicUpdateSlice
form. This means it also effectively switches the backends to use the new form.

Also a bunch of test churn - tests that don't run optimizations (and, hence,
don't trigger the splitter pass) needed to be converted.

PiperOrigin-RevId: 227908542
This commit is contained in:
Michael Kuperstein 2019-01-04 13:52:30 -08:00 committed by TensorFlower Gardener
parent 4614b04cc0
commit 962d8821eb
22 changed files with 161 additions and 162 deletions

View File

@ -1380,6 +1380,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
// => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
HloDynamicSliceInstruction* dynamic_slice =
lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs)
: Cast<HloDynamicSliceInstruction>(rhs);
// ctA:
HloInstruction* left_operand =
@ -1397,8 +1400,6 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
dnums, dot->precision_config()));
// Get pair {start, 0} or {0, start}.
HloInstruction* original_start_indices =
lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
// Position of start:
int index_of_non_zero_start = lhs_is_dynamic_slice
? 1 - lhs_contracting_dimension
@ -1407,23 +1408,19 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
int index_of_zero_start = 1 - index_of_non_zero_start;
// Slice out start and 0 components and reorder if necessary.
auto indices_type = original_start_indices->shape().element_type();
auto indices_type = dynamic_slice->operand(1)->shape().element_type();
Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
HloInstruction* non_zero_start =
computation_->AddInstruction(HloInstruction::CreateSlice(
s_shape, original_start_indices, {index_of_non_zero_start},
{index_of_non_zero_start + 1}, {1}));
dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
HloInstruction* zero_start =
computation_->AddInstruction(HloInstruction::CreateSlice(
s_shape, original_start_indices, {index_of_zero_start},
{index_of_zero_start + 1}, {1}));
HloInstruction* new_start_indices =
lhs_is_dynamic_slice
? computation_->AddInstruction(HloInstruction::CreateConcatenate(
d_shape, {non_zero_start, zero_start}, 0))
: computation_->AddInstruction(HloInstruction::CreateConcatenate(
d_shape, {zero_start, non_zero_start}, 0));
dynamic_slice->mutable_operand(1 + index_of_zero_start);
std::vector<HloInstruction*> new_start_indices;
if (lhs_is_dynamic_slice) {
new_start_indices = {non_zero_start, zero_start};
} else {
new_start_indices = {zero_start, non_zero_start};
}
// Build DynamicSlice(ctA x ctB).
const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;

View File

@ -4535,14 +4535,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int32 start_row = (spec.lcd == 0) ? 0 : spec.s;
int32 start_col = (spec.lcd == 0) ? spec.s : 0;
const auto start_indices =
std::vector<HloInstruction*> start_indices = {
builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<int32>({start_row, start_col})));
LiteralUtil::CreateR0<int32>(start_row))),
builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(start_col)))};
int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
ds_shape, lhs, start_indices, {slice_row_size, slice_col_size}));
ds_shape, lhs, start_indices, slice_sizes));
int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
@ -4575,7 +4578,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
} else {
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
m::Concatenate())));
m::Constant(), m::Constant())));
}
}
@ -4613,14 +4616,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int32 start_row = (spec.rcd == 0) ? 0 : spec.s;
int32 start_col = (spec.rcd == 0) ? spec.s : 0;
const auto start_indices =
std::vector<HloInstruction*> start_indices = {
builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<int32>({start_row, start_col})));
LiteralUtil::CreateR0<int32>(start_row))),
builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(start_col)))};
int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
ds_shape, rhs, start_indices, {slice_row_size, slice_col_size}));
ds_shape, rhs, start_indices, slice_sizes));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
@ -4645,7 +4651,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
} else {
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
m::Concatenate())));
m::Constant(), m::Constant())));
}
}

View File

@ -2486,9 +2486,9 @@ while_body {
get-tuple-element.3 = s32[] get-tuple-element(state), index=0
constant.2 = s32[] constant(128)
add.5 = s32[] add(get-tuple-element.3, constant.2)
constant.3 = s32[3]{0} constant({0, 0, 0})
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3)
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3)
constant.3 = s32[] constant(0)
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
}

View File

@ -1,6 +1,14 @@
# Description:
# LLVM-based CPU backend for XLA.
load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
load(
"//third_party/mkl:build_defs.bzl",
"mkl_deps",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
load(":build_defs.bzl", "runtime_copts")
licenses(["notice"]) # Apache 2.0
package(
@ -14,15 +22,6 @@ package_group(
],
)
load(":build_defs.bzl", "runtime_copts")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
load(
"//third_party/mkl:build_defs.bzl",
"mkl_deps",
)
# Filegroup used to collect source files for dependency checking.
filegroup(
name = "c_srcs",
@ -114,6 +113,7 @@ cc_library(
"//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:convolution_group_converter",
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",

View File

@ -69,6 +69,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -244,6 +245,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
HloPassPipeline pipeline("HLO passes through layout assignment");
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<DynamicIndexSplitter>();
pipeline.AddPass<CpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(

View File

@ -3,6 +3,11 @@
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
licenses(["notice"]) # Apache 2.0
@ -24,12 +29,6 @@ filegroup(
]),
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
)
xla_proto_library(
name = "backend_configs",
srcs = ["backend_configs.proto"],
@ -700,6 +699,7 @@ cc_library(
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:convolution_group_converter",
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",

View File

@ -628,8 +628,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) {
p.1 = s32[1]{0} parameter(1)
p.2 = f16[1,96,1024]{2,1,0} parameter(2)
c.0 = s32[] constant(0)
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
}
fusion.2 {
@ -638,7 +637,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) {
p.2 = f16[1,96,1024]{2,1,0} parameter(2)
c.0 = s32[] constant(0)
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
}
ENTRY entry {

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h"
@ -152,6 +153,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
HloPassPipeline pipeline("optimization");
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<DynamicIndexSplitter>();
pipeline.AddPass<GpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
@ -105,12 +106,26 @@ StatusOr<HloInstruction*> MakeDynamicSliceHlo(
absl::Span<const int64> slice_sizes) {
HloComputation* computation = operand->parent();
CHECK_EQ(computation, start_indices->parent());
int64 rank = start_indices->shape().dimensions(0);
std::vector<HloInstruction*> scalar_start_indices;
for (int i = 0; i < rank; ++i) {
// TODO(b/118437727): Update callers to provide scalars directly.
auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
start_indices, {i}, {i + 1}, {1}));
scalar_start_indices.push_back(
computation->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
slice)));
}
std::vector<Shape> scalar_start_indices_shapes(
rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
TF_ASSIGN_OR_RETURN(
Shape dynamic_slice_shape,
ShapeInference::InferDynamicSliceShape(
operand->shape(), {start_indices->shape()}, slice_sizes));
operand->shape(), scalar_start_indices_shapes, slice_sizes));
return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
dynamic_slice_shape, operand, start_indices, slice_sizes));
dynamic_slice_shape, operand, scalar_start_indices, slice_sizes));
}
StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
@ -119,12 +134,26 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
HloComputation* computation = operand->parent();
CHECK_EQ(computation, update->parent());
CHECK_EQ(computation, start_indices->parent());
int64 rank = start_indices->shape().dimensions(0);
std::vector<HloInstruction*> scalar_start_indices;
for (int i = 0; i < rank; ++i) {
// TODO(b/118437727): Update callers to provide scalars directly.
auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
start_indices, {i}, {i + 1}, {1}));
scalar_start_indices.push_back(
computation->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
slice)));
}
std::vector<Shape> scalar_start_indices_shapes(
rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
TF_ASSIGN_OR_RETURN(
Shape dynamic_update_slice_shape,
ShapeInference::InferDynamicUpdateSliceShape(
operand->shape(), update->shape(), {start_indices->shape()}));
operand->shape(), update->shape(), scalar_start_indices_shapes));
return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
dynamic_update_slice_shape, operand, update, start_indices));
dynamic_update_slice_shape, operand, update, scalar_start_indices));
}
StatusOr<HloInstruction*> MakeBroadcastHlo(

View File

@ -1410,22 +1410,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto operand = dynamic_slice->operand(0);
auto start_indices = dynamic_slice->operand(1);
auto result_shape = dynamic_slice->shape();
// TODO(b/118437727): Remove all of this nonsense.
// We may get an instruction without a parent module. In this case, assume
// scalar indices are not allowed.
bool allow_scalar_index = false;
if (dynamic_slice->GetModule() != nullptr) {
allow_scalar_index = dynamic_slice->GetModule()
->config()
.debug_options()
.xla_allow_scalar_index_dynamic_ops();
}
TF_ASSIGN_OR_RETURN(
auto inferred_return_shape,
ShapeInference::InferDynamicSliceShape(
operand->shape(),
Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
dynamic_slice->dynamic_slice_sizes(), allow_scalar_index));
dynamic_slice->dynamic_slice_sizes()));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "
@ -1483,20 +1473,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto update = dynamic_update_slice->operand(1);
auto start_indices = dynamic_update_slice->operand(2);
auto result_shape = dynamic_update_slice->shape();
bool allow_scalar_index = false;
if (dynamic_update_slice->GetModule() != nullptr) {
allow_scalar_index = dynamic_update_slice->GetModule()
->config()
.debug_options()
.xla_allow_scalar_index_dynamic_ops();
}
TF_ASSIGN_OR_RETURN(
auto inferred_return_shape,
ShapeInference::InferDynamicUpdateSliceShape(
operand->shape(), update->shape(),
Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
->index_shapes(),
allow_scalar_index));
->index_shapes()));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "

View File

@ -504,38 +504,23 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
}
Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
const DebugOptions& debug_options =
dynamic_slice->GetModule()->config().debug_options();
const bool allow_scalar_indices =
debug_options.xla_allow_scalar_index_dynamic_ops();
if (!allow_scalar_indices) {
TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2));
}
return CheckShape(
dynamic_slice,
ShapeInference::InferDynamicSliceShape(
dynamic_slice->operand(0)->shape(),
Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
dynamic_slice->dynamic_slice_sizes(), allow_scalar_indices));
dynamic_slice->dynamic_slice_sizes()));
}
Status ShapeVerifier::HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) {
const DebugOptions& debug_options =
dynamic_update_slice->GetModule()->config().debug_options();
const bool allow_scalar_indices =
debug_options.xla_allow_scalar_index_dynamic_ops();
if (!allow_scalar_indices) {
TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3));
}
return CheckShape(
dynamic_update_slice,
ShapeInference::InferDynamicUpdateSliceShape(
dynamic_update_slice->operand(0)->shape(),
dynamic_update_slice->operand(1)->shape(),
Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
->index_shapes(),
allow_scalar_indices));
->index_shapes()));
}
Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {

View File

@ -450,8 +450,9 @@ TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
HloModule SliceWithLayoutChange
ENTRY SliceWithLayoutChange {
par0 = f32[4,5]{0,1} parameter(0)
par1 = s32[2] parameter(1)
ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1),
par1 = s32[] parameter(1)
par2 = s32[] parameter(2)
ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
dynamic_slice_sizes={3,4}
}
)";

View File

@ -1,12 +1,12 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static",
)
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
cc_library(
name = "interpreter_transfer_manager",
srcs = ["interpreter_transfer_manager.cc"],
@ -35,6 +35,7 @@ cc_library(
"//tensorflow/compiler/xla/service:batchnorm_expander",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
#include "tensorflow/compiler/xla/service/hlo_cse.h"
@ -44,6 +45,7 @@ namespace interpreter {
Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
HloPassPipeline pipeline("Interpreter");
pipeline.AddPass<DynamicIndexSplitter>();
pipeline.AddPass<LayoutAssignment>(
hlo_module->mutable_entry_computation_layout(),
LayoutAssignment::InstructionCanChangeLayout);

View File

@ -960,8 +960,9 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
par0 = f32[3,4]{1,0} parameter(0)
par1 = f32[4,5]{0,1} parameter(1)
par2 = s32[2] parameter(2)
dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4}
par2 = s32[] parameter(2)
par3 = s32[] parameter(3)
dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
}
)";
@ -982,7 +983,7 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
m::Parameter(),
m::DynamicSlice(
m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
m::Parameter(2)))));
m::Parameter(2), m::Parameter(3)))));
}
TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {

View File

@ -2118,7 +2118,6 @@ XLA_BINOP_PATTERN(Divide)
XLA_BINOP_PATTERN(Complex)
XLA_BINOP_PATTERN(Convolution)
XLA_BINOP_PATTERN(Dot)
XLA_BINOP_PATTERN(DynamicSlice)
XLA_COMMUTATIVE_BINOP_PATTERN(Eq)
XLA_BINOP_PATTERN(Gather)
XLA_BINOP_PATTERN(Ge)
@ -2235,6 +2234,7 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
XLA_VARIADIC_OP_PATTERN(AfterAll);
XLA_VARIADIC_OP_PATTERN(Concatenate);
XLA_VARIADIC_OP_PATTERN(CustomCall);
XLA_VARIADIC_OP_PATTERN(DynamicSlice)
XLA_VARIADIC_OP_PATTERN(Map)
XLA_VARIADIC_OP_PATTERN(Reduce);
XLA_VARIADIC_OP_PATTERN(Sort);

View File

@ -2099,15 +2099,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
const Shape& start_indices_shape = start_index_shapes[0];
TF_RETURN_IF_ERROR(
ExpectArray(start_indices_shape, "start indices of dynamic slice"));
VLOG(2) << StrFormat(
"slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
ShapeUtil::HumanString(operand_shape),
ShapeUtil::HumanString(start_indices_shape),
StrJoin(slice_sizes, ", "));
TF_RETURN_IF_ERROR(
ExpectArray(start_indices_shape, "start indices of dynamic slice"));
if (start_indices_shape.rank() != 1) {
return InvalidArgument(
"Dynamic slice start indices of rank %d must be rank1.",
@ -2151,7 +2151,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"Dynamic slice start indices must be of integral type.");
}
for (const Shape& index_shape : start_index_shapes) {
if (!ShapeUtil::Equal(first_index_shape, index_shape)) {
if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
return InvalidArgument(
"Dynamic slice start indices must all have the same shape, got "
"mismatching indices with shapes %s and %s.",
@ -2258,7 +2258,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"Dynamic update slice start indices must be of integral type.");
}
for (const Shape& index_shape : start_index_shapes) {
if (!ShapeUtil::Equal(first_index_shape, index_shape)) {
if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
return InvalidArgument(
"Dynamic update slice start indices must all have the same "
"shape, got mismatching indices with shapes %s and %s.",

View File

@ -177,14 +177,14 @@ class ShapeInference {
// in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
static StatusOr<Shape> InferDynamicSliceShape(
const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
absl::Span<const int64> slice_sizes, bool allow_scalar_indices = false);
absl::Span<const int64> slice_sizes, bool allow_scalar_indices = true);
// Infers the shape produced by a dynamic update slice operation based
// on the shape of operand and update.
static StatusOr<Shape> InferDynamicUpdateSliceShape(
const Shape& operand_shape, const Shape& update_shape,
absl::Span<const Shape> start_index_shapes,
bool allow_scalar_indices = false);
bool allow_scalar_indices = true);
// Infers the shape produced by doing a compile-time-constant indexing into
// the given input shape. This is essential for operations on tuples, because

View File

@ -112,6 +112,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
if (compare_layouts) {
if (lhs.layout().format() != rhs.layout().format()) {
VLOG(3) << "CompareShapes: lhs layout format != rhs layout format";
return false;
}
if (LayoutUtil::IsDenseArray(lhs)) {

View File

@ -71,6 +71,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:transfer_manager",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@ -274,16 +275,9 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
Literal MakeRandomIndex(absl::Span<const int64> index_space,
std::minstd_rand0* engine) {
std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
for (int i = 0; i < index_space.size(); ++i) {
std::uniform_int_distribution<int32> generator(0, index_space[i]);
start_indices[i] = generator(*engine);
}
}
return LiteralUtil::CreateR1<int32>(start_indices);
Literal MakeRandomIndex(int64 index_bound, std::minstd_rand0* engine) {
std::uniform_int_distribution<int32> generator(0, index_bound);
return LiteralUtil::CreateR0<int32>(generator(*engine));
}
// Use dataflow analysis on each parameter to see if there are uses that would
@ -300,8 +294,8 @@ std::vector<HloInstruction*> FindConstrainedUses(
HloInstruction* instruction = use.instruction;
const HloOpcode opcode = instruction->opcode();
const int64 op_num = use.operand_number;
if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) ||
(opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) {
if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) ||
(opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) {
constrained_uses.push_back(instruction);
} else if (opcode == HloOpcode::kFusion) {
const HloInstruction* const to_analyze =
@ -336,7 +330,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
StatusOr<Literal> CreateLiteralForConstrainedUses(
const absl::Span<HloInstruction* const> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
std::vector<int64> index_space;
int64 index_bound = INT64_MAX;
bool no_duplicates = false;
bool needs_constant = false;
ConstantType constant_type = ConstantType::kUnknown;
@ -348,19 +342,16 @@ StatusOr<Literal> CreateLiteralForConstrainedUses(
const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
? use->shape()
: use->operand(1)->shape();
const int64 rank = indexed_shape.rank();
if (!index_space.empty()) {
TF_RET_CHECK(rank == index_space.size());
for (int64 i = 0; i < rank; ++i) {
index_space[i] = std::min(
index_space[i], ShapeUtil::GetDimension(indexed_shape, i) -
ShapeUtil::GetDimension(slice_shape, i));
}
} else {
index_space.resize(rank);
for (int64 i = 0; i < rank; ++i) {
index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) -
ShapeUtil::GetDimension(slice_shape, i);
const int64 first_index =
Cast<HloDynamicIndexInstruction>(use)->first_index_operand_number();
for (int64 operand = first_index; operand < use->operand_count();
++operand) {
if (use->operand(operand) == &param) {
index_bound = std::min(
index_bound,
ShapeUtil::GetDimension(indexed_shape, operand - first_index) -
ShapeUtil::GetDimension(slice_shape,
operand - first_index));
}
}
break;
@ -388,16 +379,13 @@ StatusOr<Literal> CreateLiteralForConstrainedUses(
}
int constraint_count = 0;
constraint_count += no_duplicates ? 1 : 0;
constraint_count += !index_space.empty() ? 1 : 0;
constraint_count += (index_bound != INT64_MAX) ? 1 : 0;
constraint_count += needs_constant ? 1 : 0;
if (constraint_count > 1) {
return Unimplemented("Conflicting operand generation constraints.");
}
if (!index_space.empty()) {
// constrained_uses looks through bitcasts, so param and indexed_space may
// not have the same shape. (For example, param might be an R0 while
// indexed_space might have size 1.)
return MakeRandomIndex(index_space, engine)
if (index_bound != INT64_MAX) {
return MakeRandomIndex(index_bound, engine)
.Reshape(param.shape().dimensions());
} else if (needs_constant) {
switch (constant_type) {

View File

@ -79,25 +79,26 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
R"(HloModule index_space_module
ENTRY IndexSpace {
index_param = s32[3]{0} parameter(0)
array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3}
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
index_param.0 = s32[] parameter(0)
index_param.1 = s32[] parameter(1)
index_param.2 = s32[] parameter(2)
array_param.1 = f32[123,4,789]{0,1,2} parameter(3)
array_param.2 = f32[3,3000,5]{0,1,2} parameter(4)
dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={1,2,3}
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={3,2,2}
})")
.ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 3);
const Literal& index_arg = args[0];
ASSERT_EQ(args.size(), 5);
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
EXPECT_EQ(args[0].Get<int32>({}), 0);
EXPECT_GE(index_arg.Get<int32>({1}), 0);
EXPECT_LE(index_arg.Get<int32>({1}), 2);
EXPECT_GE(args[1].Get<int32>({}), 0);
EXPECT_LE(args[0].Get<int32>({}), 2);
EXPECT_GE(index_arg.Get<int32>({2}), 0);
EXPECT_LE(index_arg.Get<int32>({2}), 3);
EXPECT_GE(args[2].Get<int32>({}), 0);
EXPECT_LE(args[2].Get<int32>({}), 3);
}
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
@ -105,28 +106,29 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
R"(HloModule index_space_module
ENTRY IndexSpace {
index_param = s32[3]{0} parameter(0)
array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
update_param.1 = f32[1,2,3]{0,1,2} parameter(3)
update_param.2 = f32[3,2,2]{0,1,2} parameter(4)
index_param.0 = s32[] parameter(0)
index_param.1 = s32[] parameter(1)
index_param.2 = s32[] parameter(2)
array_param.1 = f32[123,4,789]{0,1,2} parameter(3)
array_param.2 = f32[3,3000,5]{0,1,2} parameter(4)
update_param.1 = f32[1,2,3]{0,1,2} parameter(5)
update_param.2 = f32[3,2,2]{0,1,2} parameter(6)
dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param)
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param.0, index_param.1, index_param.2)
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param.0, index_param.1, index_param.2)
})")
.ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 5);
const Literal& index_arg = args[0];
ASSERT_EQ(args.size(), 7);
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
EXPECT_EQ(args[0].Get<int32>({}), 0);
EXPECT_GE(index_arg.Get<int32>({1}), 0);
EXPECT_LE(index_arg.Get<int32>({1}), 2);
EXPECT_GE(args[1].Get<int32>({}), 0);
EXPECT_LE(args[0].Get<int32>({}), 2);
EXPECT_GE(index_arg.Get<int32>({2}), 0);
EXPECT_LE(index_arg.Get<int32>({2}), 3);
EXPECT_GE(args[2].Get<int32>({}), 0);
EXPECT_LE(args[2].Get<int32>({}), 3);
}
XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) {