[XLA] Rollforward of Rolledback CL after fixing mistakes.
*** Reason for rollback of the rolled back CL*** Fix the mistake in the rolled-back CL, by changing a {0} subscript to {} in CopyElementFrom invocation. PiperOrigin-RevId: 340492736 Change-Id: Idd4c8ab143bb10c6c9bbe876d7d4bb2747599553
This commit is contained in:
parent
cbc7f31b7d
commit
d7a91161c0
@ -2486,6 +2486,23 @@ Status HloEvaluator::HandleReduce(HloInstruction* instr) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status HloEvaluator::HandleReduceWindow(HloInstruction* hlo) {
|
||||||
|
// Here we delegate the handling to the typed visitor class, instantiated by
|
||||||
|
// using the type of the first input of ReduceWindow. The support for the
|
||||||
|
// variadic case inside the typed_visitor is made to not use the template
|
||||||
|
// parameter so it doesn't really matter which type is used to instantiate it
|
||||||
|
// here. We choose not to move the implementation for handle ReduceWindow
|
||||||
|
// from the typed visitor to here because we need to reuse the
|
||||||
|
// IterateThroughWindow method, which is defined and only avaiable inside the
|
||||||
|
// typed visitor.
|
||||||
|
if (hlo->shape().IsTuple()) {
|
||||||
|
return hlo->Visit(
|
||||||
|
typed_visitors_[hlo->shape().tuple_shapes(0).element_type()].get());
|
||||||
|
} else {
|
||||||
|
return DefaultAction(hlo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
|
Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
|
||||||
if (!custom_call_handler_) {
|
if (!custom_call_handler_) {
|
||||||
// No handler is registered; this means custom-calls are not allowed.
|
// No handler is registered; this means custom-calls are not allowed.
|
||||||
|
@ -260,6 +260,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
Status HandleReduce(HloInstruction* reduce) override;
|
Status HandleReduce(HloInstruction* reduce) override;
|
||||||
|
|
||||||
|
Status HandleReduceWindow(HloInstruction* hlo) override;
|
||||||
|
|
||||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||||
|
|
||||||
// Unsupported HLOs, note some of them (such as BatchNorm*) are typically
|
// Unsupported HLOs, note some of them (such as BatchNorm*) are typically
|
||||||
|
@ -2861,6 +2861,109 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) {
|
|||||||
EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
|
EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(HloEvaluatorBf16Test, Min3In5Stride2Tuple) {
|
||||||
|
HloComputation::Builder builder("main");
|
||||||
|
auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
|
||||||
|
auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
|
||||||
|
HloComputation::Builder bcompute("ComputeFunction");
|
||||||
|
auto shape1 = ShapeUtil::MakeShape(F32, {});
|
||||||
|
auto shape2 = ShapeUtil::MakeShape(F32, {});
|
||||||
|
auto p2 =
|
||||||
|
bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
|
||||||
|
auto p3 =
|
||||||
|
bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
|
||||||
|
auto p4 =
|
||||||
|
bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
|
||||||
|
auto p5 =
|
||||||
|
bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
|
||||||
|
std::vector<HloInstruction*> compute_vec = {
|
||||||
|
bcompute.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
|
||||||
|
bcompute.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
|
||||||
|
bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
|
||||||
|
auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
|
||||||
|
std::vector<HloInstruction*> input_vec = {input1, input2};
|
||||||
|
auto init1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
|
||||||
|
auto init2 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
|
||||||
|
std::vector<HloInstruction*> init_vec = {init1, init2};
|
||||||
|
auto padding = std::pair<int64, int64>(0, 0);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto window,
|
||||||
|
ShapeInference::InferWindowFromDimensions(
|
||||||
|
{3}, {2}, absl::MakeSpan(&padding, 1),
|
||||||
|
/*lhs_dilation=*/{},
|
||||||
|
/*rhs_dilation=*/{}));
|
||||||
|
std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
|
||||||
|
std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(Shape shape,
|
||||||
|
ShapeInference::InferReduceWindowShape(
|
||||||
|
input_shapes, init_shapes, window,
|
||||||
|
compute_tuple->ComputeProgramShape()));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateReduceWindow(
|
||||||
|
shape, input_vec, init_vec, window, compute_tuple));
|
||||||
|
auto r1 = LiteralUtil::CreateR1<float>({100, 1});
|
||||||
|
auto expected = LiteralUtil::MakeTuple({&r1, &r1});
|
||||||
|
m_->AddEntryComputation(builder.Build());
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(HloEvaluatorBf16Test, Min3In5Stride2TupleDiffInput) {
|
||||||
|
HloComputation::Builder builder("main");
|
||||||
|
auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
|
||||||
|
auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::CreateR1<int>({15, 28, 300, 107, 12})));
|
||||||
|
HloComputation::Builder bcompute("ComputeFunction");
|
||||||
|
auto shape1 = ShapeUtil::MakeShape(F32, {});
|
||||||
|
auto shape2 = ShapeUtil::MakeShape(S32, {});
|
||||||
|
auto p2 =
|
||||||
|
bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
|
||||||
|
auto p3 =
|
||||||
|
bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
|
||||||
|
auto p4 =
|
||||||
|
bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
|
||||||
|
auto p5 =
|
||||||
|
bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
|
||||||
|
std::vector<HloInstruction*> compute_vec = {
|
||||||
|
bcompute.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
|
||||||
|
bcompute.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
|
||||||
|
bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
|
||||||
|
auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
|
||||||
|
std::vector<HloInstruction*> input_vec = {input1, input2};
|
||||||
|
auto init1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
|
||||||
|
auto init2 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::MaxValue(S32)));
|
||||||
|
std::vector<HloInstruction*> init_vec = {init1, init2};
|
||||||
|
auto padding = std::pair<int64, int64>(0, 0);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto window,
|
||||||
|
ShapeInference::InferWindowFromDimensions(
|
||||||
|
{3}, {2}, absl::MakeSpan(&padding, 1),
|
||||||
|
/*lhs_dilation=*/{},
|
||||||
|
/*rhs_dilation=*/{}));
|
||||||
|
std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
|
||||||
|
std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(Shape shape,
|
||||||
|
ShapeInference::InferReduceWindowShape(
|
||||||
|
input_shapes, init_shapes, window,
|
||||||
|
compute_tuple->ComputeProgramShape()));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateReduceWindow(
|
||||||
|
shape, input_vec, init_vec, window, compute_tuple));
|
||||||
|
auto r1 = LiteralUtil::CreateR1<float>({100, 1});
|
||||||
|
auto r2 = LiteralUtil::CreateR1<int>({15, 12});
|
||||||
|
auto expected = LiteralUtil::MakeTuple({&r1, &r2});
|
||||||
|
m_->AddEntryComputation(builder.Build());
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorBf16Test, StridedSlice) {
|
TEST_P(HloEvaluatorBf16Test, StridedSlice) {
|
||||||
HloComputation::Builder b(TestName());
|
HloComputation::Builder b(TestName());
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/meta/type_traits.h"
|
#include "absl/meta/type_traits.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/array2d.h"
|
#include "tensorflow/compiler/xla/array2d.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
@ -664,6 +665,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
|
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
|
||||||
nullptr>
|
nullptr>
|
||||||
Status HandleMinimum(HloInstruction* minimum) {
|
Status HandleMinimum(HloInstruction* minimum) {
|
||||||
|
VLOG(2) << "Evaluating minimum\n";
|
||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum],
|
||||||
ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
|
ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
|
||||||
ElementwiseT rhs_el) {
|
ElementwiseT rhs_el) {
|
||||||
@ -1932,18 +1934,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status HandleReduceWindow(HloInstruction* reduce_window) override {
|
Status HandleReduceWindow(HloInstruction* reduce_window) override {
|
||||||
if (reduce_window->shape().IsTuple()) {
|
auto* reduce_window_instr = Cast<HloReduceWindowInstruction>(reduce_window);
|
||||||
return Status(tensorflow::error::UNIMPLEMENTED,
|
|
||||||
"Variadic reduce window op is not yet fully supported.");
|
|
||||||
}
|
|
||||||
auto operand = reduce_window->operand(0);
|
|
||||||
const Window& window = reduce_window->window();
|
const Window& window = reduce_window->window();
|
||||||
HloComputation* function = reduce_window->to_apply();
|
HloComputation* function = reduce_window->to_apply();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto inferred_return_shape,
|
auto inferred_return_shape,
|
||||||
ShapeInference::InferReduceWindowShape(
|
ShapeInference::InferReduceWindowShape(
|
||||||
/*operand_shape=*/reduce_window->operand(0)->shape(),
|
reduce_window_instr->input_array_shapes(),
|
||||||
/*init_value=*/reduce_window->operand(1)->shape(), window,
|
reduce_window_instr->init_value_shapes(), window,
|
||||||
/*to_apply_shape=*/function->ComputeProgramShape()));
|
/*to_apply_shape=*/function->ComputeProgramShape()));
|
||||||
TF_RET_CHECK(
|
TF_RET_CHECK(
|
||||||
ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
|
ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
|
||||||
@ -1952,62 +1950,101 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
<< " but is inferred to be: "
|
<< " but is inferred to be: "
|
||||||
<< ShapeUtil::HumanStringWithLayout(inferred_return_shape);
|
<< ShapeUtil::HumanStringWithLayout(inferred_return_shape);
|
||||||
|
|
||||||
const Literal& operand_literal =
|
absl::InlinedVector<const Literal*, 2> input_literal_vec, init_literal_vec;
|
||||||
parent_->GetEvaluatedLiteralFor(reduce_window->operand(0));
|
auto input_arrays = reduce_window_instr->input_arrays();
|
||||||
VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString();
|
auto init_values = reduce_window_instr->init_values();
|
||||||
const Literal& init_literal =
|
int64 num_args = input_arrays.size();
|
||||||
parent_->GetEvaluatedLiteralFor(reduce_window->operand(1));
|
for (int i = 0; i < num_args; ++i) {
|
||||||
VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString();
|
const Literal& input_literal =
|
||||||
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
|
parent_->GetEvaluatedLiteralFor(input_arrays[i]);
|
||||||
auto init_scalar = init_literal.Get<ReturnT>({});
|
VLOG(3) << "HandleReduceWindow arg_literal: " << input_literal.ToString();
|
||||||
|
input_literal_vec.push_back(&input_literal);
|
||||||
|
const Literal& init_literal =
|
||||||
|
parent_->GetEvaluatedLiteralFor(init_values[i]);
|
||||||
|
VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString();
|
||||||
|
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
|
||||||
|
init_literal_vec.push_back(&init_literal);
|
||||||
|
}
|
||||||
// Creates a Shape object from window, for iteration below.
|
// Creates a Shape object from window, for iteration below.
|
||||||
std::vector<int64> window_dimension_sizes;
|
absl::InlinedVector<int64, 2> window_dimension_sizes;
|
||||||
for (const auto& window_dimension : window.dimensions()) {
|
for (const auto& window_dimension : window.dimensions()) {
|
||||||
window_dimension_sizes.push_back(window_dimension.size());
|
window_dimension_sizes.push_back(window_dimension.size());
|
||||||
}
|
}
|
||||||
const Shape window_shape = ShapeUtil::MakeShape(
|
const Shape window_shape = ShapeUtil::MakeShape(
|
||||||
operand->shape().element_type(), window_dimension_sizes);
|
input_arrays[0]->shape().element_type(), window_dimension_sizes);
|
||||||
|
|
||||||
DimensionVector window_index(window.dimensions_size());
|
|
||||||
DimensionVector operand_index(operand_literal.shape().rank());
|
|
||||||
|
|
||||||
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
||||||
Literal result(reduce_window->shape());
|
|
||||||
// For each resulting dimension, calculate and assign computed value.
|
// For each resulting dimension, calculate and assign computed value.
|
||||||
TF_RETURN_IF_ERROR(
|
auto evaluate_impl =
|
||||||
result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
|
[&](absl::Span<const int64> output_index) -> std::vector<Literal> {
|
||||||
ReturnT result_val = init_scalar;
|
std::vector<Literal> computed_result;
|
||||||
|
computed_result.reserve(init_literal_vec.size());
|
||||||
std::fill(window_index.begin(), window_index.end(), 0);
|
for (const auto* init : init_literal_vec) {
|
||||||
std::fill(operand_index.begin(), operand_index.end(), 0);
|
computed_result.push_back(init->Clone());
|
||||||
|
}
|
||||||
IterateThroughWindow(
|
IterateThroughWindow(
|
||||||
window_shape, window, operand_literal.shape(), output_index,
|
window_shape, window, input_literal_vec[0]->shape(), output_index,
|
||||||
[&](const std::vector<int64>& operand_index) {
|
[&](absl::Span<const int64> operand_index) -> void {
|
||||||
auto curr_val = operand_literal.Get<ReturnT>(operand_index);
|
absl::InlinedVector<const Literal*, 2> args;
|
||||||
|
for (auto& curr_result_val : computed_result) {
|
||||||
// Evaluate computation with specified literal operands.
|
VLOG(2) << "Pushing:" << curr_result_val.ToString() << "\n";
|
||||||
const auto curr_val_literal =
|
args.push_back(&curr_result_val);
|
||||||
LiteralUtil::CreateR0<ReturnT>(curr_val);
|
}
|
||||||
const auto result_val_literal =
|
absl::InlinedVector<Literal, 2> curr_val_literal_vec(
|
||||||
LiteralUtil::CreateR0<ReturnT>(result_val);
|
input_literal_vec.size());
|
||||||
Literal computed_result =
|
for (const auto* input_literal : input_literal_vec) {
|
||||||
embedded_evaluator
|
// Evaluate computation with specified literal operands.
|
||||||
.Evaluate(*function,
|
curr_val_literal_vec.push_back(Literal(ShapeUtil::MakeShape(
|
||||||
{&result_val_literal, &curr_val_literal})
|
input_literal->shape().element_type(), {})));
|
||||||
.ConsumeValueOrDie();
|
TF_CHECK_OK(curr_val_literal_vec.back().CopyElementFrom(
|
||||||
|
*input_literal, operand_index, {}));
|
||||||
// Clear visit states so that the we can use the evaluate again
|
VLOG(2) << "Pushing:" << curr_val_literal_vec.back().ToString()
|
||||||
// on the same computation.
|
<< "\n";
|
||||||
embedded_evaluator.ResetVisitStates();
|
args.push_back(&curr_val_literal_vec.back());
|
||||||
|
}
|
||||||
result_val = computed_result.Get<ReturnT>({});
|
computed_result[0] = embedded_evaluator.Evaluate(*function, args)
|
||||||
});
|
.ConsumeValueOrDie();
|
||||||
|
VLOG(2) << "Computed result:" << computed_result[0].ToString()
|
||||||
return result_val;
|
<< "\n";
|
||||||
}));
|
// Clear visit states so that the we can use the evaluate again
|
||||||
|
// on the same computation.
|
||||||
|
embedded_evaluator.ResetVisitStates();
|
||||||
|
if (inferred_return_shape.IsTuple()) {
|
||||||
|
computed_result = computed_result[0].DecomposeTuple();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
VLOG(2) << "Final result size:" << computed_result.size() << "\n";
|
||||||
|
for (const auto& res : computed_result) {
|
||||||
|
VLOG(2) << res.ToString() << "\n";
|
||||||
|
}
|
||||||
|
return computed_result;
|
||||||
|
};
|
||||||
|
Literal result(inferred_return_shape);
|
||||||
|
if (inferred_return_shape.IsTuple()) {
|
||||||
|
absl::InlinedVector<Literal, 1> results(num_args);
|
||||||
|
for (int64 i = 0; i < num_args; ++i) {
|
||||||
|
results[i] = Literal(inferred_return_shape.tuple_shapes(i));
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
|
||||||
|
inferred_return_shape.tuple_shapes(0),
|
||||||
|
[&](absl::Span<const int64> output_index) -> StatusOr<bool> {
|
||||||
|
std::vector<Literal> computed_result_vec =
|
||||||
|
evaluate_impl(output_index);
|
||||||
|
for (int i = 0; i < computed_result_vec.size(); ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(results[i].CopyElementFrom(
|
||||||
|
computed_result_vec[i], {}, output_index));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}));
|
||||||
|
result = Literal::MoveIntoTuple(absl::MakeSpan(results));
|
||||||
|
VLOG(2) << "Final result is:" << result.ToString() << "\n";
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
|
||||||
|
return evaluate_impl(output_index)[0].template Get<ReturnT>({});
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
VLOG(2) << "Final result is:" << result.ToString() << "\n";
|
||||||
parent_->evaluated_[reduce_window] = std::move(result);
|
parent_->evaluated_[reduce_window] = std::move(result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -881,12 +881,16 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
|
Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
|
||||||
|
VLOG(2) << "Verify reduce window:" << reduce_window->ToString() << "\n";
|
||||||
|
auto reduce_window_instr = Cast<HloReduceWindowInstruction>(reduce_window);
|
||||||
|
auto input_shapes = reduce_window_instr->input_array_shapes();
|
||||||
|
VLOG(2) << "reduce window input shape count: " << input_shapes.size() << "\n";
|
||||||
|
auto init_shapes = reduce_window_instr->init_value_shapes();
|
||||||
|
VLOG(2) << "reduce instruction is :" << reduce_window->ToString() << "\n";
|
||||||
TF_RETURN_IF_ERROR(CheckShape(
|
TF_RETURN_IF_ERROR(CheckShape(
|
||||||
reduce_window,
|
reduce_window, ShapeInference::InferReduceWindowShape(
|
||||||
ShapeInference::InferReduceWindowShape(
|
input_shapes, init_shapes, reduce_window->window(),
|
||||||
reduce_window->operand(0)->shape(),
|
reduce_window->to_apply()->ComputeProgramShape())));
|
||||||
reduce_window->operand(1)->shape(), reduce_window->window(),
|
|
||||||
reduce_window->to_apply()->ComputeProgramShape())));
|
|
||||||
|
|
||||||
return allow_mixed_precision_
|
return allow_mixed_precision_
|
||||||
? Status::OK()
|
? Status::OK()
|
||||||
|
@ -2168,8 +2168,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
|
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
|
||||||
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
absl::Span<const Shape* const> operands,
|
||||||
const Window& window, const ProgramShape& to_apply_shape) {
|
absl::Span<const Shape* const> init_values, const Window& window,
|
||||||
|
const ProgramShape& to_apply_shape) {
|
||||||
auto number_of_input = operands.size();
|
auto number_of_input = operands.size();
|
||||||
// Check that all of the reduced tensors have the same dimensions. The element
|
// Check that all of the reduced tensors have the same dimensions. The element
|
||||||
// types may be different.
|
// types may be different.
|
||||||
|
@ -168,8 +168,9 @@ class ShapeInference {
|
|||||||
const Shape& init_value,
|
const Shape& init_value,
|
||||||
const Window& window);
|
const Window& window);
|
||||||
static StatusOr<Shape> InferReduceWindowShape(
|
static StatusOr<Shape> InferReduceWindowShape(
|
||||||
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
absl::Span<const Shape* const> operands,
|
||||||
const Window& window, const ProgramShape& to_apply_shape);
|
absl::Span<const Shape* const> init_values, const Window& window,
|
||||||
|
const ProgramShape& to_apply_shape);
|
||||||
|
|
||||||
static StatusOr<Shape> InferReduceWindowShape(
|
static StatusOr<Shape> InferReduceWindowShape(
|
||||||
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user