[XLA] Clean up clang tidy readability warnings in compiler/xla
* lambda capture 'builder' is not used * using decl 'Printf' is unused * lambda capture 'this' is not used (17 times) * lambda capture 'buffer_liveness' is not used * lambda capture 'computation' is not used * lambda capture 'operand_to_generator' is not used * lambda capture 'M' is not used * using decl 'InvalidParameterArgument' is unused * lambda capture 'sum' is not used * lambda capture 's' is not used * lambda capture 'epsilon' is not used PiperOrigin-RevId: 207542895
This commit is contained in:
parent
e70f94ee08
commit
2b8df9f406
@ -56,7 +56,7 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
|||||||
|
|
||||||
// Performs a single round of the Threefry2x32 algorithm, with a rotation
|
// Performs a single round of the Threefry2x32 algorithm, with a rotation
|
||||||
// amount 'rotation'.
|
// amount 'rotation'.
|
||||||
auto round = [builder](ThreeFry2x32State v, int rotation) {
|
auto round = [](ThreeFry2x32State v, int rotation) {
|
||||||
v[0] = v[0] + v[1];
|
v[0] = v[0] + v[1];
|
||||||
v[1] = RotateLeftS32(v[1], rotation);
|
v[1] = RotateLeftS32(v[1], rotation);
|
||||||
v[1] = v[0] ^ v[1];
|
v[1] = v[0] ^ v[1];
|
||||||
|
@ -36,7 +36,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
using tensorflow::strings::Printf;
|
|
||||||
using tensorflow::strings::StrCat;
|
using tensorflow::strings::StrCat;
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
@ -2006,7 +2006,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
|
|||||||
// Builds a convolution from <options> and runs algebraic simplification on
|
// Builds a convolution from <options> and runs algebraic simplification on
|
||||||
// the computation. Returns a string description of the result of
|
// the computation. Returns a string description of the result of
|
||||||
// simplification.
|
// simplification.
|
||||||
auto build_and_simplify = [&options, this]() -> string {
|
auto build_and_simplify = [&options]() -> string {
|
||||||
HloComputation::Builder b(TestName());
|
HloComputation::Builder b(TestName());
|
||||||
|
|
||||||
Window window;
|
Window window;
|
||||||
|
@ -109,11 +109,11 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
|
|||||||
ResolveInternal(data));
|
ResolveInternal(data));
|
||||||
for (const auto& shaped_buffer : replicated_buffers) {
|
for (const auto& shaped_buffer : replicated_buffers) {
|
||||||
std::vector<ShapeIndex> shape_indices;
|
std::vector<ShapeIndex> shape_indices;
|
||||||
ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(),
|
ShapeUtil::ForEachSubshape(
|
||||||
[this, &shape_indices](const Shape& /*subshape*/,
|
shaped_buffer->on_device_shape(),
|
||||||
const ShapeIndex& index) {
|
[&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||||
shape_indices.push_back(index);
|
shape_indices.push_back(index);
|
||||||
});
|
});
|
||||||
for (const ShapeIndex& index : shape_indices) {
|
for (const ShapeIndex& index : shape_indices) {
|
||||||
TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
|
TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
|
||||||
shaped_buffer->device_ordinal()));
|
shaped_buffer->device_ordinal()));
|
||||||
|
@ -877,8 +877,8 @@ Status BufferAssigner::AssignBuffersForComputation(
|
|||||||
// important reuse case where an elementwise instruction reuses one of its
|
// important reuse case where an elementwise instruction reuses one of its
|
||||||
// operand's buffer. This improves locality.
|
// operand's buffer. This improves locality.
|
||||||
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
|
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
|
||||||
[this, has_sequential_order, &liveness, &post_order_position,
|
[has_sequential_order, &liveness, &post_order_position, assignment](
|
||||||
assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
|
const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||||
// Primary sort is by decreasing buffer size.
|
// Primary sort is by decreasing buffer size.
|
||||||
const int64 a_size = assignment->buffer_size_(*a);
|
const int64 a_size = assignment->buffer_size_(*a);
|
||||||
const int64 b_size = assignment->buffer_size_(*b);
|
const int64 b_size = assignment->buffer_size_(*b);
|
||||||
@ -1441,9 +1441,9 @@ void BufferAssigner::BuildColocatedBufferSets(
|
|||||||
const HloInstruction* while_hlo = instruction;
|
const HloInstruction* while_hlo = instruction;
|
||||||
ShapeUtil::ForEachSubshape(
|
ShapeUtil::ForEachSubshape(
|
||||||
while_hlo->shape(),
|
while_hlo->shape(),
|
||||||
[this, while_hlo, &points_to_analysis, &buffer_liveness,
|
[this, while_hlo, &points_to_analysis, buffer_size,
|
||||||
buffer_size, computation, colocated_buffer_sets](
|
colocated_buffer_sets](const Shape& /*subshape*/,
|
||||||
const Shape& /*subshape*/, const ShapeIndex& index) {
|
const ShapeIndex& index) {
|
||||||
std::vector<const LogicalBuffer*> colocated_set;
|
std::vector<const LogicalBuffer*> colocated_set;
|
||||||
// Add while.init.
|
// Add while.init.
|
||||||
AddBufferToColocatedSet(while_hlo->operand(0), index,
|
AddBufferToColocatedSet(while_hlo->operand(0), index,
|
||||||
|
@ -220,7 +220,7 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
|
|||||||
// The body adds the reduced value of the Infeed data (first tuple element)
|
// The body adds the reduced value of the Infeed data (first tuple element)
|
||||||
// to the previous accumulator, and returns the accumulator and the continue
|
// to the previous accumulator, and returns the accumulator and the continue
|
||||||
// flag (second tuple element) as a tuple.
|
// flag (second tuple element) as a tuple.
|
||||||
const auto build_body = [this, &result_shape](const Shape& infeed_shape) {
|
const auto build_body = [&result_shape](const Shape& infeed_shape) {
|
||||||
XlaComputation body;
|
XlaComputation body;
|
||||||
XlaBuilder builder("body");
|
XlaBuilder builder("body");
|
||||||
auto prev = Parameter(&builder, 0, result_shape, "prev");
|
auto prev = Parameter(&builder, 0, result_shape, "prev");
|
||||||
|
@ -2134,7 +2134,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
|
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
|
||||||
};
|
};
|
||||||
default:
|
default:
|
||||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
return [hlo](const IrArray::Index& index) {
|
||||||
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
|
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
|
||||||
HloOpcodeString(hlo->opcode()).c_str());
|
HloOpcodeString(hlo->opcode()).c_str());
|
||||||
};
|
};
|
||||||
|
@ -293,7 +293,7 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
|
|||||||
// the respective location in ShapedBuffer.
|
// the respective location in ShapedBuffer.
|
||||||
std::set<se::DeviceMemoryBase> buffers_in_result;
|
std::set<se::DeviceMemoryBase> buffers_in_result;
|
||||||
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus(
|
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus(
|
||||||
[&buffer_allocations, &buffers_in_result, &shaped_buffer, this](
|
[&buffer_allocations, &buffers_in_result, this](
|
||||||
const ShapeIndex& index, se::DeviceMemoryBase* device_memory) {
|
const ShapeIndex& index, se::DeviceMemoryBase* device_memory) {
|
||||||
const auto& sources = this->GetRootPointsToSet().element(index);
|
const auto& sources = this->GetRootPointsToSet().element(index);
|
||||||
// The points-to set is unambiguous so the set should be a
|
// The points-to set is unambiguous so the set should be a
|
||||||
|
@ -328,7 +328,7 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module,
|
|||||||
if (linker.linkInModule(
|
if (linker.linkInModule(
|
||||||
std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded,
|
std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded,
|
||||||
[](Module& M, const StringSet<>& GVS) {
|
[](Module& M, const StringSet<>& GVS) {
|
||||||
internalizeModule(M, [&M, &GVS](const GlobalValue& GV) {
|
internalizeModule(M, [&GVS](const GlobalValue& GV) {
|
||||||
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
|
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
|
||||||
});
|
});
|
||||||
})) {
|
})) {
|
||||||
|
@ -533,12 +533,12 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
|
|||||||
// TODO(b/33004697): Compute correct cost here, taking the actual number of
|
// TODO(b/33004697): Compute correct cost here, taking the actual number of
|
||||||
// replicas into account.
|
// replicas into account.
|
||||||
double flops = 0.0;
|
double flops = 0.0;
|
||||||
ShapeUtil::ForEachSubshape(
|
ShapeUtil::ForEachSubshape(crs->shape(),
|
||||||
crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) {
|
[&](const Shape& subshape, const ShapeIndex&) {
|
||||||
if (ShapeUtil::IsArray(subshape)) {
|
if (ShapeUtil::IsArray(subshape)) {
|
||||||
flops += ShapeUtil::ElementsIn(subshape);
|
flops += ShapeUtil::ElementsIn(subshape);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
current_properties_[kFlopsKey] = flops;
|
current_properties_[kFlopsKey] = flops;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -2365,7 +2365,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
|
|||||||
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
||||||
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
||||||
|
|
||||||
auto make_cond = [this, &data_shape]() {
|
auto make_cond = [&data_shape]() {
|
||||||
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
||||||
auto data = builder.AddInstruction(
|
auto data = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, data_shape, "data"));
|
HloInstruction::CreateParameter(0, data_shape, "data"));
|
||||||
@ -2374,7 +2374,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
|||||||
return builder.Build();
|
return builder.Build();
|
||||||
};
|
};
|
||||||
|
|
||||||
auto make_body = [this, &data_shape]() {
|
auto make_body = [&data_shape]() {
|
||||||
auto builder = HloComputation::Builder(TestName() + ".Body");
|
auto builder = HloComputation::Builder(TestName() + ".Body");
|
||||||
auto data = builder.AddInstruction(
|
auto data = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, data_shape, "data"));
|
HloInstruction::CreateParameter(0, data_shape, "data"));
|
||||||
|
@ -282,7 +282,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
SequentialHloOrdering::HloModuleSequence sequence,
|
SequentialHloOrdering::HloModuleSequence sequence,
|
||||||
ScheduleComputationsInModule(*module,
|
ScheduleComputationsInModule(*module,
|
||||||
[&TUPLE_SIZE](const BufferValue& buffer) {
|
[](const BufferValue& buffer) {
|
||||||
return ShapeUtil::ByteSizeOf(
|
return ShapeUtil::ByteSizeOf(
|
||||||
buffer.shape(), TUPLE_SIZE);
|
buffer.shape(), TUPLE_SIZE);
|
||||||
},
|
},
|
||||||
|
@ -283,8 +283,7 @@ std::ostream& operator<<(std::ostream& out,
|
|||||||
string InstructionValueSet::ToString() const {
|
string InstructionValueSet::ToString() const {
|
||||||
string out =
|
string out =
|
||||||
StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
|
StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
|
||||||
ForEachElement([this, &out](const ShapeIndex& index,
|
ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) {
|
||||||
const HloValueSet& value_set) {
|
|
||||||
StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
|
StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
|
||||||
});
|
});
|
||||||
return out;
|
return out;
|
||||||
|
@ -1228,7 +1228,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs(
|
|||||||
const PointsToSet& points_to_set =
|
const PointsToSet& points_to_set =
|
||||||
constraints->points_to_analysis().GetPointsToSet(instruction);
|
constraints->points_to_analysis().GetPointsToSet(instruction);
|
||||||
return points_to_set.ForEachElementWithStatus(
|
return points_to_set.ForEachElementWithStatus(
|
||||||
[this, &shape_layout, constraints](
|
[&shape_layout, constraints](
|
||||||
const ShapeIndex& index,
|
const ShapeIndex& index,
|
||||||
const PointsToSet::BufferList& buffers) -> Status {
|
const PointsToSet::BufferList& buffers) -> Status {
|
||||||
if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
|
if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
|
||||||
|
@ -56,7 +56,6 @@ limitations under the License.
|
|||||||
|
|
||||||
using ::tensorflow::strings::Printf;
|
using ::tensorflow::strings::Printf;
|
||||||
using ::tensorflow::strings::StrCat;
|
using ::tensorflow::strings::StrCat;
|
||||||
using ::xla::source_map_util::InvalidParameterArgument;
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
@ -232,8 +232,7 @@ Status TuplePointsToAnalysis::HandleGetTupleElement(
|
|||||||
// Copy the points-to set (and tuple sources) at index {element_index} of the
|
// Copy the points-to set (and tuple sources) at index {element_index} of the
|
||||||
// operand to the points-to set for this GetTupleElement instruction.
|
// operand to the points-to set for this GetTupleElement instruction.
|
||||||
points_to_set.ForEachMutableElement(
|
points_to_set.ForEachMutableElement(
|
||||||
[&, this](const ShapeIndex& target_index,
|
[&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
|
||||||
PointsToSet::BufferList* points_to) {
|
|
||||||
// Construct an index into the operand by prepending element_index to
|
// Construct an index into the operand by prepending element_index to
|
||||||
// the index for the GetTupleElement instruction's points-to set.
|
// the index for the GetTupleElement instruction's points-to set.
|
||||||
ShapeIndex src_index;
|
ShapeIndex src_index;
|
||||||
@ -308,7 +307,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
|
|||||||
// Recursively copy the points to set of the operand tuple {0} to the output
|
// Recursively copy the points to set of the operand tuple {0} to the output
|
||||||
// element {0}.
|
// element {0}.
|
||||||
points_to_set.ForEachMutableElement(
|
points_to_set.ForEachMutableElement(
|
||||||
[this, &points_to_set, &operand_points_to_set](
|
[&points_to_set, &operand_points_to_set](
|
||||||
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
|
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
|
||||||
if (index.empty() || index[0] != 0) {
|
if (index.empty() || index[0] != 0) {
|
||||||
return;
|
return;
|
||||||
@ -517,7 +516,7 @@ Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
|
|||||||
const HloInstruction* instruction,
|
const HloInstruction* instruction,
|
||||||
TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
|
TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
|
||||||
GetPointsToSet(instruction)
|
GetPointsToSet(instruction)
|
||||||
.ForEachElement([this, buffers, instruction](
|
.ForEachElement([buffers, instruction](
|
||||||
const ShapeIndex& index,
|
const ShapeIndex& index,
|
||||||
const PointsToSet::BufferList& source_buffers) {
|
const PointsToSet::BufferList& source_buffers) {
|
||||||
// Add buffers which 'instruction' is the source of.
|
// Add buffers which 'instruction' is the source of.
|
||||||
@ -547,7 +546,7 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
|
|||||||
PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
|
PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
|
||||||
const PointsToSet& src_points_to_set = GetPointsToSet(src);
|
const PointsToSet& src_points_to_set = GetPointsToSet(src);
|
||||||
dst_points_to_set.ForEachMutableElement(
|
dst_points_to_set.ForEachMutableElement(
|
||||||
[this, &dst_points_to_set, &src_points_to_set](
|
[&dst_points_to_set, &src_points_to_set](
|
||||||
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
|
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
|
||||||
*buffers = src_points_to_set.element(index);
|
*buffers = src_points_to_set.element(index);
|
||||||
for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
|
for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
|
||||||
|
@ -1118,7 +1118,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
|
|||||||
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
||||||
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
||||||
|
|
||||||
auto make_cond = [this, &data_shape]() {
|
auto make_cond = [&data_shape]() {
|
||||||
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
||||||
auto data = builder.AddInstruction(
|
auto data = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, data_shape, "data"));
|
HloInstruction::CreateParameter(0, data_shape, "data"));
|
||||||
@ -1127,7 +1127,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
|||||||
return builder.Build();
|
return builder.Build();
|
||||||
};
|
};
|
||||||
|
|
||||||
auto make_body = [this, &data_shape]() {
|
auto make_body = [&data_shape]() {
|
||||||
auto builder = HloComputation::Builder(TestName() + ".Body");
|
auto builder = HloComputation::Builder(TestName() + ".Body");
|
||||||
auto data = builder.AddInstruction(
|
auto data = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, data_shape, "data"));
|
HloInstruction::CreateParameter(0, data_shape, "data"));
|
||||||
|
@ -172,7 +172,7 @@ TEST_F(ShapeTreeTest, TupleShape) {
|
|||||||
|
|
||||||
// Write zero to all data elements.
|
// Write zero to all data elements.
|
||||||
shape_tree.ForEachMutableElement(
|
shape_tree.ForEachMutableElement(
|
||||||
[&sum](const ShapeIndex& /*index*/, int* data) { *data = 0; });
|
[](const ShapeIndex& /*index*/, int* data) { *data = 0; });
|
||||||
EXPECT_EQ(0, shape_tree.element({}));
|
EXPECT_EQ(0, shape_tree.element({}));
|
||||||
EXPECT_EQ(0, shape_tree.element({0}));
|
EXPECT_EQ(0, shape_tree.element({0}));
|
||||||
EXPECT_EQ(0, shape_tree.element({1}));
|
EXPECT_EQ(0, shape_tree.element({1}));
|
||||||
|
@ -596,8 +596,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto comma_list_to_int64s =
|
auto comma_list_to_int64s =
|
||||||
[&s,
|
[string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
|
||||||
string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
|
|
||||||
std::vector<int64> results;
|
std::vector<int64> results;
|
||||||
for (const string& piece : tensorflow::str_util::Split(input, ',')) {
|
for (const string& piece : tensorflow::str_util::Split(input, ',')) {
|
||||||
TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece));
|
TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece));
|
||||||
|
@ -733,7 +733,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
|
|||||||
var4D, [epsilon](float a) { return a + epsilon; });
|
var4D, [epsilon](float a) { return a + epsilon; });
|
||||||
|
|
||||||
auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
|
auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
|
||||||
var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); });
|
var_add_epsilon, [](float a) { return 1 / std::sqrt(a); });
|
||||||
|
|
||||||
auto grad_output_times_var =
|
auto grad_output_times_var =
|
||||||
*ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
|
*ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
|
||||||
|
@ -182,7 +182,7 @@ XLA_TEST_F(PrngTest, Uniformity256) {
|
|||||||
|
|
||||||
XLA_TEST_F(PrngTest, MapUsingRng) {
|
XLA_TEST_F(PrngTest, MapUsingRng) {
|
||||||
// Build a x -> (x + U[0,1)) computation.
|
// Build a x -> (x + U[0,1)) computation.
|
||||||
auto build_sum_rng = [this](XlaBuilder& builder) {
|
auto build_sum_rng = [](XlaBuilder& builder) {
|
||||||
auto b = builder.CreateSubBuilder("sum_with_rng");
|
auto b = builder.CreateSubBuilder("sum_with_rng");
|
||||||
auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input");
|
auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input");
|
||||||
Add(x,
|
Add(x,
|
||||||
|
Loading…
Reference in New Issue
Block a user