Fix uses of private, mangled names for proto enumerators.
PiperOrigin-RevId: 256558252
This commit is contained in:
parent
1375ad20ad
commit
011dac5f0a
@ -128,9 +128,7 @@ Status AutoClusteringTest::RunAutoClusteringTestImpl(
|
||||
TF_RETURN_IF_ERROR(AssertGraphDefIsUnclustered(graphdef));
|
||||
|
||||
OptimizationPassRunner runner;
|
||||
TF_RETURN_IF_ERROR(
|
||||
runner.SetJitLevel(tensorflow::OptimizerOptions::GlobalJitLevel::
|
||||
OptimizerOptions_GlobalJitLevel_ON_2));
|
||||
TF_RETURN_IF_ERROR(runner.SetJitLevel(tensorflow::OptimizerOptions::ON_2));
|
||||
TF_RETURN_IF_ERROR(runner.AddCpus(32));
|
||||
TF_RETURN_IF_ERROR(runner.AddGpus(8));
|
||||
|
||||
@ -211,9 +209,7 @@ Status BenchmarkMarkForCompilation(absl::string_view graph_def_path,
|
||||
ReadTextProto(Env::Default(), string(graph_def_path), &graph_def));
|
||||
|
||||
OptimizationPassRunner runner;
|
||||
TF_RETURN_IF_ERROR(
|
||||
runner.SetJitLevel(tensorflow::OptimizerOptions::GlobalJitLevel::
|
||||
OptimizerOptions_GlobalJitLevel_ON_2));
|
||||
TF_RETURN_IF_ERROR(runner.SetJitLevel(tensorflow::OptimizerOptions::ON_2));
|
||||
TF_RETURN_IF_ERROR(runner.AddCpus(32));
|
||||
TF_RETURN_IF_ERROR(runner.AddGpus(8));
|
||||
|
||||
|
@ -25,8 +25,7 @@ TEST(CoreUtilTest, ParseShardingFromDevice) {
|
||||
auto core_from_sharding =
|
||||
[](absl::optional<xla::OpSharding> sharding) -> int64 {
|
||||
if (sharding.has_value() &&
|
||||
sharding.value().type() ==
|
||||
xla::OpSharding::Type::OpSharding_Type_MAXIMAL) {
|
||||
sharding.value().type() == xla::OpSharding::MAXIMAL) {
|
||||
return sharding.value().tile_assignment_devices(0);
|
||||
} else {
|
||||
return -1;
|
||||
|
@ -504,8 +504,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
|
||||
*possible_match,
|
||||
/*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
|
||||
if (sharding.has_value()) {
|
||||
TF_RET_CHECK(sharding.value().type() ==
|
||||
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
TF_RET_CHECK(sharding.value().type() == xla::OpSharding::MAXIMAL);
|
||||
const int core_annotation = sharding.value().tile_assignment_devices(0);
|
||||
if (core == -1 || core > core_annotation) {
|
||||
core = core_annotation;
|
||||
|
@ -85,8 +85,7 @@ ComputeArgAndRetvalCores(const Graph& graph) {
|
||||
auto sharding,
|
||||
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
|
||||
if (sharding.has_value()) {
|
||||
TF_RET_CHECK(sharding.value().type() ==
|
||||
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
TF_RET_CHECK(sharding.value().type() == xla::OpSharding::MAXIMAL);
|
||||
return sharding.value().tile_assignment_devices(0);
|
||||
} else {
|
||||
return -1;
|
||||
@ -832,7 +831,7 @@ Status XlaCompiler::BuildArguments(
|
||||
xla::XlaOp tuple;
|
||||
if (is_entry_computation) {
|
||||
xla::OpSharding tuple_sharding;
|
||||
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
|
||||
tuple_sharding.set_type(xla::OpSharding::TUPLE);
|
||||
for (int64 parameter : *input_to_args) {
|
||||
auto it = arg_cores.find(parameter);
|
||||
const int core = it == arg_cores.end() ? 0 : it->second;
|
||||
|
Loading…
Reference in New Issue
Block a user