[XLA] Rollforward of Rolledback CL after fixing mistakes.
*** Reason for rollback of the rolled back CL*** Fix the mistake in the rolled-back CL, by changing a {0} subscript to {} in CopyElementFrom invocation. PiperOrigin-RevId: 340492736 Change-Id: Idd4c8ab143bb10c6c9bbe876d7d4bb2747599553
This commit is contained in:
parent
cbc7f31b7d
commit
d7a91161c0
@ -2486,6 +2486,23 @@ Status HloEvaluator::HandleReduce(HloInstruction* instr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleReduceWindow(HloInstruction* hlo) {
|
||||
// Here we delegate the handling to the typed visitor class, instantiated by
|
||||
// using the type of the first input of ReduceWindow. The support for the
|
||||
// variadic case inside the typed_visitor is made to not use the template
|
||||
// parameter so it doesn't really matter which type is used to instantiate it
|
||||
// here. We choose not to move the implementation for handle ReduceWindow
|
||||
// from the typed visitor to here because we need to reuse the
|
||||
// IterateThroughWindow method, which is defined and only avaiable inside the
|
||||
// typed visitor.
|
||||
if (hlo->shape().IsTuple()) {
|
||||
return hlo->Visit(
|
||||
typed_visitors_[hlo->shape().tuple_shapes(0).element_type()].get());
|
||||
} else {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
|
||||
if (!custom_call_handler_) {
|
||||
// No handler is registered; this means custom-calls are not allowed.
|
||||
|
@ -260,6 +260,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleReduce(HloInstruction* reduce) override;
|
||||
|
||||
Status HandleReduceWindow(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||
|
||||
// Unsupported HLOs, note some of them (such as BatchNorm*) are typically
|
||||
|
@ -2861,6 +2861,109 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
|
||||
}
|
||||
|
||||
TEST_P(HloEvaluatorBf16Test, Min3In5Stride2Tuple) {
|
||||
HloComputation::Builder builder("main");
|
||||
auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
|
||||
auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
|
||||
HloComputation::Builder bcompute("ComputeFunction");
|
||||
auto shape1 = ShapeUtil::MakeShape(F32, {});
|
||||
auto shape2 = ShapeUtil::MakeShape(F32, {});
|
||||
auto p2 =
|
||||
bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
|
||||
auto p3 =
|
||||
bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
|
||||
auto p4 =
|
||||
bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
|
||||
auto p5 =
|
||||
bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
|
||||
std::vector<HloInstruction*> compute_vec = {
|
||||
bcompute.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
|
||||
bcompute.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
|
||||
bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
|
||||
auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
|
||||
std::vector<HloInstruction*> input_vec = {input1, input2};
|
||||
auto init1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
|
||||
auto init2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
|
||||
std::vector<HloInstruction*> init_vec = {init1, init2};
|
||||
auto padding = std::pair<int64, int64>(0, 0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto window,
|
||||
ShapeInference::InferWindowFromDimensions(
|
||||
{3}, {2}, absl::MakeSpan(&padding, 1),
|
||||
/*lhs_dilation=*/{},
|
||||
/*rhs_dilation=*/{}));
|
||||
std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
|
||||
std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape shape,
|
||||
ShapeInference::InferReduceWindowShape(
|
||||
input_shapes, init_shapes, window,
|
||||
compute_tuple->ComputeProgramShape()));
|
||||
builder.AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
shape, input_vec, init_vec, window, compute_tuple));
|
||||
auto r1 = LiteralUtil::CreateR1<float>({100, 1});
|
||||
auto expected = LiteralUtil::MakeTuple({&r1, &r1});
|
||||
m_->AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
TEST_P(HloEvaluatorBf16Test, Min3In5Stride2TupleDiffInput) {
|
||||
HloComputation::Builder builder("main");
|
||||
auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
|
||||
auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<int>({15, 28, 300, 107, 12})));
|
||||
HloComputation::Builder bcompute("ComputeFunction");
|
||||
auto shape1 = ShapeUtil::MakeShape(F32, {});
|
||||
auto shape2 = ShapeUtil::MakeShape(S32, {});
|
||||
auto p2 =
|
||||
bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
|
||||
auto p3 =
|
||||
bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
|
||||
auto p4 =
|
||||
bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
|
||||
auto p5 =
|
||||
bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
|
||||
std::vector<HloInstruction*> compute_vec = {
|
||||
bcompute.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
|
||||
bcompute.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
|
||||
bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
|
||||
auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
|
||||
std::vector<HloInstruction*> input_vec = {input1, input2};
|
||||
auto init1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
|
||||
auto init2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::MaxValue(S32)));
|
||||
std::vector<HloInstruction*> init_vec = {init1, init2};
|
||||
auto padding = std::pair<int64, int64>(0, 0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto window,
|
||||
ShapeInference::InferWindowFromDimensions(
|
||||
{3}, {2}, absl::MakeSpan(&padding, 1),
|
||||
/*lhs_dilation=*/{},
|
||||
/*rhs_dilation=*/{}));
|
||||
std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
|
||||
std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape shape,
|
||||
ShapeInference::InferReduceWindowShape(
|
||||
input_shapes, init_shapes, window,
|
||||
compute_tuple->ComputeProgramShape()));
|
||||
builder.AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
shape, input_vec, init_vec, window, compute_tuple));
|
||||
auto r1 = LiteralUtil::CreateR1<float>({100, 1});
|
||||
auto r2 = LiteralUtil::CreateR1<int>({15, 12});
|
||||
auto expected = LiteralUtil::MakeTuple({&r1, &r2});
|
||||
m_->AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
TEST_P(HloEvaluatorBf16Test, StridedSlice) {
|
||||
HloComputation::Builder b(TestName());
|
||||
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/meta/type_traits.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
@ -664,6 +665,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
|
||||
nullptr>
|
||||
Status HandleMinimum(HloInstruction* minimum) {
|
||||
VLOG(2) << "Evaluating minimum\n";
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum],
|
||||
ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
|
||||
ElementwiseT rhs_el) {
|
||||
@ -1932,18 +1934,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
Status HandleReduceWindow(HloInstruction* reduce_window) override {
|
||||
if (reduce_window->shape().IsTuple()) {
|
||||
return Status(tensorflow::error::UNIMPLEMENTED,
|
||||
"Variadic reduce window op is not yet fully supported.");
|
||||
}
|
||||
auto operand = reduce_window->operand(0);
|
||||
auto* reduce_window_instr = Cast<HloReduceWindowInstruction>(reduce_window);
|
||||
const Window& window = reduce_window->window();
|
||||
HloComputation* function = reduce_window->to_apply();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto inferred_return_shape,
|
||||
ShapeInference::InferReduceWindowShape(
|
||||
/*operand_shape=*/reduce_window->operand(0)->shape(),
|
||||
/*init_value=*/reduce_window->operand(1)->shape(), window,
|
||||
reduce_window_instr->input_array_shapes(),
|
||||
reduce_window_instr->init_value_shapes(), window,
|
||||
/*to_apply_shape=*/function->ComputeProgramShape()));
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
|
||||
@ -1952,62 +1950,101 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
<< " but is inferred to be: "
|
||||
<< ShapeUtil::HumanStringWithLayout(inferred_return_shape);
|
||||
|
||||
const Literal& operand_literal =
|
||||
parent_->GetEvaluatedLiteralFor(reduce_window->operand(0));
|
||||
VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString();
|
||||
const Literal& init_literal =
|
||||
parent_->GetEvaluatedLiteralFor(reduce_window->operand(1));
|
||||
VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString();
|
||||
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
|
||||
auto init_scalar = init_literal.Get<ReturnT>({});
|
||||
|
||||
absl::InlinedVector<const Literal*, 2> input_literal_vec, init_literal_vec;
|
||||
auto input_arrays = reduce_window_instr->input_arrays();
|
||||
auto init_values = reduce_window_instr->init_values();
|
||||
int64 num_args = input_arrays.size();
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
const Literal& input_literal =
|
||||
parent_->GetEvaluatedLiteralFor(input_arrays[i]);
|
||||
VLOG(3) << "HandleReduceWindow arg_literal: " << input_literal.ToString();
|
||||
input_literal_vec.push_back(&input_literal);
|
||||
const Literal& init_literal =
|
||||
parent_->GetEvaluatedLiteralFor(init_values[i]);
|
||||
VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString();
|
||||
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
|
||||
init_literal_vec.push_back(&init_literal);
|
||||
}
|
||||
// Creates a Shape object from window, for iteration below.
|
||||
std::vector<int64> window_dimension_sizes;
|
||||
absl::InlinedVector<int64, 2> window_dimension_sizes;
|
||||
for (const auto& window_dimension : window.dimensions()) {
|
||||
window_dimension_sizes.push_back(window_dimension.size());
|
||||
}
|
||||
const Shape window_shape = ShapeUtil::MakeShape(
|
||||
operand->shape().element_type(), window_dimension_sizes);
|
||||
|
||||
DimensionVector window_index(window.dimensions_size());
|
||||
DimensionVector operand_index(operand_literal.shape().rank());
|
||||
input_arrays[0]->shape().element_type(), window_dimension_sizes);
|
||||
|
||||
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
||||
Literal result(reduce_window->shape());
|
||||
// For each resulting dimension, calculate and assign computed value.
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
|
||||
ReturnT result_val = init_scalar;
|
||||
|
||||
std::fill(window_index.begin(), window_index.end(), 0);
|
||||
std::fill(operand_index.begin(), operand_index.end(), 0);
|
||||
|
||||
IterateThroughWindow(
|
||||
window_shape, window, operand_literal.shape(), output_index,
|
||||
[&](const std::vector<int64>& operand_index) {
|
||||
auto curr_val = operand_literal.Get<ReturnT>(operand_index);
|
||||
|
||||
// Evaluate computation with specified literal operands.
|
||||
const auto curr_val_literal =
|
||||
LiteralUtil::CreateR0<ReturnT>(curr_val);
|
||||
const auto result_val_literal =
|
||||
LiteralUtil::CreateR0<ReturnT>(result_val);
|
||||
Literal computed_result =
|
||||
embedded_evaluator
|
||||
.Evaluate(*function,
|
||||
{&result_val_literal, &curr_val_literal})
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
// Clear visit states so that the we can use the evaluate again
|
||||
// on the same computation.
|
||||
embedded_evaluator.ResetVisitStates();
|
||||
|
||||
result_val = computed_result.Get<ReturnT>({});
|
||||
});
|
||||
|
||||
return result_val;
|
||||
}));
|
||||
|
||||
auto evaluate_impl =
|
||||
[&](absl::Span<const int64> output_index) -> std::vector<Literal> {
|
||||
std::vector<Literal> computed_result;
|
||||
computed_result.reserve(init_literal_vec.size());
|
||||
for (const auto* init : init_literal_vec) {
|
||||
computed_result.push_back(init->Clone());
|
||||
}
|
||||
IterateThroughWindow(
|
||||
window_shape, window, input_literal_vec[0]->shape(), output_index,
|
||||
[&](absl::Span<const int64> operand_index) -> void {
|
||||
absl::InlinedVector<const Literal*, 2> args;
|
||||
for (auto& curr_result_val : computed_result) {
|
||||
VLOG(2) << "Pushing:" << curr_result_val.ToString() << "\n";
|
||||
args.push_back(&curr_result_val);
|
||||
}
|
||||
absl::InlinedVector<Literal, 2> curr_val_literal_vec(
|
||||
input_literal_vec.size());
|
||||
for (const auto* input_literal : input_literal_vec) {
|
||||
// Evaluate computation with specified literal operands.
|
||||
curr_val_literal_vec.push_back(Literal(ShapeUtil::MakeShape(
|
||||
input_literal->shape().element_type(), {})));
|
||||
TF_CHECK_OK(curr_val_literal_vec.back().CopyElementFrom(
|
||||
*input_literal, operand_index, {}));
|
||||
VLOG(2) << "Pushing:" << curr_val_literal_vec.back().ToString()
|
||||
<< "\n";
|
||||
args.push_back(&curr_val_literal_vec.back());
|
||||
}
|
||||
computed_result[0] = embedded_evaluator.Evaluate(*function, args)
|
||||
.ConsumeValueOrDie();
|
||||
VLOG(2) << "Computed result:" << computed_result[0].ToString()
|
||||
<< "\n";
|
||||
// Clear visit states so that the we can use the evaluate again
|
||||
// on the same computation.
|
||||
embedded_evaluator.ResetVisitStates();
|
||||
if (inferred_return_shape.IsTuple()) {
|
||||
computed_result = computed_result[0].DecomposeTuple();
|
||||
}
|
||||
});
|
||||
VLOG(2) << "Final result size:" << computed_result.size() << "\n";
|
||||
for (const auto& res : computed_result) {
|
||||
VLOG(2) << res.ToString() << "\n";
|
||||
}
|
||||
return computed_result;
|
||||
};
|
||||
Literal result(inferred_return_shape);
|
||||
if (inferred_return_shape.IsTuple()) {
|
||||
absl::InlinedVector<Literal, 1> results(num_args);
|
||||
for (int64 i = 0; i < num_args; ++i) {
|
||||
results[i] = Literal(inferred_return_shape.tuple_shapes(i));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
|
||||
inferred_return_shape.tuple_shapes(0),
|
||||
[&](absl::Span<const int64> output_index) -> StatusOr<bool> {
|
||||
std::vector<Literal> computed_result_vec =
|
||||
evaluate_impl(output_index);
|
||||
for (int i = 0; i < computed_result_vec.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(results[i].CopyElementFrom(
|
||||
computed_result_vec[i], {}, output_index));
|
||||
}
|
||||
return true;
|
||||
}));
|
||||
result = Literal::MoveIntoTuple(absl::MakeSpan(results));
|
||||
VLOG(2) << "Final result is:" << result.ToString() << "\n";
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
|
||||
return evaluate_impl(output_index)[0].template Get<ReturnT>({});
|
||||
}));
|
||||
}
|
||||
VLOG(2) << "Final result is:" << result.ToString() << "\n";
|
||||
parent_->evaluated_[reduce_window] = std::move(result);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -881,12 +881,16 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) {
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
|
||||
VLOG(2) << "Verify reduce window:" << reduce_window->ToString() << "\n";
|
||||
auto reduce_window_instr = Cast<HloReduceWindowInstruction>(reduce_window);
|
||||
auto input_shapes = reduce_window_instr->input_array_shapes();
|
||||
VLOG(2) << "reduce window input shape count: " << input_shapes.size() << "\n";
|
||||
auto init_shapes = reduce_window_instr->init_value_shapes();
|
||||
VLOG(2) << "reduce instruction is :" << reduce_window->ToString() << "\n";
|
||||
TF_RETURN_IF_ERROR(CheckShape(
|
||||
reduce_window,
|
||||
ShapeInference::InferReduceWindowShape(
|
||||
reduce_window->operand(0)->shape(),
|
||||
reduce_window->operand(1)->shape(), reduce_window->window(),
|
||||
reduce_window->to_apply()->ComputeProgramShape())));
|
||||
reduce_window, ShapeInference::InferReduceWindowShape(
|
||||
input_shapes, init_shapes, reduce_window->window(),
|
||||
reduce_window->to_apply()->ComputeProgramShape())));
|
||||
|
||||
return allow_mixed_precision_
|
||||
? Status::OK()
|
||||
|
@ -2168,8 +2168,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
|
||||
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
||||
const Window& window, const ProgramShape& to_apply_shape) {
|
||||
absl::Span<const Shape* const> operands,
|
||||
absl::Span<const Shape* const> init_values, const Window& window,
|
||||
const ProgramShape& to_apply_shape) {
|
||||
auto number_of_input = operands.size();
|
||||
// Check that all of the reduced tensors have the same dimensions. The element
|
||||
// types may be different.
|
||||
|
@ -168,8 +168,9 @@ class ShapeInference {
|
||||
const Shape& init_value,
|
||||
const Window& window);
|
||||
static StatusOr<Shape> InferReduceWindowShape(
|
||||
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
||||
const Window& window, const ProgramShape& to_apply_shape);
|
||||
absl::Span<const Shape* const> operands,
|
||||
absl::Span<const Shape* const> init_values, const Window& window,
|
||||
const ProgramShape& to_apply_shape);
|
||||
|
||||
static StatusOr<Shape> InferReduceWindowShape(
|
||||
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
||||
|
Loading…
Reference in New Issue
Block a user