[XLA] Move xla::legacy_flags into xla namespace.

PiperOrigin-RevId: 220680181
This commit is contained in:
Justin Lebar 2018-11-08 12:15:21 -08:00 committed by TensorFlower Gardener
parent 99c50e5777
commit 5d5b912c7c
50 changed files with 175 additions and 213 deletions

View File

@ -93,7 +93,7 @@ cc_library(
":tfcompile_lib",
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -103,7 +103,7 @@ Status Main(const MainFlags& flags) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile();
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
@ -132,7 +132,7 @@ int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::AppendDebugOptionsFlags(&flag_list);
tensorflow::string usage = tensorflow::tfcompile::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list);

View File

@ -22,7 +22,7 @@ cc_library(
hdrs = ["mark_for_compilation_pass_flags.h"],
deps =
[
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
@ -34,7 +34,7 @@ cc_library(
hdrs = ["xla_device_flags.h"],
deps =
[
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
@ -46,7 +46,7 @@ cc_library(
hdrs = ["build_xla_ops_pass_flags.h"],
deps =
[
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
@ -58,7 +58,7 @@ cc_library(
hdrs = ["xla_ops_common_flags.h"],
deps =
[
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],

View File

@ -16,7 +16,7 @@ limitations under the License.
#include <mutex> // NOLINT
#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
@ -34,7 +34,7 @@ void AllocateAndParseFlags() {
Flag("tf_xla_enable_lazy_compilation",
&flags->tf_xla_enable_lazy_compilation, ""),
});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
xla::ParseFlagsFromEnv(*flag_list);
}
} // namespace

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
@ -64,7 +64,7 @@ static void AllocateFlags() {
Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only,
"enable fusion of element-wise operations only using XLA when "
"global_jit_level is ON*.")});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
xla::ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with the XLA bridge's

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
@ -41,7 +41,7 @@ static void AllocateFlags() {
"Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."),
});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
xla::ParseFlagsFromEnv(*flag_list);
}
// Return a pointer to the XlaDeviceFlags struct;

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/util/command_line_flags.h"
@ -35,7 +35,7 @@ void AllocateAndParseFlags() {
Flag("tf_xla_always_defer_compilation",
&flags->tf_xla_always_defer_compilation, ""),
});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
xla::ParseFlagsFromEnv(*flag_list);
}
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {

View File

@ -438,7 +438,7 @@ cc_library(
"dump_graph.h",
],
deps = [
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/dump_graph_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
@ -41,7 +41,7 @@ static void AllocateFlags() {
"Path prefix to which graphs dumped during debugging should be "
"written."),
});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
xla::ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with the XLA bridge's

View File

@ -68,7 +68,7 @@ cc_library(
visibility = [":friends"],
deps = [
":xla_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla:debug_options_flags",
],
)
@ -735,6 +735,70 @@ tf_cc_test(
],
)
cc_library(
name = "parse_flags_from_env",
srcs = ["parse_flags_from_env.cc"],
hdrs = ["parse_flags_from_env.h"],
deps =
[
"//tensorflow/compiler/xla:types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "parse_flags_from_env_test",
srcs = ["parse_flags_from_env_test.cc"],
deps =
[
":parse_flags_from_env",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings:str_format",
],
)
cc_library(
name = "debug_options_flags",
srcs = [
"debug_options_flags.cc",
"debug_options_parsers.h",
],
hdrs = ["debug_options_flags.h"],
deps =
[
":parse_flags_from_env",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "debug_options_parsers_test",
size = "small",
srcs = [
"debug_options_parsers.h",
"debug_options_parsers_test.cc",
],
deps =
[
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)
# -----------------------------------------------------------------------------
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.

View File

@ -68,6 +68,7 @@ cc_library(
deps = [
":global_data",
":xla_computation",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:service_interface",
@ -76,7 +77,6 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@ -236,13 +236,13 @@ tf_cc_test(
deps = [
":xla_builder",
":xla_computation",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/core:test",

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@ -465,8 +465,7 @@ StatusOr<string> Client::ExecutionStatsAsString(
const XlaComputation& computation, const ExecutionProfile& profile) {
TF_ASSIGN_OR_RETURN(
auto computation_stats,
GetComputationStats(computation,
legacy_flags::GetDebugOptionsFromFlags()));
GetComputationStats(computation, GetDebugOptionsFromFlags()));
int64 total_flops =
computation_stats.flop_count() + computation_stats.transcendental_count();
if (profile.compute_time_ns() > 0) {

View File

@ -60,8 +60,8 @@ class LocalExecutable {
// Validates that the given arguments and options satisfy various constraints
// of the computation.
//
// The given ExecutableRunOptions override any values from legacy_flags
// (TF_XLA_FLAGS environment variable).
// The given ExecutableRunOptions override any values from TF_XLA_FLAGS
// environment variable.
Status ValidateExecutionOptions(
const absl::Span<const ShapedBuffer* const> arguments,
const ExecutableRunOptions& run_options, const Backend& backend);
@ -69,8 +69,8 @@ class LocalExecutable {
// Records the computation in a SessionModule proto with the arguments used to
// invoke it, and the result. Enabled by flag: --tla_dump_executions_to.
//
// The given ServiceExecutableRunOptions override any values from legacy_flags
// (TF_XLA_FLAGS environment variable).
// The given ServiceExecutableRunOptions override any values from TF_XLA_FLAGS
// environment variable.
StatusOr<ScopedShapedBuffer> ExecuteAndDump(
const ServiceExecutableRunOptions* run_options,
const absl::Span<const ShapedBuffer* const> arguments);
@ -114,8 +114,8 @@ class LocalClient : public Client {
// Build and return a LocalExecutable object. The executable is compiled using
// the given XlaComputation, argument layouts and options.
//
// The given ExecutableBuildOptions override any values from legacy_flags
// (TF_XLA_FLAGS environment variable).
// The given ExecutableBuildOptions override any values from TF_XLA_FLAGS
// environment variable.
StatusOr<std::unique_ptr<LocalExecutable>> Compile(
const XlaComputation& computation,
const absl::Span<const Shape* const> argument_layouts,

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -43,7 +43,7 @@ class XlaBuilderTest : public ::testing::Test {
const HloModuleProto& proto = computation.proto();
TF_ASSIGN_OR_RETURN(const auto& config,
HloModule::CreateModuleConfigFromProto(
proto, legacy_flags::GetDebugOptionsFromFlags()));
proto, GetDebugOptionsFromFlags()));
return HloModule::CreateFromProto(proto, config);
}
@ -54,7 +54,7 @@ class XlaBuilderTest : public ::testing::Test {
const HloModuleProto& proto = computation.proto();
TF_ASSIGN_OR_RETURN(const auto& config,
HloModule::CreateModuleConfigFromProto(
proto, legacy_flags::GetDebugOptionsFromFlags()));
proto, GetDebugOptionsFromFlags()));
return HloModule::CreateFromProto(proto, config);
}

View File

@ -13,17 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <vector>
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/debug_options_parsers.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
namespace xla {
namespace legacy_flags {
namespace {
DebugOptions* flag_values;
@ -101,8 +99,8 @@ void AllocateFlags() {
[](string comma_separated_values) {
auto* extra_options_map =
flag_values->mutable_xla_backend_extra_options();
impl::parse_xla_backend_extra_options(extra_options_map,
comma_separated_values);
parse_xla_backend_extra_options(extra_options_map,
comma_separated_values);
return true;
};
@ -111,8 +109,8 @@ void AllocateFlags() {
[](string reduce_precision_option_value) {
HloReducePrecisionOptions* option_proto =
flag_values->add_hlo_reduce_precision_options();
return impl::parse_xla_reduce_precision_option(
option_proto, reduce_precision_option_value);
return parse_xla_reduce_precision_option(option_proto,
reduce_precision_option_value);
};
flag_objects = new std::vector<tensorflow::Flag>({
@ -353,5 +351,4 @@ xla::DebugOptions GetDebugOptionsFromFlags() {
return *flag_values;
}
} // namespace legacy_flags
} // namespace xla

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_
#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_
#include <vector>
@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Appends flag definitions for debug options to flag_list.
void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list);
@ -32,7 +31,6 @@ void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list);
// first.
xla::DebugOptions GetDebugOptionsFromFlags();
} // namespace legacy_flags
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_
#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_
#include <vector>
#include "absl/strings/numbers.h"
@ -23,8 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla.pb.h"
namespace xla {
namespace legacy_flags {
namespace impl {
template <typename T>
void parse_xla_backend_extra_options(T* extra_options_map,
@ -140,8 +138,6 @@ inline bool parse_xla_reduce_precision_option(
return true;
}
} // namespace impl
} // namespace legacy_flags
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_

View File

@ -15,7 +15,7 @@ limitations under the License.
// Test for parse_flags_from_env.cc
#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h"
#include "tensorflow/compiler/xla/debug_options_parsers.h"
#include <unordered_map>
#include <vector>
@ -23,13 +23,12 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace legacy_flags {
// Test that the xla_backend_extra_options flag is parsed correctly.
TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) {
std::unordered_map<string, string> test_map;
string test_string = "aa=bb,cc,dd=,ee=ff=gg";
impl::parse_xla_backend_extra_options(&test_map, test_string);
parse_xla_backend_extra_options(&test_map, test_string);
EXPECT_EQ(test_map.size(), 4);
EXPECT_EQ(test_map.at("aa"), "bb");
EXPECT_EQ(test_map.at("cc"), "");
@ -41,7 +40,7 @@ TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) {
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) {
HloReducePrecisionOptions proto;
string test_string = "OP_OUTPUTS=5,10:add,dot";
EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS);
EXPECT_EQ(proto.exponent_bits(), 5);
EXPECT_EQ(proto.mantissa_bits(), 10);
@ -56,7 +55,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) {
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) {
HloReducePrecisionOptions proto;
string test_string = "OP_OUTPUTS=5,10:add,dot;";
EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS);
EXPECT_EQ(proto.exponent_bits(), 5);
EXPECT_EQ(proto.mantissa_bits(), 10);
@ -71,7 +70,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) {
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) {
HloReducePrecisionOptions proto;
string test_string = "UNFUSED_OP_OUTPUTS=5,10:;foo,bar/baz";
EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS);
EXPECT_EQ(proto.exponent_bits(), 5);
EXPECT_EQ(proto.mantissa_bits(), 10);
@ -84,7 +83,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) {
TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) {
HloReducePrecisionOptions proto;
string test_string = "UNFUSED_OP_OUTPUTS=5,10:subtract;foo,bar/baz";
EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string));
EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS);
EXPECT_EQ(proto.exponent_bits(), 5);
EXPECT_EQ(proto.mantissa_bits(), 10);
@ -96,7 +95,6 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) {
EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz");
}
} // namespace legacy_flags
} // namespace xla
int main(int argc, char* argv[]) {

View File

@ -13,14 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
namespace xla {
ExecutionOptions CreateDefaultExecutionOptions() {
ExecutionOptions execution_options;
*(execution_options.mutable_debug_options()) =
legacy_flags::GetDebugOptionsFromFlags();
*(execution_options.mutable_debug_options()) = GetDebugOptionsFromFlags();
return execution_options;
}

View File

@ -1,82 +0,0 @@
# Legacy command-line flags for the XLA libraries.
# Please do not add more flags to this package.
# The XLA libraries were written in an environment that allowed command-line
# flags to be scattered freely throughout the libraries. This model, while
# initially convenient, leads to a proliferation in unused command-line flags
# in tests and binaries, and serious problems in servers, where one might wish
# parameters to be different in independent RPC calls to the same routine.
#
# Please don't add more flags. If you're a library author, pass options and
# parameters explicitly through the library's interface.
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "parse_flags_from_env",
srcs = ["parse_flags_from_env.cc"],
hdrs = ["parse_flags_from_env.h"],
deps =
[
"//tensorflow/compiler/xla:types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "parse_flags_from_env_test",
srcs = ["parse_flags_from_env_test.cc"],
deps =
[
":parse_flags_from_env",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings:str_format",
],
)
cc_library(
name = "debug_options_flags",
srcs = [
"debug_options_flags.cc",
"debug_options_parsers.h",
],
hdrs = ["debug_options_flags.h"],
deps =
[
":parse_flags_from_env",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "debug_options_parsers_test",
size = "small",
srcs = [
"debug_options_parsers.h",
"debug_options_parsers_test.cc",
],
deps =
[
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)

View File

@ -22,7 +22,7 @@ limitations under the License.
#include <string.h>
#include <vector>
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
static const char kEnvVar[] = "TF_XLA_FLAGS"; // environment variable queried
static const char kWS[] = " \t\r\n"; // whitespace
@ -202,5 +201,4 @@ void ResetFlagsFromEnvForTesting(int** pargc, std::vector<char*>** pargv) {
*pargv = &env_argv->argv;
}
} // namespace legacy_flags
} // namespace xla

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_
#ifndef TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_
#define TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_
// This module exports ParseFlagsFromEnv(), which allows other modules to parse
// flags from the environtment variable TF_XLA_FLAGS, or (if the first
@ -50,7 +50,6 @@ limitations under the License.
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Call tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet
// unrecognized flags passed in from the environment, and return its
@ -60,7 +59,6 @@ bool ParseFlagsFromEnv(const std::vector<tensorflow::Flag>& flag_list);
// Used only for testing. Not to be used by clients.
void ResetFlagsFromEnvForTesting(int** pargc, std::vector<char*>** pargv);
} // namespace legacy_flags
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_
#endif // TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_

View File

@ -15,7 +15,7 @@ limitations under the License.
// Test for parse_flags_from_env.cc
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include <stdio.h>
#include <stdlib.h>
@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Test that XLA flags can be set from the environment.
// Failure messages are accompanied by the text in msg[].
@ -159,12 +158,11 @@ TEST(ParseFlagsFromEnv, EnvAndFlag) {
}
}
} // namespace legacy_flags
} // namespace xla
int main(int argc, char* argv[]) {
// Save name of binary so that it may invoke itself.
xla::legacy_flags::binary_name = argv[0];
xla::binary_name = argv[0];
bool recursing = false;
xla::int32 int_flag = 1;
const std::vector<tensorflow::Flag> flag_list = {
@ -173,7 +171,7 @@ int main(int argc, char* argv[]) {
tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"),
};
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = xla::legacy_flags::ParseFlagsFromEnv(flag_list);
bool parse_ok = xla::ParseFlagsFromEnv(flag_list);
if (!parse_ok) {
LOG(QFATAL) << "can't parse from environment\n" << usage;
}

View File

@ -603,11 +603,11 @@ cc_library(
hdrs = ["platform_util.h"],
deps = [
":compiler",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
@ -663,6 +663,7 @@ cc_library(
":source_map_util",
":stream_pool",
":transfer_manager",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:service_interface",
@ -674,7 +675,6 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
@ -731,12 +731,12 @@ cc_library(
":computation_layout",
":platform_util",
":service",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
@ -835,6 +835,7 @@ cc_library(
":maybe_owning_device_memory",
":shaped_buffer",
":stream_pool",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:status",
@ -842,7 +843,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
@ -2434,12 +2434,12 @@ tf_cc_test(
":hlo_graph_dumper",
":hlo_matchers",
":hlo_runner",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
@ -2840,8 +2840,8 @@ tf_cc_test(
":hlo_domain_isolator",
":hlo_domain_remover",
":hlo_parser",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",

View File

@ -2308,7 +2308,7 @@ ENTRY Main {
)";
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
config.set_debug_options(GetDebugOptionsFromFlags());
ParseAndVerifyModule(hlo_text, config);
auto buffers = RunBufferAssignment(&module());

View File

@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/platform_util.h"

View File

@ -110,6 +110,6 @@ Compiler::GetPlatformCompilers() {
}
AotCompilationOptions::AotCompilationOptions()
: debug_options_(legacy_flags::GetDebugOptionsFromFlags()) {}
: debug_options_(GetDebugOptionsFromFlags()) {}
} // namespace xla

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <set>
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -1896,7 +1896,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) {
tensorflow::testing::StopTiming();
for (int i = 0; i < num_iters; ++i) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
config.set_debug_options(GetDebugOptionsFromFlags());
HloModule module("BM_SequentialWhiles", config);
auto builder = HloComputation::Builder("BM_SequentialWhiles");
@ -1936,7 +1936,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) {
tensorflow::testing::StopTiming();
for (int i = 0; i < num_iters; ++i) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
config.set_debug_options(GetDebugOptionsFromFlags());
HloModule module("BM_SequentialWhiles", config);
auto builder = HloComputation::Builder("BM_ParallelWhiles");
@ -2003,7 +2003,7 @@ std::unique_ptr<HloComputation> MakeBenchmarkWhileBody(
void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) {
tensorflow::testing::StopTiming();
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
config.set_debug_options(GetDebugOptionsFromFlags());
CopyInsertion copy_insertion;
const Shape element_shape = ShapeUtil::MakeShape(F32, {});
std::vector<HloInstruction*> tuple_params(num_tuple_inputs);

View File

@ -961,12 +961,12 @@ tf_cc_test(
srcs = ["cpu_copy_insertion_test.cc"],
deps = [
":cpu_copy_insertion",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_matchers",

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"

View File

@ -37,7 +37,7 @@ cc_library(
hdrs = ["gpu_codegen_test.h"],
tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service/gpu:gpu_executable",
"//tensorflow/compiler/xla/tests:filecheck",

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/core/platform/logging.h"
@ -25,7 +25,7 @@ namespace gpu {
std::unique_ptr<HloModule> GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) {
HloModuleConfig config;
auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
auto debug_options = GetDebugOptionsFromFlags();
debug_options.set_xla_gpu_ftz(ftz);
debug_options.set_xla_gpu_max_kernel_unroll_factor(1);
// TODO(b/38354253): Change tests to use Parameters instead of Constants.

View File

@ -85,7 +85,7 @@ TEST_F(GpuUnrollingTest, UnrollFourTimes) {
TEST_F(GpuUnrollingTest, UnrollDefaultTimes) {
// The default unrolling factor is 4.
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
config.set_debug_options(GetDebugOptionsFromFlags());
auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_domain_remover.h"
@ -67,7 +67,7 @@ class HloDomainTest : public HloVerifiedTestBase {
StatusOr<HloModule*> ParseModule(absl::string_view hlo_string) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
config.set_debug_options(GetDebugOptionsFromFlags());
ParseAndVerifyModule(hlo_string, config);
return &module();
}

View File

@ -1337,7 +1337,7 @@ void BM_ReducePrecisely(int num_iters) {
tensorflow::testing::StopTiming();
HloComputation::Builder b("BM_ReducePrecisely");
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
config.set_debug_options(GetDebugOptionsFromFlags());
HloModule module("BM_ReducePrecisely", config);
constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -217,8 +217,8 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
// fix the number of devices to one. However we do let the user override
// this behavior to help run tests on the host that run models in parallel
// across multiple devices.
device_count = legacy_flags::GetDebugOptionsFromFlags()
.xla_force_host_platform_device_count();
device_count =
GetDebugOptionsFromFlags().xla_force_host_platform_device_count();
}
std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
VLOG(1) << "Initializing devices";

View File

@ -23,9 +23,9 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@ -292,7 +292,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
config->set_seed(execution_options->seed());
config->set_debug_options(execution_options->debug_options());
} else {
config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
config->set_debug_options(GetDebugOptionsFromFlags());
}
if (execute_backend_ != nullptr &&

View File

@ -22,8 +22,8 @@ limitations under the License.
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/allocation_tracker.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/channel_tracker.h"

View File

@ -44,7 +44,7 @@ cc_library(
testonly = True,
srcs = ["xla_internal_test_main.cc"],
deps = [
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
@ -117,12 +117,12 @@ cc_library(
deps = [
":literal_test_util",
":test_utils",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",

View File

@ -23,8 +23,8 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
@ -135,7 +135,7 @@ PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
}
DebugOptions HloTestBase::GetDebugOptionsForTest() {
auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
auto debug_options = GetDebugOptionsFromFlags();
// TODO(b/38354253): Change tests to use Parameters instead of Constants.
debug_options.add_xla_disable_hlo_passes("constant_folding");
debug_options.set_xla_gpu_max_kernel_unroll_factor(1);

View File

@ -126,7 +126,7 @@ class LLVMCompilerTest : public ::testing::Test {
static std::unique_ptr<HloModule> CreateNewModule() {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
config.set_debug_options(GetDebugOptionsFromFlags());
return absl::make_unique<HloModule>(TestName(), config);
}
};

View File

@ -380,7 +380,7 @@ static std::pair<int, char**> AddXlaHloProfileFlag(int argc, char** argv) {
GTEST_API_ int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::AppendDebugOptionsFlags(&flag_list);
std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv);
auto usage = tensorflow::Flags::Usage(argv[0], flag_list);

View File

@ -15,14 +15,14 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
GTEST_API_ int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::AppendDebugOptionsFlags(&flag_list);
auto usage = tensorflow::Flags::Usage(argv[0], flag_list);
if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
LOG(ERROR) << "\n" << usage;

View File

@ -33,6 +33,7 @@ cc_library(
name = "dumped_computation_to_graphviz_library",
srcs = ["dumped_computation_to_graphviz.cc"],
deps = [
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@ -40,7 +41,6 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
@ -78,6 +78,7 @@ cc_library(
name = "replay_computation_library",
srcs = ["replay_computation.cc"],
deps = [
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -91,7 +92,6 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:testing",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
@ -207,13 +207,13 @@ tf_cc_binary(
name = "dumped_computation_to_tf_graphdef",
srcs = ["dumped_computation_to_tf_graphdef.cc"],
deps = [
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_proto",

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -54,7 +54,7 @@ void RealMain(absl::Span<char* const> args) {
tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
XlaComputation computation =
client->LoadSnapshot(module).ConsumeValueOrDie();
DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags();
DebugOptions debug_options = GetDebugOptionsFromFlags();
debug_options.set_xla_generate_hlo_graph(".*");
ComputationStats stats =
client->GetComputationStats(computation, debug_options)
@ -68,7 +68,7 @@ void RealMain(absl::Span<char* const> args) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {

View File

@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -53,7 +53,7 @@ void RealMain(absl::Span<char* const> args) {
tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
XlaComputation computation =
client->LoadSnapshot(module).ConsumeValueOrDie();
DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags();
DebugOptions debug_options = GetDebugOptionsFromFlags();
debug_options.set_xla_generate_hlo_graph(".*");
debug_options.set_xla_hlo_dump_as_graphdef(true);
ComputationStats stats =
@ -68,7 +68,7 @@ void RealMain(absl::Span<char* const> args) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {

View File

@ -47,8 +47,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/testing.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@ -191,8 +191,7 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
// Run the computation num_runs times, and return the result from the last
// execution.
const bool xla_hlo_profile =
legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile();
const bool xla_hlo_profile = GetDebugOptionsFromFlags().xla_hlo_profile();
StreamExecutorMemoryAllocator allocator(
client->platform(),
{client->platform()->ExecutorForDevice(0).ValueOrDie()});

View File

@ -12,6 +12,7 @@ cc_library(
hdrs = ["xrt_state_ops.h"],
deps = [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -21,7 +22,6 @@ cc_library(
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:hlo_proto",