Merge commit for internal changes
This commit is contained in:
commit
f69c7569cd
@ -39,7 +39,7 @@ config_setting(
|
||||
config_setting(
|
||||
name = "android_armeabi",
|
||||
values = {
|
||||
"cc_target_os": "android",
|
||||
"crosstool_top": "//external:android/crosstool",
|
||||
"cpu": "armeabi",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
@ -218,7 +218,9 @@ filegroup(
|
||||
"//tensorflow/compiler/jit/ops:all_files",
|
||||
"//tensorflow/compiler/tests:all_files",
|
||||
"//tensorflow/compiler/tf2xla:all_files",
|
||||
"//tensorflow/compiler/tf2xla/cc:all_files",
|
||||
"//tensorflow/compiler/tf2xla/kernels:all_files",
|
||||
"//tensorflow/compiler/tf2xla/ops:all_files",
|
||||
"//tensorflow/compiler/xla:all_files",
|
||||
"//tensorflow/compiler/xla/client:all_files",
|
||||
"//tensorflow/compiler/xla/client/lib:all_files",
|
||||
@ -253,7 +255,7 @@ filegroup(
|
||||
"//tensorflow/contrib/data/python/kernel_tests:all_files",
|
||||
"//tensorflow/contrib/data/python/ops:all_files",
|
||||
"//tensorflow/contrib/data/python/util:all_files",
|
||||
"//tensorflow/contrib/decision_trees:all_files",
|
||||
"//tensorflow/contrib/decision_trees/proto:all_files",
|
||||
"//tensorflow/contrib/distributions:all_files",
|
||||
"//tensorflow/contrib/factorization:all_files",
|
||||
"//tensorflow/contrib/factorization/kernels:all_files",
|
||||
@ -284,6 +286,8 @@ filegroup(
|
||||
"//tensorflow/contrib/ndlstm:all_files",
|
||||
"//tensorflow/contrib/nn:all_files",
|
||||
"//tensorflow/contrib/opt:all_files",
|
||||
"//tensorflow/contrib/predictor:all_files",
|
||||
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
|
||||
"//tensorflow/contrib/rnn:all_files",
|
||||
"//tensorflow/contrib/saved_model:all_files",
|
||||
"//tensorflow/contrib/saved_model/cc/saved_model:all_files",
|
||||
@ -302,10 +306,13 @@ filegroup(
|
||||
"//tensorflow/contrib/stateless:all_files",
|
||||
"//tensorflow/contrib/tensor_forest:all_files",
|
||||
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
|
||||
"//tensorflow/contrib/tensor_forest/kernels/v4:all_files",
|
||||
"//tensorflow/contrib/tensor_forest/proto:all_files",
|
||||
"//tensorflow/contrib/tensorboard:all_files",
|
||||
"//tensorflow/contrib/testing:all_files",
|
||||
"//tensorflow/contrib/text:all_files",
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
|
||||
"//tensorflow/contrib/tpu:all_files",
|
||||
"//tensorflow/contrib/training:all_files",
|
||||
"//tensorflow/contrib/util:all_files",
|
||||
"//tensorflow/contrib/verbs:all_files",
|
||||
@ -353,70 +360,6 @@ filegroup(
|
||||
"//tensorflow/python/ops/distributions:all_files",
|
||||
"//tensorflow/python/saved_model:all_files",
|
||||
"//tensorflow/python/tools:all_files",
|
||||
"//tensorflow/tensorboard:all_files",
|
||||
"//tensorflow/tensorboard/backend:all_files",
|
||||
"//tensorflow/tensorboard/backend/event_processing:all_files",
|
||||
"//tensorflow/tensorboard/components:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_audio_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_audio_dashboard/test:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_backend:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_backend/test:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_color_scale:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_color_scale/test:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_dashboard_common:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_dashboard_common/test:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_globals:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_app:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_app/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_board:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_board/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_common:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_controls:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_controls/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_dashboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_debugger_data_card:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_info:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_info/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_loader:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_loader/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_image_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_imports:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_option_selector:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_profile_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_profile_dashboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_runs_selector:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_storage:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_storage/test:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_tensorboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_text_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_trace_viewer:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_distribution_chart:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_projector:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_projector/test:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_sorting:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_sorting/test:all_files",
|
||||
"//tensorflow/tensorboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
|
||||
"//tensorflow/tensorboard/plugins:all_files",
|
||||
"//tensorflow/tensorboard/plugins/audio:all_files",
|
||||
"//tensorflow/tensorboard/plugins/distributions:all_files",
|
||||
"//tensorflow/tensorboard/plugins/graphs:all_files",
|
||||
"//tensorflow/tensorboard/plugins/histograms:all_files",
|
||||
"//tensorflow/tensorboard/plugins/images:all_files",
|
||||
"//tensorflow/tensorboard/plugins/projector:all_files",
|
||||
"//tensorflow/tensorboard/plugins/scalars:all_files",
|
||||
"//tensorflow/tensorboard/plugins/text:all_files",
|
||||
"//tensorflow/tensorboard/scripts:all_files",
|
||||
"//tensorflow/tools/api/golden:all_files",
|
||||
"//tensorflow/tools/api/lib:all_files",
|
||||
"//tensorflow/tools/api/tests:all_files",
|
||||
|
@ -628,7 +628,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
|
||||
// Target nodes
|
||||
const char** c_target_oper_names, int ntargets,
|
||||
const char** handle, TF_Status* status) {
|
||||
status->status = Status::OK();
|
||||
*handle = nullptr;
|
||||
|
||||
std::vector<tensorflow::string> input_names(ninputs);
|
||||
std::vector<tensorflow::string> output_names(noutputs);
|
||||
@ -643,16 +643,12 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
|
||||
target_oper_names[i] = c_target_oper_names[i];
|
||||
}
|
||||
tensorflow::string new_handle;
|
||||
Status result;
|
||||
result = s->session->PRunSetup(input_names, output_names, target_oper_names,
|
||||
&new_handle);
|
||||
if (result.ok()) {
|
||||
status->status = s->session->PRunSetup(input_names, output_names,
|
||||
target_oper_names, &new_handle);
|
||||
if (status->status.ok()) {
|
||||
char* buf = new char[new_handle.size() + 1];
|
||||
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
|
||||
*handle = buf;
|
||||
} else {
|
||||
*handle = nullptr;
|
||||
status->status = result;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2326,6 +2322,8 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
|
||||
int ninputs, const TF_Output* outputs, int noutputs,
|
||||
const TF_Operation* const* target_opers, int ntargets,
|
||||
const char** handle, TF_Status* status) {
|
||||
*handle = nullptr;
|
||||
|
||||
if (!ExtendSessionGraphHelper(session, status)) {
|
||||
return;
|
||||
}
|
||||
|
@ -1101,8 +1101,7 @@ TF_CAPI_EXPORT extern void TF_SessionRun(
|
||||
// needed.
|
||||
//
|
||||
// On failure, out_status contains a tensorflow::Status with an error
|
||||
// message.
|
||||
// NOTE: This is EXPERIMENTAL and subject to change.
|
||||
// message. *handle is set to nullptr.
|
||||
TF_CAPI_EXPORT extern void TF_SessionPRunSetup(
|
||||
TF_Session*,
|
||||
// Input names
|
||||
@ -1118,7 +1117,6 @@ TF_CAPI_EXPORT extern void TF_SessionPRunSetup(
|
||||
|
||||
// Continue to run the graph with additional feeds and fetches. The
|
||||
// execution state is uniquely identified by the handle.
|
||||
// NOTE: This is EXPERIMENTAL and subject to change.
|
||||
TF_CAPI_EXPORT extern void TF_SessionPRun(
|
||||
TF_Session*, const char* handle,
|
||||
// Input tensors
|
||||
|
@ -61,7 +61,6 @@ cc_library(
|
||||
":gradients",
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
@ -274,10 +273,6 @@ cc_library(
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
@ -305,10 +300,6 @@ cc_library(
|
||||
":cc_ops",
|
||||
":cc_ops_internal",
|
||||
":grad_op_registry",
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
@ -527,7 +518,6 @@ cc_library(
|
||||
deps = [
|
||||
":coordinator",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -560,8 +550,6 @@ cc_library(
|
||||
srcs = ["training/coordinator.cc"],
|
||||
hdrs = ["training/coordinator.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -21,6 +21,9 @@ namespace tensorflow {
|
||||
/// SavedModel assets directory.
|
||||
constexpr char kSavedModelAssetsDirectory[] = "assets";
|
||||
|
||||
/// SavedModel assets.extra directory.
|
||||
constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra";
|
||||
|
||||
/// SavedModel assets key for graph collection-def.
|
||||
constexpr char kSavedModelAssetsKey[] = "saved_model_assets";
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
@ -76,8 +77,16 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
string tags_as_string = "{ ";
|
||||
for (const string& tag : tags) {
|
||||
tags_as_string = strings::StrCat(tags_as_string, tag, " ");
|
||||
}
|
||||
tags_as_string = strings::StrCat(tags_as_string, "}");
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"Could not find meta graph def matching supplied tags.");
|
||||
"Could not find meta graph def matching supplied tags: " +
|
||||
tags_as_string +
|
||||
". To inspect available tag-sets in the SavedModel, please "
|
||||
"use the SavedModel CLI: `saved_model_cli`");
|
||||
}
|
||||
|
||||
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||
|
@ -133,9 +133,9 @@ TEST_F(LoaderTest, NoTagMatch) {
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{"missing-tag"}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(
|
||||
StringPiece(st.error_message())
|
||||
.contains("Could not find meta graph def matching supplied tags."))
|
||||
EXPECT_TRUE(StringPiece(st.error_message())
|
||||
.contains("Could not find meta graph def matching supplied "
|
||||
"tags: { missing-tag }"))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
@ -151,7 +151,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(
|
||||
StringPiece(st.error_message())
|
||||
.contains("Could not find meta graph def matching supplied tags."))
|
||||
.contains("Could not find meta graph def matching supplied tags: "))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
|
@ -18,10 +18,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// Tag for the `gpu` graph.
|
||||
constexpr char kSavedModelTagGpu[] = "gpu";
|
||||
|
||||
/// Tag for the `serving` graph.
|
||||
constexpr char kSavedModelTagServe[] = "serve";
|
||||
|
||||
/// Tag for the `training` graph.`
|
||||
/// Tag for the `training` graph.
|
||||
constexpr char kSavedModelTagTrain[] = "train";
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -126,14 +126,11 @@ cc_library(
|
||||
deps = [
|
||||
":tfcompile_lib",
|
||||
":tfcompile_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:llvm_util_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:service_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:util_flags",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
|
@ -23,14 +23,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/util_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
@ -136,14 +133,11 @@ int main(int argc, char** argv) {
|
||||
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
xla::legacy_flags::AppendAliasAnalysisFlags(&flag_list);
|
||||
xla::legacy_flags::AppendBufferAssignmentFlags(&flag_list);
|
||||
xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list);
|
||||
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
|
||||
xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
|
||||
xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list);
|
||||
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
|
||||
xla::legacy_flags::AppendLlvmUtilFlags(&flag_list);
|
||||
xla::legacy_flags::AppendServiceFlags(&flag_list);
|
||||
xla::legacy_flags::AppendUtilFlags(&flag_list);
|
||||
|
||||
|
@ -22,20 +22,6 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
# This target can be used by XLA device plugins to prevent circular
|
||||
# dependencies, and provides access to all of the required headers
|
||||
# for building a device library.
|
||||
cc_header_only_library(
|
||||
name = "xla_jit_headers_lib",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_cpu_device",
|
||||
":xla_cpu_jit",
|
||||
":xla_gpu_device",
|
||||
":xla_gpu_jit",
|
||||
],
|
||||
)
|
||||
|
||||
# Target that bundles up the XLA CPU and GPU JIT devices.
|
||||
cc_library(
|
||||
name = "jit",
|
||||
@ -283,3 +269,15 @@ filegroup(
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
|
||||
cc_header_only_library(
|
||||
name = "xla_jit_headers_lib",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_cpu_device",
|
||||
":xla_cpu_jit",
|
||||
":xla_gpu_device",
|
||||
":xla_gpu_jit",
|
||||
],
|
||||
)
|
||||
|
@ -38,6 +38,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_context.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
@ -149,6 +150,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
xla::ExecutionOptions execution_options;
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
kernel->xla_output_shape;
|
||||
*execution_options.mutable_debug_options() =
|
||||
xla::legacy_flags::GetDebugOptionsFromFlags();
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
@ -202,8 +205,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
|
||||
// Apply variable updates, if any.
|
||||
VLOG(2) << "Applying variable updates";
|
||||
for (int i = 0; i < kernel->variable_updates.size(); ++i) {
|
||||
const XlaCompiler::VariableUpdate& write = kernel->variable_updates[i];
|
||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
||||
OP_REQUIRES(ctx,
|
||||
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
|
||||
errors::Internal("Invalid input index for variable write."));
|
||||
|
@ -1,32 +1,20 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_ops",
|
||||
srcs = [
|
||||
"xla_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
srcs = ["xla_ops.cc"],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_check_op",
|
||||
srcs = ["parallel_check_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
|
@ -182,17 +182,18 @@ Status BuildArguments(int num_constant_args,
|
||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
||||
|
||||
arg.name = variable_args[variable_id].name;
|
||||
arg.kind = XlaCompiler::Argument::kVariable;
|
||||
if (variable_args[variable_id].present) {
|
||||
const Tensor& value = variable_args[variable_id].value;
|
||||
arg.kind = XlaCompiler::Argument::kVariable;
|
||||
arg.type = value.dtype();
|
||||
arg.shape = value.shape();
|
||||
arg.initialized = true;
|
||||
} else {
|
||||
// The values of uninitialized variables are not passed as inputs, since
|
||||
// they are meaningless. However, it is legal to assign to a resource
|
||||
// variable for the first time inside the XLA computation, so we do permit
|
||||
// uninitialized variables.
|
||||
arg.kind = XlaCompiler::Argument::kUninitializedVariable;
|
||||
arg.initialized = false;
|
||||
arg.type = DT_INVALID;
|
||||
arg.shape = TensorShape();
|
||||
}
|
||||
|
@ -137,7 +137,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
done(result.status());
|
||||
return;
|
||||
}
|
||||
const void* src_ptr = xla::LiteralUtil::InternalData(*result.ValueOrDie());
|
||||
const void* src_ptr = result.ValueOrDie()->InternalData();
|
||||
void* dst_ptr = DMAHelper::base(cpu_tensor);
|
||||
size_t total_bytes = cpu_tensor->TotalBytes();
|
||||
memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
|
@ -40,6 +40,7 @@ py_library(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
@ -323,7 +324,7 @@ tf_xla_py_test(
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "reverse_ops_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["reverse_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
|
@ -228,34 +228,40 @@ class SpaceToBatchNDTest(XLATestCase):
|
||||
outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]],
|
||||
[[4, 41], [6, 61]]])
|
||||
|
||||
def testDirect(self):
|
||||
def testDirect0(self):
|
||||
# Test with zero-size remaining dimension.
|
||||
self._testDirect(
|
||||
input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]])
|
||||
|
||||
def testDirect1(self):
|
||||
# Test with zero-size blocked dimension.
|
||||
self._testDirect(
|
||||
input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]])
|
||||
|
||||
def testDirect2(self):
|
||||
# Test with padding up from zero size.
|
||||
self._testDirect(
|
||||
input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]])
|
||||
|
||||
def testDirect3(self):
|
||||
self._testDirect(
|
||||
input_shape=[3, 3, 4, 5, 2],
|
||||
block_shape=[3, 4, 2],
|
||||
paddings=[[1, 2], [0, 0], [3, 0]])
|
||||
|
||||
def testDirect4(self):
|
||||
self._testDirect(
|
||||
input_shape=[3, 3, 4, 5, 2],
|
||||
block_shape=[3, 4, 2, 2],
|
||||
paddings=[[1, 2], [0, 0], [3, 0], [0, 0]])
|
||||
|
||||
def testDirect5(self):
|
||||
self._testDirect(
|
||||
input_shape=[3, 2, 2, 3, 4, 5, 2, 5],
|
||||
block_shape=[1, 1, 3, 4, 2, 2],
|
||||
paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]])
|
||||
|
||||
def testDirect6(self):
|
||||
self._testDirect(
|
||||
input_shape=[3, 2, 2, 3, 4, 5, 2, 5],
|
||||
block_shape=[1, 1, 3, 4, 2, 2, 1],
|
||||
|
@ -335,7 +335,7 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
r0_bad = gen_data_flow_ops._tensor_array_read_v3(
|
||||
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
|
||||
with self.assertRaisesOpError(
|
||||
"TensorArray dtype is float but Op requested dtype double."):
|
||||
"TensorArray dtype is float but op has dtype double."):
|
||||
r0_bad.eval()
|
||||
|
||||
# Test reading from a different index than the one we wrote to
|
||||
@ -573,13 +573,12 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
[2000.0, -2000.0]],
|
||||
grad_vals[0])
|
||||
|
||||
# TODO(phawkins): implement TensorArrayClose
|
||||
# def testCloseTensorArray(self):
|
||||
# with self.test_session() as session, self.test_scope():
|
||||
# ta = tensor_array_ops.TensorArray(
|
||||
# dtype=dtypes.float32, tensor_array_name="foo", size=3)
|
||||
# c1 = ta.close()
|
||||
# session.run(c1)
|
||||
def testCloseTensorArray(self):
|
||||
with self.test_session() as session, self.test_scope():
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, tensor_array_name="foo", size=3)
|
||||
c1 = ta.close()
|
||||
session.run(c1)
|
||||
|
||||
def testSizeTensorArray(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
@ -588,17 +587,16 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
s = ta.size()
|
||||
self.assertAllEqual(3, s.eval())
|
||||
|
||||
# TODO(phawkins): implement TensorArrayClose
|
||||
# def testWriteCloseTensorArray(self):
|
||||
# with self.test_session(), self.test_scope():
|
||||
# ta = tensor_array_ops.TensorArray(
|
||||
# dtype=dtypes.float32,
|
||||
# tensor_array_name="foo",
|
||||
# size=3,
|
||||
# infer_shape=False)
|
||||
# w0 = ta.write(0, [[4.0, 5.0]])
|
||||
# w1 = w0.write(1, [3.0])
|
||||
# w1.close().run() # Expected to run without problems
|
||||
def testWriteCloseTensorArray(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32,
|
||||
tensor_array_name="foo",
|
||||
size=3,
|
||||
infer_shape=False)
|
||||
w0 = ta.write(0, [[4.0, 5.0]])
|
||||
w1 = w0.write(1, [3.0])
|
||||
w1.close().run() # Expected to run without problems
|
||||
|
||||
# TODO(phawkins): implement while loops.
|
||||
# def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
|
||||
|
@ -42,6 +42,7 @@ cc_library(
|
||||
deps = [
|
||||
":common",
|
||||
":dump_graph",
|
||||
":functionalize_control_flow",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -152,7 +153,6 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
@ -165,13 +165,10 @@ cc_test(
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -203,6 +200,58 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "functionalize_control_flow",
|
||||
srcs = ["functionalize_control_flow.cc"],
|
||||
hdrs = ["functionalize_control_flow.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:graph_to_functiondef",
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/tf2xla/ops:functional_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "functionalize_control_flow_test",
|
||||
srcs = ["functionalize_control_flow_test.cc"],
|
||||
deps = [
|
||||
":functionalize_control_flow",
|
||||
":test_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:resource_variable_ops",
|
||||
"//tensorflow/compiler/tf2xla/cc:functional_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:resource_variable_ops_op_lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_util",
|
||||
testonly = 1,
|
||||
srcs = ["test_util.cc"],
|
||||
hdrs = ["test_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
44
tensorflow/compiler/tf2xla/cc/BUILD
Normal file
44
tensorflow/compiler/tf2xla/cc/BUILD
Normal file
@ -0,0 +1,44 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
|
||||
|
||||
tf_gen_op_wrapper_cc(
|
||||
name = "functional_ops_gen",
|
||||
include_internal_ops = 1,
|
||||
out_ops_file = "ops/functional_ops",
|
||||
deps = ["//tensorflow/compiler/tf2xla/ops:functional_ops"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "functional_ops",
|
||||
srcs = ["ops/functional_ops.cc"],
|
||||
hdrs = ["ops/functional_ops.h"],
|
||||
deps = [
|
||||
"//tensorflow/cc:const_op",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/compiler/tf2xla/ops:functional_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
566
tensorflow/compiler/tf2xla/functionalize_control_flow.cc
Normal file
566
tensorflow/compiler/tf2xla/functionalize_control_flow.cc
Normal file
@ -0,0 +1,566 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
const char* const kArgOp = "_Arg";
|
||||
const char* const kRetValOp = "_Retval";
|
||||
|
||||
// Information about a loop argument.
|
||||
struct Arg {
|
||||
// Every loop argument has an Enter node.
|
||||
Node* enter;
|
||||
|
||||
// Is the loop argument a loop-invariant value? Taken from the `is_constant`
|
||||
// attribute on the Enter node.
|
||||
bool is_loop_invariant;
|
||||
|
||||
// If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
|
||||
// arguments must have all of the following nodes:
|
||||
Node* merge = nullptr;
|
||||
Node* switch_node = nullptr;
|
||||
Node* next_iteration = nullptr;
|
||||
Node* exit = nullptr;
|
||||
};
|
||||
|
||||
// Information about a loop frame.
|
||||
struct Frame {
|
||||
string name;
|
||||
|
||||
// Pointer to the parent frame. The root frame has a pointer to itself.
|
||||
Frame* parent = nullptr;
|
||||
int num_children = 0;
|
||||
|
||||
// Arguments to this loop.
|
||||
std::vector<Arg> args;
|
||||
|
||||
// The loop condition of the loop. There should be exactly one loop condition
|
||||
// in every loop.
|
||||
Node* loop_cond = nullptr;
|
||||
|
||||
// Set of nodes that belong to the loop frame.
|
||||
std::unordered_set<Node*> nodes;
|
||||
};
|
||||
|
||||
// Copies a subgraph from `graph` to `output` by performing a reverse DFS
|
||||
// starting at nodes in vector `stack`.
|
||||
// `node_map` is a vector indexed by source node ID to dest nodes.
|
||||
// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
|
||||
// before the traversal clients can cut the graph. Returns an error if the
|
||||
// traversal leaves 'frame'; the client must add enough nodes to `node_map` to
|
||||
// cut the graph and prevent the traversal from escaping.
|
||||
//
|
||||
// `squash_src_outputs` contains a bool for each source node ID. If true, then
|
||||
// the source output on that node will be replaced by zero when copied. This is
|
||||
// used when replacing a Switch node with an _Arg node. The output we are
|
||||
// taking from the Switch node was not necessarily the first output, but _Arg
|
||||
// nodes only have one output. By adding the Switch node to `squash_src_outputs`
|
||||
// we rewrite the src_output of the corresponding edge to be 0.
|
||||
Status CopySubgraph(const Graph& graph, const Frame& frame,
|
||||
std::vector<Node*> stack,
|
||||
const std::vector<bool>& squash_src_outputs,
|
||||
std::vector<Node*>* node_map, Graph* output) {
|
||||
std::vector<bool> visited(graph.num_node_ids(), false);
|
||||
while (!stack.empty()) {
|
||||
Node* n = stack.back();
|
||||
stack.pop_back();
|
||||
|
||||
VLOG(3) << "Copying node " << n->name();
|
||||
|
||||
if (visited[n->id()]) continue;
|
||||
visited[n->id()] = true;
|
||||
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
Node* src = e->src();
|
||||
if (frame.nodes.find(src) == frame.nodes.end()) {
|
||||
// We traversed out of the loop frame, without encountering a cut node.
|
||||
return errors::Internal("Graph traversal of loop frame ", frame.name,
|
||||
" escaped frame at ", src->name(),
|
||||
" without encountering an argument node.");
|
||||
}
|
||||
if ((*node_map)[src->id()] == nullptr) {
|
||||
(*node_map)[src->id()] = output->CopyNode(src);
|
||||
stack.push_back(src);
|
||||
}
|
||||
Node* src_copy = (*node_map)[e->src()->id()];
|
||||
int src_output = squash_src_outputs[e->src()->id()] ? 0 : e->src_output();
|
||||
Node* dst_copy = (*node_map)[e->dst()->id()];
|
||||
output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildArgNode(Graph* graph, DataType type, int index, Node** arg_node) {
|
||||
NodeDef arg_def;
|
||||
NodeDefBuilder builder(strings::StrCat("_Arg", index), kArgOp);
|
||||
builder.Attr("T", type);
|
||||
builder.Attr("index", index);
|
||||
TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
|
||||
Status status;
|
||||
*arg_node = graph->AddNode(arg_def, &status);
|
||||
return status;
|
||||
}
|
||||
|
||||
Status BuildRetvalNode(Graph* graph, DataType type, int index,
|
||||
Node** retval_node) {
|
||||
NodeDef ret_def;
|
||||
ret_def.set_op(kRetValOp);
|
||||
ret_def.set_name(strings::StrCat("_Retval", index));
|
||||
AddNodeAttr("T", type, &ret_def);
|
||||
AddNodeAttr("index", index, &ret_def);
|
||||
Status status;
|
||||
*retval_node = graph->AddNode(ret_def, &status);
|
||||
return status;
|
||||
}
|
||||
|
||||
// Builds a graph for the loop condition.
|
||||
Status BuildLoopCondition(const Graph& graph, Frame* frame,
|
||||
std::unique_ptr<Graph>* cond_output) {
|
||||
VLOG(2) << "Building loop condition for " << frame->name;
|
||||
*cond_output = xla::MakeUnique<Graph>(graph.op_registry());
|
||||
Graph* output = cond_output->get();
|
||||
|
||||
// Map from nodes in the original graph to the condition graph.
|
||||
std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
|
||||
std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
|
||||
|
||||
// Build one _Arg node for each Enter node.
|
||||
for (int i = 0; i < frame->args.size(); ++i) {
|
||||
const Arg& arg = frame->args[i];
|
||||
|
||||
Node* arg_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildArgNode(output, arg.enter->input_type(0), i, &arg_node));
|
||||
if (arg.is_loop_invariant) {
|
||||
node_map[arg.enter->id()] = arg_node;
|
||||
} else {
|
||||
node_map[arg.merge->id()] = arg_node;
|
||||
}
|
||||
}
|
||||
|
||||
// Build a Retval node for the loop condition. The LoopCond nodes are always
|
||||
// boolean because of the type constraints on the LoopCond op.
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildRetvalNode(output, DT_BOOL, 0, &node_map[frame->loop_cond->id()]));
|
||||
|
||||
// Performs a reverse DFS, copying nodes and edges to the output graph.
|
||||
// The _Arg and _Retval nodes were added unconditionally above, so we are
|
||||
// guaranteed to get the correct function signature.
|
||||
TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, {frame->loop_cond},
|
||||
squash_src_outputs, &node_map, output));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builds a graph for the loop body.
|
||||
Status BuildLoopBody(const Graph& graph, Frame* frame,
|
||||
DataTypeVector* arg_types,
|
||||
std::unique_ptr<Graph>* body_output) {
|
||||
VLOG(2) << "Building loop body for " << frame->name;
|
||||
*body_output = xla::MakeUnique<Graph>(graph.op_registry());
|
||||
Graph* output = body_output->get();
|
||||
|
||||
// Map from nodes in the original graph to the condition graph.
|
||||
std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
|
||||
std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
|
||||
|
||||
// Build one _Arg node for each Enter node.
|
||||
std::vector<Node*> next_iterations;
|
||||
next_iterations.reserve(frame->args.size());
|
||||
arg_types->reserve(frame->args.size());
|
||||
for (int i = 0; i < frame->args.size(); ++i) {
|
||||
const Arg& arg = frame->args[i];
|
||||
|
||||
DataType dtype = arg.enter->input_type(0);
|
||||
arg_types->push_back(dtype);
|
||||
Node* arg_node;
|
||||
TF_RETURN_IF_ERROR(BuildArgNode(output, dtype, i, &arg_node));
|
||||
|
||||
if (dtype == DT_RESOURCE) {
|
||||
// The convention of the XLA bridge is that resource variable arguments
|
||||
// are only inputs to the loop body and have no corresponding output.
|
||||
// TODO(b/37741920): change the convention so that DT_RESOURCE variables
|
||||
// are both inputs and outputs, and then remove this case.
|
||||
TF_RET_CHECK(arg.is_loop_invariant);
|
||||
node_map[arg.enter->id()] = arg_node;
|
||||
} else {
|
||||
Node* retval_node;
|
||||
TF_RETURN_IF_ERROR(BuildRetvalNode(output, dtype, i, &retval_node));
|
||||
|
||||
if (arg.is_loop_invariant) {
|
||||
// Argument is loop-invariant. Forward it from the Arg to the Retval.
|
||||
node_map[arg.enter->id()] = arg_node;
|
||||
output->AddEdge(arg_node, 0, retval_node, 0);
|
||||
} else {
|
||||
// Argument is loop-varying.
|
||||
node_map[arg.switch_node->id()] = arg_node;
|
||||
// The Switch node has two outputs, but _Arg only has one. This tells
|
||||
// the CopySubgraph function to rewrite the output number of edges from
|
||||
// the _Arg node to be 0 rather than copying the output number from the
|
||||
// Switch node.
|
||||
squash_src_outputs[arg.switch_node->id()] = true;
|
||||
node_map[arg.next_iteration->id()] = retval_node;
|
||||
next_iterations.push_back(arg.next_iteration);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Performs a reverse DFS, copying nodes and edges to the output graph.
|
||||
// The _Arg and _Retval nodes were added unconditionally above, so we are
|
||||
// guaranteed to get the correct function signature.
|
||||
TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, std::move(next_iterations),
|
||||
squash_src_outputs, &node_map, output));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionalizeLoop(Graph* graph, Frame* frame,
|
||||
FunctionLibraryDefinition* library) {
|
||||
VLOG(2) << "Frame " << frame->name << " before: "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_before", *graph);
|
||||
|
||||
// Split loop-varying Enter nodes with multiple successors. If the same
|
||||
// Tensor is fed as input to multiple loop arguments, we may end up with a
|
||||
// shared Enter node. We clone Enter nodes with multiple successors to
|
||||
// maintain the invariant of a unique Enter node per argument of the final
|
||||
// loop.
|
||||
std::vector<Arg> args;
|
||||
for (const Arg& arg : frame->args) {
|
||||
if (arg.is_loop_invariant) {
|
||||
args.push_back(arg);
|
||||
} else {
|
||||
std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
|
||||
arg.enter->out_edges().end());
|
||||
for (int i = 0; i < edges.size(); ++i) {
|
||||
TF_RET_CHECK(!edges[i]->IsControlEdge());
|
||||
Arg new_arg;
|
||||
new_arg.is_loop_invariant = false;
|
||||
if (i == 0) {
|
||||
new_arg.enter = arg.enter;
|
||||
} else {
|
||||
new_arg.enter = graph->CopyNode(arg.enter);
|
||||
frame->nodes.insert(new_arg.enter);
|
||||
for (Edge const* e : arg.enter->in_edges()) {
|
||||
graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
|
||||
e->IsControlEdge() ? Graph::kControlSlot : 0);
|
||||
}
|
||||
Node* dst = edges[i]->dst();
|
||||
int dst_input = edges[i]->dst_input();
|
||||
graph->RemoveEdge(edges[i]);
|
||||
graph->AddEdge(new_arg.enter, 0, dst, dst_input);
|
||||
}
|
||||
args.push_back(new_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
frame->args = std::move(args);
|
||||
|
||||
// Order the arguments so that:
|
||||
// a) resource variables are last, and
|
||||
// b) sort lexicographically by name (for deterministic output).
|
||||
std::sort(frame->args.begin(), frame->args.end(),
|
||||
[](const Arg& a, const Arg& b) {
|
||||
bool a_is_resource = (a.enter->input_type(0) == DT_RESOURCE);
|
||||
bool b_is_resource = (b.enter->input_type(0) == DT_RESOURCE);
|
||||
return std::tie(a_is_resource, a.enter->name()) <
|
||||
std::tie(b_is_resource, b.enter->name());
|
||||
});
|
||||
|
||||
if (frame->loop_cond == nullptr) {
|
||||
return errors::InvalidArgument("Loop ", frame->name,
|
||||
" has no LoopCond node");
|
||||
}
|
||||
|
||||
// Find the set of Switch nodes that are successors of the LoopCond.
|
||||
std::unordered_set<Node*> switches;
|
||||
for (const Edge* edge : frame->loop_cond->out_edges()) {
|
||||
if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
|
||||
edge->dst_input() == 1) {
|
||||
switches.insert(edge->dst());
|
||||
}
|
||||
}
|
||||
|
||||
// For each non-constant argument, looks for the following pattern of nodes:
|
||||
// Enter ----> Merge --------> Switch --> Exit
|
||||
// ^ ^
|
||||
// | |
|
||||
// NextIteration LoopCond
|
||||
// ^ ^
|
||||
// | |
|
||||
// ... ...
|
||||
for (Arg& arg : frame->args) {
|
||||
if (!arg.is_loop_invariant) {
|
||||
// Follow the edge from the Enter to Merge.
|
||||
if (arg.enter->out_edges().size() != 1) {
|
||||
return errors::Internal("Enter node for loop-varying argument ",
|
||||
arg.enter->name(),
|
||||
" does not have exactly one successor");
|
||||
}
|
||||
const Edge* enter_merge = *arg.enter->out_edges().begin();
|
||||
arg.merge = enter_merge->dst();
|
||||
if (!IsMerge(arg.merge)) {
|
||||
return errors::InvalidArgument(
|
||||
"Successor of Enter node for loop-varying argument ",
|
||||
arg.merge->name(),
|
||||
" is not a Merge node; got: ", arg.merge->type_string());
|
||||
}
|
||||
|
||||
// Find the NextIteration from the merge. There should be two inputs to
|
||||
// the Merge and the NextIteration should be the other input.
|
||||
if (arg.merge->input_types().size() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Unexpected number of inputs to Merge node for loop-varying "
|
||||
"argument ",
|
||||
arg.merge->name(), "; expected 2, got ",
|
||||
arg.merge->input_types().size());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
|
||||
&arg.next_iteration));
|
||||
if (!IsNextIteration(arg.next_iteration)) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected NextIteration node as input to Merge node; got node ",
|
||||
arg.next_iteration->name(), " with kind ",
|
||||
arg.next_iteration->type_string());
|
||||
}
|
||||
|
||||
// Find the Switch successor of the Merge. There should be exactly one
|
||||
// Switch node that is a successor of both the Merge and the LoopCond.
|
||||
for (const Edge* edge : arg.merge->out_edges()) {
|
||||
if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
|
||||
switches.find(edge->dst()) != switches.end()) {
|
||||
if (arg.switch_node != nullptr) {
|
||||
return errors::InvalidArgument("Duplicate Switch successors to ",
|
||||
arg.merge->name());
|
||||
}
|
||||
arg.switch_node = edge->dst();
|
||||
}
|
||||
}
|
||||
if (arg.switch_node == nullptr) {
|
||||
return errors::InvalidArgument("Missing Switch successor to ",
|
||||
arg.merge->name());
|
||||
}
|
||||
|
||||
// Find the Exit successor of the Switch.
|
||||
for (const Edge* edge : arg.switch_node->out_edges()) {
|
||||
if (edge->src_output() == 0 && IsExit(edge->dst())) {
|
||||
if (arg.exit != nullptr) {
|
||||
return errors::InvalidArgument("Duplicate Exit successors to ",
|
||||
arg.switch_node->name());
|
||||
}
|
||||
arg.exit = edge->dst();
|
||||
}
|
||||
}
|
||||
if (arg.exit == nullptr) {
|
||||
return errors::InvalidArgument("Mising Exit successor to ",
|
||||
arg.switch_node->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Builds the condition and body functions.
|
||||
std::unique_ptr<Graph> cond_graph;
|
||||
TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
|
||||
DataTypeVector arg_types;
|
||||
std::unique_ptr<Graph> body_graph;
|
||||
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
|
||||
|
||||
VLOG(2) << "Frame " << frame->name << " condition: "
|
||||
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph)
|
||||
<< " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
|
||||
|
||||
static std::atomic<int64> sequence_num(0LL);
|
||||
int64 id = ++sequence_num;
|
||||
NameAttrList cond_name;
|
||||
cond_name.set_name(strings::StrCat("_functionalize_cond_", id));
|
||||
NameAttrList body_name;
|
||||
body_name.set_name(strings::StrCat("_functionalize_body_", id));
|
||||
FunctionDef cond_fdef;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
|
||||
FunctionDef body_fdef;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
|
||||
|
||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
|
||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
|
||||
|
||||
// Builds a While operator.
|
||||
NodeDef while_def;
|
||||
NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
|
||||
builder.Attr("T", arg_types);
|
||||
builder.Attr("cond", cond_name);
|
||||
builder.Attr("body", body_name);
|
||||
std::vector<NodeDefBuilder::NodeOut> inputs;
|
||||
for (int i = 0; i < frame->args.size(); ++i) {
|
||||
const Arg& arg = frame->args[i];
|
||||
const Edge* in_edge;
|
||||
TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
|
||||
if (in_edge->IsControlEdge()) {
|
||||
builder.ControlInput(in_edge->src()->name());
|
||||
} else {
|
||||
inputs.push_back(NodeDefBuilder::NodeOut(
|
||||
in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
|
||||
}
|
||||
}
|
||||
builder.Input(inputs);
|
||||
TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
|
||||
|
||||
Status status;
|
||||
Node* while_node = graph->AddNode(while_def, &status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Copies edges to the Enter nodes and from the Exit nodes onto the While.
|
||||
for (int i = 0; i < frame->args.size(); ++i) {
|
||||
const Arg& arg = frame->args[i];
|
||||
const Edge* in_edge;
|
||||
TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
|
||||
if (in_edge->IsControlEdge()) {
|
||||
graph->AddControlEdge(in_edge->src(), while_node);
|
||||
} else {
|
||||
graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
|
||||
}
|
||||
|
||||
if (!arg.is_loop_invariant) {
|
||||
std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
|
||||
arg.exit->out_edges().end());
|
||||
for (const Edge* edge : edges) {
|
||||
Node* dst = edge->dst();
|
||||
int dst_input = edge->dst_input();
|
||||
graph->RemoveEdge(edge);
|
||||
|
||||
int src_output =
|
||||
dst_input == Graph::kControlSlot ? Graph::kControlSlot : i;
|
||||
graph->AddEdge(while_node, src_output, dst, dst_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the old nodes from the graph, and add the while node to the parent
|
||||
// frame.
|
||||
for (Node* node : frame->nodes) {
|
||||
graph->RemoveNode(node);
|
||||
}
|
||||
frame->parent->nodes.insert(while_node);
|
||||
|
||||
VLOG(2) << "Frame " << frame->name << " after: "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_after", *graph);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Transformation that converts Tensorflow's graph control flow constructs into
|
||||
// functional equivalents.
|
||||
Status FunctionalizeControlFlow(Graph* graph,
|
||||
FunctionLibraryDefinition* library) {
|
||||
VLOG(2) << "FunctionalizeControlFlow: "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_initial", *graph);
|
||||
// Note: BuildControlFlowInfo() requires that the graph's source node is
|
||||
// connected to all source nodes in the graph. Many graphs violate this
|
||||
// invariant.
|
||||
std::vector<ControlFlowInfo> cf_info;
|
||||
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info));
|
||||
|
||||
// Builds Frames, indexed by name.
|
||||
std::unordered_map<string, Frame> frames;
|
||||
for (Node* node : graph->op_nodes()) {
|
||||
const ControlFlowInfo& cf = cf_info[node->id()];
|
||||
|
||||
VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name
|
||||
<< " frame: " << (cf.frame ? cf.frame->name() : "---")
|
||||
<< " parent_frame: "
|
||||
<< (cf.parent_frame ? cf.parent_frame->name() : "---");
|
||||
TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
|
||||
|
||||
Frame& frame = frames[cf.frame_name];
|
||||
Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
|
||||
if (frame.parent == nullptr) {
|
||||
frame.parent = parent;
|
||||
frame.name = cf.frame_name;
|
||||
++parent->num_children;
|
||||
} else if (frame.parent != parent) {
|
||||
return errors::InvalidArgument("Mismatched parent frames for ",
|
||||
cf.frame->id(), ": ", parent->name, " vs ",
|
||||
frame.parent->name);
|
||||
}
|
||||
|
||||
if (IsEnter(node)) {
|
||||
Arg arg;
|
||||
arg.enter = node;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
|
||||
&arg.is_loop_invariant));
|
||||
frame.args.push_back(arg);
|
||||
} else if (IsLoopCond(node)) {
|
||||
if (frame.loop_cond) {
|
||||
return errors::InvalidArgument(
|
||||
"Loop ", cf.frame_name,
|
||||
" has more than one LoopCond node: ", node->name(), " and ",
|
||||
frame.loop_cond->name());
|
||||
}
|
||||
frame.loop_cond = node;
|
||||
}
|
||||
frame.nodes.insert(node);
|
||||
}
|
||||
|
||||
// Adds frames with no children (i.e., the innermost frames) to a worklist.
|
||||
std::deque<Frame*> worklist;
|
||||
for (auto& frame : frames) {
|
||||
if (frame.second.num_children == 0) {
|
||||
worklist.push_back(&frame.second);
|
||||
}
|
||||
}
|
||||
|
||||
// Eliminate loops from innermost to outermost.
|
||||
while (!worklist.empty()) {
|
||||
Frame* frame = worklist.front();
|
||||
worklist.pop_front();
|
||||
if (frame->parent == frame) {
|
||||
// Skip the root frame.
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));
|
||||
|
||||
// If the parent has no remaining children, add it to the worklist.
|
||||
--frame->parent->num_children;
|
||||
if (frame->parent->num_children == 0) {
|
||||
worklist.push_back(frame->parent);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
32
tensorflow/compiler/tf2xla/functionalize_control_flow.h
Normal file
32
tensorflow/compiler/tf2xla/functionalize_control_flow.h
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Transformation that converts tf.while_loop() loops into functional While
|
||||
// operators, suitable for XLA compilation.
|
||||
// TODO(b/36470387): add support for conditionals.
|
||||
Status FunctionalizeControlFlow(Graph* graph,
|
||||
FunctionLibraryDefinition* library);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
|
647
tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
Normal file
647
tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
Normal file
@ -0,0 +1,647 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/test_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Returns the names of the "cond" and "body" functions for the While node
|
||||
// in a graph.
|
||||
Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
|
||||
NameAttrList* body) {
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
if (node.op() == "XlaWhile") {
|
||||
const NameAttrList* result;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
|
||||
*cond = *result;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
|
||||
*body = *result;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return errors::NotFound("No XlaWhile node found in graph");
|
||||
}
|
||||
|
||||
// Graph:
|
||||
// x = array_ops.placeholder(dtypes.int32)
|
||||
// y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
|
||||
TEST(FunctionalizeControlFlow, OneLoopVar) {
|
||||
Graph graph(OpRegistry::Global());
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
|
||||
|
||||
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
|
||||
auto enter =
|
||||
ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
|
||||
auto merge = ops::Merge(scope.WithOpName("while/Merge"),
|
||||
std::initializer_list<Input>{enter, dummy});
|
||||
auto ten = ops::Const<int32>(
|
||||
scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
|
||||
10);
|
||||
auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
|
||||
auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
|
||||
auto switch_ =
|
||||
ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
|
||||
auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
|
||||
switch_.output_false);
|
||||
auto identity =
|
||||
ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
|
||||
auto one = ops::Const<int32>(
|
||||
scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
|
||||
auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
|
||||
auto next_iteration =
|
||||
ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
|
||||
|
||||
auto sink = ops::Identity(scope.WithOpName("sink"), exit);
|
||||
|
||||
// Remove the dummy node and add the loop backedge.
|
||||
scope.graph()->RemoveNode(dummy.node());
|
||||
scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
|
||||
|
||||
TF_EXPECT_OK(scope.ToGraph(&graph));
|
||||
}
|
||||
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
|
||||
|
||||
GraphDef graph_def;
|
||||
graph.ToGraphDef(&graph_def);
|
||||
|
||||
NameAttrList cond_fn, body_fn;
|
||||
TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
|
||||
|
||||
// Outer graph
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
|
||||
auto while_op =
|
||||
ops::XlaWhile(scope.WithOpName("while/LoopCond"),
|
||||
std::initializer_list<Input>{source}, cond_fn, body_fn);
|
||||
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
TF_EXPECT_GRAPH_EQ(expected, graph_def);
|
||||
}
|
||||
|
||||
// Condition graph
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
|
||||
auto ten = ops::Const<int32>(
|
||||
scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
|
||||
auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
|
||||
auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
|
||||
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
|
||||
InstantiationResultForTest result;
|
||||
TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
|
||||
|
||||
EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
|
||||
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
|
||||
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
|
||||
}
|
||||
|
||||
// Body graph.
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
|
||||
auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
|
||||
auto one = ops::Const<int32>(
|
||||
scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
|
||||
auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
|
||||
auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
|
||||
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
|
||||
InstantiationResultForTest result;
|
||||
TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
|
||||
|
||||
EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
|
||||
EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
|
||||
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
|
||||
}
|
||||
}
|
||||
|
||||
// Graph:
|
||||
// x = array_ops.placeholder(dtypes.int32)
|
||||
// y = array_ops.placeholder(dtypes.int32)
|
||||
// cond = lambda (i, j): i + 3 < 10
|
||||
// body = lambda (i, j): (i < 10, j * 2)
|
||||
// z = control_flow_ops.while_loop(cond, body, [x, y])
|
||||
TEST(FunctionalizeControlFlow, TwoLoopVars) {
|
||||
Graph graph(OpRegistry::Global());
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
|
||||
|
||||
auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
|
||||
auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
|
||||
auto enter_x =
|
||||
ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
|
||||
auto enter_y =
|
||||
ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
|
||||
auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
|
||||
std::initializer_list<Input>{enter_x, dummy});
|
||||
auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
|
||||
std::initializer_list<Input>{enter_y, dummy});
|
||||
|
||||
// Loop condition
|
||||
auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
|
||||
.WithControlDependencies(merge_x.output),
|
||||
3);
|
||||
auto cond_add =
|
||||
ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
|
||||
auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
|
||||
.WithControlDependencies(merge_x.output),
|
||||
10);
|
||||
auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
|
||||
auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
|
||||
|
||||
auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
|
||||
merge_x.output, loop_cond);
|
||||
auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
|
||||
merge_y.output, loop_cond);
|
||||
|
||||
auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
|
||||
switch_x.output_false);
|
||||
auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
|
||||
switch_y.output_false);
|
||||
|
||||
auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
|
||||
switch_x.output_true);
|
||||
auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
|
||||
switch_y.output_true);
|
||||
|
||||
auto one = ops::Const<int32>(
|
||||
scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
|
||||
1);
|
||||
auto two = ops::Const<int32>(
|
||||
scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
|
||||
2);
|
||||
|
||||
auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
|
||||
auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
|
||||
auto next_iteration_x =
|
||||
ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
|
||||
auto next_iteration_y =
|
||||
ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
|
||||
|
||||
auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
|
||||
auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
|
||||
|
||||
// Remove the dummy node and add the loop backedges.
|
||||
scope.graph()->RemoveNode(dummy.node());
|
||||
scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
|
||||
1);
|
||||
scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
|
||||
1);
|
||||
|
||||
TF_EXPECT_OK(scope.ToGraph(&graph));
|
||||
}
|
||||
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
|
||||
|
||||
GraphDef graph_def;
|
||||
graph.ToGraphDef(&graph_def);
|
||||
|
||||
NameAttrList cond_fn, body_fn;
|
||||
TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
|
||||
|
||||
// Outer graph.
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
|
||||
auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
|
||||
auto while_op =
|
||||
ops::XlaWhile(scope.WithOpName("while/LoopCond"),
|
||||
std::initializer_list<Input>{x, y}, cond_fn, body_fn);
|
||||
auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
|
||||
auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
TF_EXPECT_GRAPH_EQ(expected, graph_def);
|
||||
}
|
||||
|
||||
// Condition graph.
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
|
||||
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
|
||||
auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
|
||||
.WithControlDependencies(arg0.output),
|
||||
3);
|
||||
auto cond_add =
|
||||
ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
|
||||
auto ten = ops::Const<int32>(
|
||||
scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output),
|
||||
10);
|
||||
auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
|
||||
auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
|
||||
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
|
||||
InstantiationResultForTest result;
|
||||
TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
|
||||
|
||||
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
|
||||
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
|
||||
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
|
||||
}
|
||||
|
||||
// Body graph.
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
|
||||
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
|
||||
|
||||
auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
|
||||
auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
|
||||
|
||||
auto one = ops::Const<int32>(
|
||||
scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
|
||||
1);
|
||||
auto two = ops::Const<int32>(
|
||||
scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
|
||||
2);
|
||||
|
||||
auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
|
||||
auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
|
||||
auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
|
||||
auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1);
|
||||
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
|
||||
InstantiationResultForTest result;
|
||||
TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
|
||||
|
||||
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
|
||||
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
|
||||
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
|
||||
}
|
||||
}
|
||||
|
||||
// Example with nesting, loop-invariant arguments, and resource variables.
|
||||
//
|
||||
// accum = resource_variable_ops.ResourceVariable(1)
|
||||
// x = array_ops.placeholder(2, dtype=dtypes.int32)
|
||||
// y = 3 + x
|
||||
//
|
||||
// def inner_body(j, k):
|
||||
// add = state_ops.assign_add(accum, k * j + x)
|
||||
// with ops.control_dependencies([add]):
|
||||
// return [j + 1, k]
|
||||
//
|
||||
// def body(i):
|
||||
// m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
|
||||
// [1, y], name="inner")
|
||||
// with ops.control_dependencies(m):
|
||||
// return [i + 1]
|
||||
//
|
||||
// z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
|
||||
TEST(FunctionalizeControlFlow, Complex) {
|
||||
Graph graph(OpRegistry::Global());
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
|
||||
|
||||
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
|
||||
auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
|
||||
auto y = ops::Add(scope.WithOpName("y"), x, three);
|
||||
|
||||
auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
|
||||
TensorShape({}));
|
||||
|
||||
// Outer loop
|
||||
auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
|
||||
auto enter_i =
|
||||
ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
|
||||
auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
|
||||
std::initializer_list<Input>{enter_i, dummy});
|
||||
auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
|
||||
.WithControlDependencies(merge_i.output),
|
||||
10);
|
||||
auto less_i =
|
||||
ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
|
||||
auto outer_loop_cond =
|
||||
ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
|
||||
auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
|
||||
merge_i.output, outer_loop_cond);
|
||||
auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
|
||||
switch_i.output_false);
|
||||
auto identity_i =
|
||||
ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
|
||||
|
||||
auto enter_x_outer =
|
||||
ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
auto enter_k_outer =
|
||||
ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
auto enter_var_outer =
|
||||
ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
|
||||
// Inner loop
|
||||
auto one_j = ops::Const<int32>(
|
||||
scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
|
||||
auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
|
||||
one_j, "inner");
|
||||
auto enter_k =
|
||||
ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
|
||||
.WithControlDependencies(identity_i),
|
||||
enter_k_outer, "inner");
|
||||
auto enter_x = ops::internal::Enter(
|
||||
scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
auto enter_var = ops::internal::Enter(
|
||||
scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
|
||||
auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
|
||||
std::initializer_list<Input>{enter_j, dummy});
|
||||
auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
|
||||
std::initializer_list<Input>{enter_k, dummy});
|
||||
|
||||
auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
|
||||
.WithControlDependencies(merge_j.output),
|
||||
5);
|
||||
auto less_j =
|
||||
ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
|
||||
auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j);
|
||||
|
||||
auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
|
||||
merge_j.output, loop_cond);
|
||||
auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
|
||||
merge_k.output, loop_cond);
|
||||
auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
|
||||
switch_j.output_false);
|
||||
auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
|
||||
switch_k.output_false);
|
||||
auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
|
||||
switch_j.output_true);
|
||||
auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
|
||||
switch_k.output_true);
|
||||
|
||||
// Variable update
|
||||
auto mul_jk =
|
||||
ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
|
||||
auto add_jkx =
|
||||
ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
|
||||
auto assign = ops::AssignAddVariableOp(
|
||||
scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
|
||||
|
||||
auto one =
|
||||
ops::Const<int32>(scope.WithOpName("outer/inner/One")
|
||||
.WithControlDependencies(
|
||||
gtl::ArraySlice<Operation>{assign.operation}),
|
||||
1);
|
||||
auto add_j =
|
||||
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
|
||||
|
||||
auto next_iteration_j = ops::NextIteration(
|
||||
scope.WithOpName("outer/inner/NextIteration_j"), add_j);
|
||||
auto next_iteration_k = ops::NextIteration(
|
||||
scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
|
||||
|
||||
// Body and backedge for outer loop.
|
||||
auto one_outer = ops::Const<int32>(
|
||||
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
|
||||
auto add_i =
|
||||
ops::Add(scope.WithOpName("outer/add")
|
||||
.WithControlDependencies(gtl::ArraySlice<Operation>{
|
||||
exit_j.output.op(), exit_k.output.op()}),
|
||||
identity_i, one_outer);
|
||||
auto next_iteration_i =
|
||||
ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
|
||||
|
||||
auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
|
||||
|
||||
// Remove the dummy node and add the loop backedge.
|
||||
scope.graph()->RemoveNode(dummy.node());
|
||||
scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
|
||||
1);
|
||||
scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
|
||||
1);
|
||||
scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
|
||||
1);
|
||||
|
||||
TF_EXPECT_OK(scope.ToGraph(&graph));
|
||||
}
|
||||
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
|
||||
|
||||
GraphDef graph_def;
|
||||
graph.ToGraphDef(&graph_def);
|
||||
|
||||
NameAttrList outer_cond_fn, outer_body_fn;
|
||||
TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
|
||||
|
||||
// Outer graph.
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
|
||||
auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
|
||||
auto y = ops::Add(scope.WithOpName("y"), x, three);
|
||||
|
||||
auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
|
||||
TensorShape({}));
|
||||
|
||||
auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
|
||||
|
||||
auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"),
|
||||
std::initializer_list<Input>{zero, y, x, var},
|
||||
outer_cond_fn, outer_body_fn);
|
||||
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
TF_EXPECT_GRAPH_EQ(expected, graph_def);
|
||||
}
|
||||
|
||||
// Outer condition graph.
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
|
||||
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
|
||||
auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
|
||||
auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
|
||||
|
||||
auto ten = ops::Const<int32>(
|
||||
scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
|
||||
10);
|
||||
auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
|
||||
auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
|
||||
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
|
||||
InstantiationResultForTest result;
|
||||
TF_EXPECT_OK(
|
||||
InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
|
||||
|
||||
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
|
||||
result.arg_types);
|
||||
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
|
||||
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
|
||||
}
|
||||
|
||||
// Outer body graph.
|
||||
NameAttrList inner_cond_fn, inner_body_fn;
|
||||
{
|
||||
InstantiationResultForTest result;
|
||||
TF_EXPECT_OK(
|
||||
InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
|
||||
|
||||
// Find the inner condition and body names.
|
||||
TF_EXPECT_OK(
|
||||
FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
|
||||
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
|
||||
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
|
||||
auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
|
||||
auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
|
||||
|
||||
auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
|
||||
auto one_j = ops::Const<int32>(
|
||||
scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
|
||||
auto while_op =
|
||||
ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"),
|
||||
std::initializer_list<Input>{one_j, arg1, arg2, arg3},
|
||||
inner_cond_fn, inner_body_fn);
|
||||
|
||||
auto one_outer = ops::Const<int32>(
|
||||
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
|
||||
auto add_i =
|
||||
ops::Add(scope.WithOpName("outer/add")
|
||||
.WithControlDependencies(gtl::ArraySlice<Operation>{
|
||||
while_op[0].op(), while_op[1].op()}),
|
||||
identity_i, one_outer);
|
||||
|
||||
auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0);
|
||||
auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1);
|
||||
auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
|
||||
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
|
||||
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
|
||||
result.arg_types);
|
||||
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types);
|
||||
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
|
||||
}
|
||||
|
||||
// Inner condition graph.
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
|
||||
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
|
||||
auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
|
||||
auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
|
||||
|
||||
auto five = ops::Const<int32>(
|
||||
scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5);
|
||||
auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
|
||||
auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0);
|
||||
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
|
||||
InstantiationResultForTest result;
|
||||
TF_EXPECT_OK(
|
||||
InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
|
||||
|
||||
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
|
||||
result.arg_types);
|
||||
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
|
||||
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
|
||||
}
|
||||
|
||||
// Inner body graph.
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
|
||||
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
|
||||
auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
|
||||
auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
|
||||
|
||||
auto identity_j =
|
||||
ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
|
||||
auto identity_k =
|
||||
ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
|
||||
|
||||
auto mul_jk =
|
||||
ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
|
||||
auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
|
||||
auto assign = ops::AssignAddVariableOp(
|
||||
scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
|
||||
|
||||
auto one =
|
||||
ops::Const<int32>(scope.WithOpName("outer/inner/One")
|
||||
.WithControlDependencies(
|
||||
gtl::ArraySlice<Operation>{assign.operation}),
|
||||
1);
|
||||
auto add_j =
|
||||
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
|
||||
|
||||
auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0);
|
||||
auto retval1 =
|
||||
ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1);
|
||||
auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
|
||||
|
||||
GraphDef expected;
|
||||
TF_EXPECT_OK(scope.ToGraphDef(&expected));
|
||||
|
||||
InstantiationResultForTest result;
|
||||
TF_EXPECT_OK(
|
||||
InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
|
||||
|
||||
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
|
||||
result.arg_types);
|
||||
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types);
|
||||
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -68,6 +68,7 @@ tf_kernel_library(
|
||||
"reduction_ops.h",
|
||||
],
|
||||
deps = [
|
||||
":while_op",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -91,6 +92,21 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "while_op",
|
||||
srcs = ["while_op.cc"],
|
||||
hdrs = ["while_op.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/ops:functional_ops",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
],
|
||||
)
|
||||
|
||||
# Kernels that only work on CPU, because they use XLA custom calls.
|
||||
# Only link this when using the CPU backend for XLA.
|
||||
#
|
||||
|
@ -51,13 +51,26 @@ class ArgOp : public XlaOpKernel {
|
||||
|
||||
XlaContext& xc = XlaContext::Get(ctx);
|
||||
const XlaContext::Argument& arg = xc.args()[index_];
|
||||
if (arg.is_variable) {
|
||||
if (arg.is_resource) {
|
||||
XlaResource::Kind kind;
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kVariable:
|
||||
kind = XlaResource::kVariable;
|
||||
break;
|
||||
case XlaCompiler::Argument::kTensorArray:
|
||||
kind = XlaResource::kTensorArray;
|
||||
break;
|
||||
default:
|
||||
CHECK(false);
|
||||
}
|
||||
|
||||
// TODO(phawkins): this code assumes that variables do not alias.
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type,
|
||||
arg.value.handle, &var));
|
||||
var->tensor_array_size = arg.tensor_array_size;
|
||||
ctx->SetVariableOutput(0, var);
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
xc.CreateResource(kind, index_, arg.name, arg.value.type,
|
||||
arg.value.handle, &resource));
|
||||
resource->tensor_array_size = arg.tensor_array_size;
|
||||
ctx->SetResourceOutput(0, resource);
|
||||
} else if (arg.value.is_constant) {
|
||||
ctx->SetConstantOutput(0, arg.value.constant_value);
|
||||
} else {
|
||||
|
@ -127,8 +127,8 @@ void BatchToSpace(XlaOpKernelContext* ctx,
|
||||
std::vector<int64> end_indices = reshaped_permuted_shape;
|
||||
std::vector<int64> strides(input_rank, 1);
|
||||
for (int i = 0; i < block_rank; ++i) {
|
||||
int64 crop_start = xla::LiteralUtil::Get<int64>(crops, {i, 0});
|
||||
int64 crop_end = xla::LiteralUtil::Get<int64>(crops, {i, 1});
|
||||
int64 crop_start = crops.Get<int64>({i, 0});
|
||||
int64 crop_end = crops.Get<int64>({i, 1});
|
||||
OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0,
|
||||
errors::InvalidArgument("Crops must be non-negative"));
|
||||
start_indices[1 + i] = crop_start;
|
||||
|
@ -55,7 +55,7 @@ class BCastGradArgsOp : public XlaOpKernel {
|
||||
|
||||
BCast::Vec vec;
|
||||
for (int64 i = 0; i < in_shape.num_elements(); ++i) {
|
||||
vec.push_back(xla::LiteralUtil::Get<int>(literal, {i}));
|
||||
vec.push_back(literal.Get<int>({i}));
|
||||
}
|
||||
shapes.push_back(vec);
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ class ConcatBaseOp : public XlaOpKernel {
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal));
|
||||
// TODO(annarev): add a helper to support int64 input.
|
||||
const int32 concat_dim = xla::LiteralUtil::Get<int>(literal, {});
|
||||
const int32 concat_dim = literal.Get<int>({});
|
||||
|
||||
std::vector<xla::ComputationDataHandle> values;
|
||||
std::vector<TensorShape> shapes;
|
||||
@ -163,7 +163,7 @@ class ConcatOffsetOp : public XlaOpKernel {
|
||||
|
||||
xla::Literal concat_dim_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal));
|
||||
const int64 cdim = xla::LiteralUtil::Get<int>(concat_dim_literal, {});
|
||||
const int64 cdim = concat_dim_literal.Get<int>({});
|
||||
|
||||
VLOG(1) << "ConcatOffset " << cdim << "," << dims;
|
||||
int32 axis = cdim < 0 ? cdim + dims : cdim;
|
||||
@ -185,12 +185,10 @@ class ConcatOffsetOp : public XlaOpKernel {
|
||||
for (int64 j = 0; j < dims; ++j) {
|
||||
if (j == axis) {
|
||||
out_vec(j) = offset;
|
||||
offset += xla::LiteralUtil::Get<int>(inp_literal, {j});
|
||||
offset += inp_literal.Get<int>({j});
|
||||
} else {
|
||||
const int32 inp0_element =
|
||||
xla::LiteralUtil::Get<int>(inp0_literal, {j});
|
||||
const int32 inp_element =
|
||||
xla::LiteralUtil::Get<int>(inp_literal, {j});
|
||||
const int32 inp0_element = inp0_literal.Get<int>({j});
|
||||
const int32 inp_element = inp_literal.Get<int>({j});
|
||||
OP_REQUIRES(
|
||||
ctx, (inp0_element == inp_element),
|
||||
errors::InvalidArgument("input[", i, ",", j, "] mismatch: ",
|
||||
|
@ -103,8 +103,7 @@ class DynamicStitchOp : public XlaOpKernel {
|
||||
int max_index = -1;
|
||||
for (int input_num = 0; input_num < indices.size(); input_num++) {
|
||||
for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) {
|
||||
max_index = std::max(
|
||||
max_index, xla::LiteralUtil::Get<int>(indices[input_num], {i}));
|
||||
max_index = std::max(max_index, indices[input_num].Get<int>({i}));
|
||||
}
|
||||
}
|
||||
int number_of_indices = max_index + 1;
|
||||
@ -118,7 +117,7 @@ class DynamicStitchOp : public XlaOpKernel {
|
||||
int index_used_count = 0;
|
||||
for (int input_num = 0; input_num < indices.size(); input_num++) {
|
||||
for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) {
|
||||
int index = xla::LiteralUtil::Get<int>(indices[input_num], {i});
|
||||
int index = indices[input_num].Get<int>({i});
|
||||
src_input_vector[index] = input_num;
|
||||
src_slice_vector[index] = i;
|
||||
if (!src_index_used[index]) {
|
||||
|
@ -52,7 +52,7 @@ class FillOp : public XlaOpKernel {
|
||||
std::vector<int64> broadcast;
|
||||
broadcast.reserve(dims_literal.shape().dimensions(0));
|
||||
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
|
||||
broadcast.push_back(xla::LiteralUtil::Get<int>(dims_literal, {i}));
|
||||
broadcast.push_back(dims_literal.Get<int>({i}));
|
||||
}
|
||||
// Look up the value input, reshaping to a scalar if it was a
|
||||
// 'legacy' scalar (secretly a vector).
|
||||
|
@ -66,10 +66,10 @@ class GatherOp : public XlaOpKernel {
|
||||
std::vector<xla::ComputationDataHandle> args;
|
||||
args.push_back(tc.GetOrCreateRuntimeContextParameter());
|
||||
args.push_back(b.ConstantLiteral(
|
||||
*xla::LiteralUtil::CreateR0<int64>(indices_shape.num_elements())));
|
||||
*xla::Literal::CreateR0<int64>(indices_shape.num_elements())));
|
||||
args.push_back(b.ConstantLiteral(
|
||||
*xla::LiteralUtil::CreateR0<int64>(params_shape.dim_size(0))));
|
||||
args.push_back(b.ConstantLiteral(*xla::LiteralUtil::CreateR0<int64>(
|
||||
*xla::Literal::CreateR0<int64>(params_shape.dim_size(0))));
|
||||
args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0<int64>(
|
||||
params_shape.num_elements() / params_shape.dim_size(0))));
|
||||
args.push_back(ctx->Input(0));
|
||||
args.push_back(ctx->Input(1));
|
||||
|
@ -69,7 +69,7 @@ class ArgMaxOp : public XlaOpKernel {
|
||||
// XLA op would have the same requirement.
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
|
||||
const int32 dim = xla::LiteralUtil::Get<int32>(literal, {});
|
||||
const int32 dim = literal.Get<int32>({});
|
||||
OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
|
||||
OP_REQUIRES(
|
||||
ctx, dim < input_shape.dims(),
|
||||
@ -97,14 +97,13 @@ class ArgMaxOp : public XlaOpKernel {
|
||||
std::vector<xla::ComputationDataHandle> args;
|
||||
args.push_back(ctx->Input(0));
|
||||
args.push_back(b.ConstantLiteral(
|
||||
*xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
|
||||
*xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
|
||||
if (input_shape.dims() > 1) {
|
||||
// Don't bother passing the output shape and dim for the 1d case, since
|
||||
// the shape is always a scalar and the dim is always 0.
|
||||
args.push_back(b.ConstantLiteral(
|
||||
*xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
|
||||
args.push_back(
|
||||
b.ConstantLiteral(*xla::LiteralUtil::CreateR0<int32>(dim)));
|
||||
*xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
|
||||
args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0<int32>(dim)));
|
||||
}
|
||||
|
||||
xla::Shape xla_shape =
|
||||
|
@ -60,8 +60,8 @@ class PadOp : public XlaOpKernel {
|
||||
xla::PaddingConfig config;
|
||||
for (int i = 0; i < fixed_dims; ++i) {
|
||||
auto* dim = config.add_dimensions();
|
||||
int before = xla::LiteralUtil::Get<int32>(pad_literal, {i, 0});
|
||||
int after = xla::LiteralUtil::Get<int32>(pad_literal, {i, 1});
|
||||
int before = pad_literal.Get<int32>({i, 0});
|
||||
int after = pad_literal.Get<int32>({i, 1});
|
||||
OP_REQUIRES(ctx, before >= 0 && after >= 0,
|
||||
errors::InvalidArgument("Paddings must be non-negative: ",
|
||||
before, " ", after));
|
||||
|
@ -63,7 +63,7 @@ class MinOp : public XlaReductionOp {
|
||||
xla::ComputationBuilder* builder) override {
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
|
||||
return builder->ConstantLiteral(xla::LiteralUtil::MaxValue(type));
|
||||
return builder->ConstantLiteral(xla::Literal::MaxValue(type));
|
||||
}
|
||||
|
||||
void BuildReducer(xla::ComputationBuilder* builder,
|
||||
@ -83,7 +83,7 @@ class MaxOp : public XlaReductionOp {
|
||||
xla::ComputationBuilder* builder) override {
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
|
||||
return builder->ConstantLiteral(xla::LiteralUtil::MinValue(type));
|
||||
return builder->ConstantLiteral(xla::Literal::MinValue(type));
|
||||
}
|
||||
|
||||
void BuildReducer(xla::ComputationBuilder* builder,
|
||||
|
@ -66,13 +66,13 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
|
||||
1, {axes_tensor_shape.num_elements()}, &axes_literal));
|
||||
|
||||
VLOG(1) << "data shape: " << data_shape.DebugString();
|
||||
VLOG(1) << "axes : " << xla::LiteralUtil::ToString(axes_literal);
|
||||
VLOG(1) << "axes : " << axes_literal.ToString();
|
||||
|
||||
gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
|
||||
std::vector<int64> xla_axes;
|
||||
int64 num_elements_reduced = 1LL;
|
||||
for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) {
|
||||
int32 index = xla::LiteralUtil::Get<int>(axes_literal, {i});
|
||||
int32 index = axes_literal.Get<int>({i});
|
||||
OP_REQUIRES(ctx,
|
||||
!(index < -data_shape.dims() || index >= data_shape.dims()),
|
||||
errors::InvalidArgument("Invalid reduction dimension (", index,
|
||||
|
@ -50,7 +50,7 @@ class ReshapeOp : public XlaOpKernel {
|
||||
int64 product = 1;
|
||||
int unknown_index = -1;
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
const int32 size = xla::LiteralUtil::Get<int>(literal, {d});
|
||||
const int32 size = literal.Get<int>({d});
|
||||
if (size == -1) {
|
||||
OP_REQUIRES(
|
||||
ctx, unknown_index == -1,
|
||||
|
@ -32,7 +32,7 @@ template <typename T>
|
||||
Status GetValue(int index, XlaOpKernelContext* ctx, T* value) {
|
||||
xla::Literal literal;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
|
||||
*value = xla::LiteralUtil::Get<T>(literal, {});
|
||||
*value = literal.Get<T>({});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -41,10 +41,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
|
||||
switch (literal.shape().element_type()) {
|
||||
case xla::S32:
|
||||
*value = xla::LiteralUtil::Get<int32>(literal, {});
|
||||
*value = literal.Get<int32>({});
|
||||
break;
|
||||
case xla::S64:
|
||||
*value = xla::LiteralUtil::Get<int64>(literal, {});
|
||||
*value = literal.Get<int64>({});
|
||||
break;
|
||||
default:
|
||||
return errors::InvalidArgument("Invalid argument type for argument",
|
||||
@ -58,9 +58,9 @@ template <typename T>
|
||||
Status CreateRangeTensor(const xla::Literal& start_literal,
|
||||
const xla::Literal& limit_literal,
|
||||
const xla::Literal& delta_literal, Tensor* output) {
|
||||
T start = xla::LiteralUtil::Get<T>(start_literal, {});
|
||||
T limit = xla::LiteralUtil::Get<T>(limit_literal, {});
|
||||
T delta = xla::LiteralUtil::Get<T>(delta_literal, {});
|
||||
T start = start_literal.Get<T>({});
|
||||
T limit = limit_literal.Get<T>({});
|
||||
T delta = delta_literal.Get<T>({});
|
||||
|
||||
if (delta == 0) {
|
||||
return errors::InvalidArgument("Requires delta != 0: ", delta);
|
||||
|
@ -56,8 +56,8 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
|
||||
padding_config.add_dimensions(); // Don't pad the batch dimension.
|
||||
for (int i = 0; i < block_rank; ++i) {
|
||||
auto* dim = padding_config.add_dimensions();
|
||||
int64 pad_start = xla::LiteralUtil::Get<int64>(paddings, {i, 0});
|
||||
int64 pad_end = xla::LiteralUtil::Get<int64>(paddings, {i, 1});
|
||||
int64 pad_start = paddings.Get<int64>({i, 0});
|
||||
int64 pad_end = paddings.Get<int64>({i, 1});
|
||||
OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0,
|
||||
errors::InvalidArgument("Paddings must be non-negative"));
|
||||
dim->set_edge_padding_low(pad_start);
|
||||
|
@ -39,7 +39,7 @@ class SplitOp : public XlaOpKernel {
|
||||
|
||||
int32 split_dim;
|
||||
if (index_shape.dims() == 0) {
|
||||
split_dim = xla::LiteralUtil::Get<int>(literal_index, {});
|
||||
split_dim = literal_index.Get<int>({});
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
ctx, index_shape.dims() == 1,
|
||||
@ -49,7 +49,7 @@ class SplitOp : public XlaOpKernel {
|
||||
ctx, index_shape.dim_size(0) == 1,
|
||||
errors::InvalidArgument("split_index input to Split Op must be a "
|
||||
"scalar or a vector with 1 element"));
|
||||
split_dim = xla::LiteralUtil::Get<int>(literal_index, {0});
|
||||
split_dim = literal_index.Get<int>({0});
|
||||
}
|
||||
const int32 num_split = num_outputs();
|
||||
const TensorShape input_shape = ctx->InputShape(1);
|
||||
@ -115,7 +115,7 @@ class SplitVOp : public XlaOpKernel {
|
||||
OP_REQUIRES(ctx, index_shape.dims() == 0,
|
||||
errors::InvalidArgument("split_dim input to Split Op must be a "
|
||||
"scalar"));
|
||||
split_dim = xla::LiteralUtil::Get<int>(literal_index, {});
|
||||
split_dim = literal_index.Get<int>({});
|
||||
|
||||
xla::ComputationDataHandle input = ctx->Input(0);
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
@ -152,7 +152,7 @@ class SplitVOp : public XlaOpKernel {
|
||||
|
||||
for (int i = 0; i < num_split; ++i) {
|
||||
int slice_size;
|
||||
slice_size = xla::LiteralUtil::Get<int>(split_size_literal, {i});
|
||||
slice_size = split_size_literal.Get<int>({i});
|
||||
if (slice_size == -1) {
|
||||
OP_REQUIRES(
|
||||
ctx, neg_one_dim == -1,
|
||||
|
@ -41,32 +41,36 @@ namespace {
|
||||
// Since the element shape is not always provided to the TensorArrayV3 operator,
|
||||
// we must support lazily initialization of the TensorArray at the time of the
|
||||
// first write.
|
||||
// If a TensorArray `var` has not been initialized, constructs storage for the
|
||||
// TensorArray with elements of `elem_shape`. For both initialized and
|
||||
// If a TensorArray `resource` has not been initialized, constructs storage for
|
||||
// the TensorArray with elements of `elem_shape`. For both initialized and
|
||||
// uninitialized TensorArrays, checks that the tensor has a type compatible with
|
||||
// 'dtype' and shape compatible with 'elem_shape'.
|
||||
Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
|
||||
XlaVariable* var, DataType dtype,
|
||||
XlaResource* resource, DataType dtype,
|
||||
const TensorShape& elem_shape) {
|
||||
if (var->type != dtype) {
|
||||
if (resource->kind != XlaResource::kTensorArray) {
|
||||
return errors::InvalidArgument("Unexpected non-TensorArray resource");
|
||||
}
|
||||
|
||||
if (resource->type != dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"TensorArray dtype is ", DataTypeString(var->type),
|
||||
"TensorArray dtype is ", DataTypeString(resource->type),
|
||||
" but op has dtype ", DataTypeString(dtype), ".");
|
||||
}
|
||||
|
||||
TF_RET_CHECK(var->tensor_array_size >= 0)
|
||||
<< var->name << " size " << var->tensor_array_size;
|
||||
TF_RET_CHECK(resource->tensor_array_size >= 0)
|
||||
<< resource->name << " size " << resource->tensor_array_size;
|
||||
TensorShape ta_shape;
|
||||
ta_shape.AddDim(var->tensor_array_size);
|
||||
ta_shape.AddDim(resource->tensor_array_size);
|
||||
ta_shape.AppendShape(elem_shape);
|
||||
|
||||
if (var->value.handle() == 0) {
|
||||
if (resource->value.handle() == 0) {
|
||||
// TensorArray has not been initialized.
|
||||
xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type);
|
||||
var->value = builder->Broadcast(zero, ta_shape.dim_sizes());
|
||||
xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type);
|
||||
resource->value = builder->Broadcast(zero, ta_shape.dim_sizes());
|
||||
} else {
|
||||
// Checks the elem_shape matches the TensorArray shape.
|
||||
auto shape_or_status = builder->GetShape(var->value);
|
||||
auto shape_or_status = builder->GetShape(resource->value);
|
||||
if (!shape_or_status.ok()) {
|
||||
return shape_or_status.status();
|
||||
}
|
||||
@ -80,6 +84,44 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks that the TensorArray 'resource' has been initialized, and has type
|
||||
// 'dtype'. Sets 'shape' to the shape
|
||||
Status CheckTensorArrayIsInitialized(const string& op_name,
|
||||
const XlaResource* resource,
|
||||
DataType dtype) {
|
||||
if (resource->kind != XlaResource::kTensorArray) {
|
||||
return errors::InvalidArgument(
|
||||
"Unexpected non-TensorArray resource passed "
|
||||
"to ",
|
||||
op_name);
|
||||
}
|
||||
if (resource->value.handle() == 0) {
|
||||
return errors::InvalidArgument("Uninitialized TensorArray passed to ",
|
||||
op_name);
|
||||
}
|
||||
if (resource->type != dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"TensorArray dtype is ", DataTypeString(resource->type),
|
||||
" but op has dtype ", DataTypeString(dtype), ".");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetTensorArrayShape(const XlaResource* resource,
|
||||
xla::ComputationBuilder* builder,
|
||||
TensorShape* shape) {
|
||||
auto shape_or_status = builder->GetShape(resource->value);
|
||||
if (!shape_or_status.ok()) {
|
||||
return shape_or_status.status();
|
||||
}
|
||||
*shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie());
|
||||
if (shape->dims() < 1) {
|
||||
return errors::InvalidArgument("TensorArray rank must be >= 1");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
|
||||
xla::ComputationDataHandle PadIndexWithZeros(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
|
||||
@ -125,7 +167,6 @@ class TensorArrayOp : public XlaOpKernel {
|
||||
errors::InvalidArgument("TensorArray size must be >= 0"));
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
b->set_die_immediately_on_error(true);
|
||||
|
||||
// Initializes the TensorArray value if we know the element shape.
|
||||
// Otherwise, defer initialization to the first write.
|
||||
@ -141,12 +182,13 @@ class TensorArrayOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
XlaContext& xc = XlaContext::Get(ctx);
|
||||
XlaVariable* var;
|
||||
XlaResource* var;
|
||||
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
|
||||
OP_REQUIRES_OK(ctx,
|
||||
xc.CreateVariable(-1, std::move(name), dtype_, value, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
|
||||
dtype_, value, &var));
|
||||
var->tensor_array_size = size;
|
||||
ctx->SetVariableOutput(0, var);
|
||||
ctx->SetResourceOutput(0, var);
|
||||
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
|
||||
}
|
||||
|
||||
@ -173,11 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel {
|
||||
|
||||
// Initializes the TensorArray, if the element shape was not known at
|
||||
// construction time.
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
|
||||
OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
|
||||
|
||||
xla::ComputationDataHandle ta = var->value;
|
||||
xla::ComputationDataHandle ta = resource->value;
|
||||
xla::ComputationDataHandle index = ctx->Input(1);
|
||||
xla::ComputationDataHandle value = ctx->Input(2);
|
||||
|
||||
@ -191,7 +234,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
|
||||
xla::ComputationDataHandle written =
|
||||
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written));
|
||||
resource->value = written;
|
||||
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
|
||||
}
|
||||
|
||||
@ -210,20 +253,17 @@ class TensorArrayReadOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType ta_type;
|
||||
TensorShape ta_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
|
||||
OP_REQUIRES(ctx, ta_type == dtype_,
|
||||
errors::InvalidArgument(
|
||||
"TensorArray dtype is ", DataTypeString(ta_type),
|
||||
" but Op requested dtype ", DataTypeString(dtype_), "."));
|
||||
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
|
||||
errors::InvalidArgument("TensorArray rank must be >= 1"));
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
xla::ComputationDataHandle ta;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CheckTensorArrayIsInitialized(name(), resource, dtype_));
|
||||
TensorShape ta_shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
|
||||
|
||||
xla::ComputationDataHandle ta = resource->value;
|
||||
xla::ComputationDataHandle index = ctx->Input(1);
|
||||
|
||||
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
||||
@ -255,13 +295,15 @@ class TensorArrayGatherOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType ta_type;
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CheckTensorArrayIsInitialized(name(), resource, dtype_));
|
||||
TensorShape ta_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
|
||||
OP_REQUIRES(ctx, ta_type == dtype_,
|
||||
errors::InvalidArgument("TensorArray type mismatch"));
|
||||
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
|
||||
errors::InvalidArgument("TensorArray rank must be >= 1"));
|
||||
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
|
||||
|
||||
const TensorShape indices_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
|
||||
@ -269,10 +311,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
|
||||
const int num_indices = indices_shape.dim_size(0);
|
||||
auto indices = ctx->Input(1);
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
xla::ComputationDataHandle ta;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
|
||||
xla::ComputationDataHandle ta = resource->value;
|
||||
|
||||
// For each index in `indices`, add the corresponding slice to `slices`.
|
||||
std::vector<xla::ComputationDataHandle> slices(num_indices);
|
||||
@ -320,11 +359,12 @@ class TensorArrayScatterOp : public XlaOpKernel {
|
||||
|
||||
const TensorShape value_shape = ctx->InputShape(2);
|
||||
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
TensorShape elem_shape = value_shape;
|
||||
elem_shape.RemoveDim(0);
|
||||
OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
|
||||
|
||||
const TensorShape indices_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
|
||||
@ -332,7 +372,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
|
||||
const int num_indices = indices_shape.dim_size(0);
|
||||
const xla::ComputationDataHandle indices = ctx->Input(1);
|
||||
|
||||
xla::ComputationDataHandle ta = var->value;
|
||||
xla::ComputationDataHandle ta = resource->value;
|
||||
const xla::ComputationDataHandle value = ctx->Input(2);
|
||||
|
||||
auto slice_dims = value_shape.dim_sizes();
|
||||
@ -357,7 +397,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
|
||||
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
|
||||
resource->value = ta;
|
||||
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
|
||||
}
|
||||
|
||||
@ -376,18 +416,17 @@ class TensorArrayConcatOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType ta_type;
|
||||
TensorShape ta_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
|
||||
OP_REQUIRES(ctx, ta_type == dtype_,
|
||||
errors::InvalidArgument("TensorArray type mismatch"));
|
||||
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
|
||||
errors::InvalidArgument("TensorArray rank must be >= 1"));
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
xla::ComputationDataHandle ta;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CheckTensorArrayIsInitialized(name(), resource, dtype_));
|
||||
TensorShape ta_shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
|
||||
|
||||
xla::ComputationDataHandle ta = resource->value;
|
||||
|
||||
auto ta_dims = ta_shape.dim_sizes();
|
||||
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
|
||||
@ -438,19 +477,20 @@ class TensorArraySplitOp : public XlaOpKernel {
|
||||
elem_shape.set_dim(0, length);
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
|
||||
OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
|
||||
xla::ComputationDataHandle ta = var->value;
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
|
||||
xla::ComputationDataHandle ta = resource->value;
|
||||
|
||||
TensorShape ta_shape;
|
||||
ta_shape.AddDim(var->tensor_array_size);
|
||||
ta_shape.AddDim(resource->tensor_array_size);
|
||||
ta_shape.AppendShape(elem_shape);
|
||||
|
||||
OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size,
|
||||
OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size,
|
||||
errors::InvalidArgument(
|
||||
"TensorArray's size is not equal to the size of lengths (",
|
||||
lengths.size(), " vs. ", var->tensor_array_size, ")"));
|
||||
lengths.size(), " vs. ", resource->tensor_array_size, ")"));
|
||||
|
||||
const xla::ComputationDataHandle value = ctx->Input(1);
|
||||
|
||||
@ -459,8 +499,7 @@ class TensorArraySplitOp : public XlaOpKernel {
|
||||
value_shape.DebugString(), " vs. ",
|
||||
ta_shape.DebugString()));
|
||||
|
||||
ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
|
||||
resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
|
||||
|
||||
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
|
||||
}
|
||||
@ -478,8 +517,8 @@ class TensorArraySizeOp : public XlaOpKernel {
|
||||
explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
|
||||
XlaResource* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var));
|
||||
Tensor size_tensor(DT_INT32, {});
|
||||
size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size);
|
||||
ctx->SetConstantOutput(0, size_tensor);
|
||||
@ -500,31 +539,31 @@ class TensorArrayGradOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
|
||||
DataType ta_type;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type));
|
||||
TensorShape ta_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
|
||||
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
|
||||
errors::InvalidArgument("TensorArray rank must be >= 1"));
|
||||
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
|
||||
|
||||
// Finds or looks up the corresponding gradient TensorArray, which stores
|
||||
// gradients computed during backpropagation.
|
||||
XlaVariable*& gradient = var->tensor_array_gradient[source_];
|
||||
XlaResource*& gradient = resource->tensor_array_gradient[source_];
|
||||
if (!gradient) {
|
||||
xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type);
|
||||
xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type);
|
||||
xla::ComputationDataHandle value =
|
||||
b->Broadcast(zero, ta_shape.dim_sizes());
|
||||
|
||||
XlaContext& xc = XlaContext::Get(ctx);
|
||||
string name = strings::StrCat("TensorArrayGrad: ", var->name);
|
||||
OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type,
|
||||
value, &gradient));
|
||||
gradient->tensor_array_size = var->tensor_array_size;
|
||||
string name = strings::StrCat("TensorArrayGrad: ", resource->name);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
|
||||
resource->type, value, &gradient));
|
||||
gradient->tensor_array_size = resource->tensor_array_size;
|
||||
}
|
||||
|
||||
ctx->SetVariableOutput(0, gradient);
|
||||
ctx->SetResourceOutput(0, gradient);
|
||||
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
|
||||
}
|
||||
|
||||
@ -536,5 +575,19 @@ class TensorArrayGradOp : public XlaOpKernel {
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp);
|
||||
|
||||
class TensorArrayCloseOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorArrayCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
// Do nothing; XLA handles resource management.
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayCloseOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArrayCloseV3"), TensorArrayCloseOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -68,7 +68,7 @@ class TileOp : public XlaOpKernel {
|
||||
bool all_multiples_are_one = true;
|
||||
bool one_dimension_is_broadcasted_without_multiple = true;
|
||||
for (int i = 0; i < input_dims; ++i) {
|
||||
int multiple = xla::LiteralUtil::Get<int>(literal, {i});
|
||||
int multiple = literal.Get<int>({i});
|
||||
OP_REQUIRES(ctx, multiple,
|
||||
errors::InvalidArgument("Expected multiples[", i,
|
||||
"] >= 0, but got ", multiple));
|
||||
|
@ -44,6 +44,7 @@ namespace {
|
||||
// Return x if x>0, otherwise -x.
|
||||
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
|
||||
XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x));
|
||||
XLAJIT_MAKE_UNARY(Cos, b->Cos(x));
|
||||
XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
|
||||
XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
|
||||
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
|
||||
|
265
tensorflow/compiler/tf2xla/kernels/while_op.cc
Normal file
265
tensorflow/compiler/tf2xla/kernels/while_op.cc
Normal file
@ -0,0 +1,265 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Builds XlaCompiler argument descriptions `args` from `ctx`.
|
||||
Status MakeXlaCompilerArgumentsFromInputs(
|
||||
XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args,
|
||||
bool* has_uninitialized_vars) {
|
||||
VLOG(2) << "Num inputs " << ctx->num_inputs();
|
||||
args->resize(ctx->num_inputs());
|
||||
*has_uninitialized_vars = false;
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
VLOG(2) << " Input " << i
|
||||
<< " type: " << DataTypeString(ctx->input_type(i))
|
||||
<< " shape: " << ctx->InputShape(i).DebugString();
|
||||
XlaCompiler::Argument& arg = (*args)[i];
|
||||
DataType type = ctx->input_type(i);
|
||||
// When reading a resource input, use the type and shape of the resource's
|
||||
// current value.
|
||||
if (type == DT_RESOURCE) {
|
||||
XlaResource* resource;
|
||||
TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource));
|
||||
|
||||
arg.initialized = resource->value.handle() > 0;
|
||||
switch (resource->kind) {
|
||||
case XlaResource::kVariable:
|
||||
arg.kind = XlaCompiler::Argument::kVariable;
|
||||
break;
|
||||
case XlaResource::kTensorArray:
|
||||
arg.kind = XlaCompiler::Argument::kTensorArray;
|
||||
break;
|
||||
case XlaResource::kInvalid:
|
||||
CHECK(false);
|
||||
}
|
||||
arg.type = resource->type;
|
||||
if (arg.initialized) {
|
||||
auto shape = ctx->builder()->GetShape(resource->value);
|
||||
TF_RETURN_IF_ERROR(shape.status());
|
||||
arg.shape = XLAShapeToTensorShape(*shape.ValueOrDie());
|
||||
} else {
|
||||
*has_uninitialized_vars = true;
|
||||
}
|
||||
arg.tensor_array_size = resource->tensor_array_size;
|
||||
arg.name = resource->name;
|
||||
// TODO(phawkins): propagate TensorArray gradients into loops.
|
||||
VLOG(2) << " resource " << resource->name
|
||||
<< " type: " << DataTypeString(arg.type)
|
||||
<< " shape: " << arg.shape.DebugString()
|
||||
<< " initialized: " << arg.initialized;
|
||||
|
||||
} else {
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
arg.type = ctx->input_type(i);
|
||||
arg.shape = ctx->InputShape(i);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
const NameAttrList* name_attr;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &name_attr));
|
||||
cond_name_attr_ = *name_attr;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
|
||||
body_name_attr_ = *name_attr;
|
||||
}
|
||||
|
||||
void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
VLOG(1) << "WhileOp::Compile";
|
||||
|
||||
std::vector<XlaCompiler::Argument> arguments;
|
||||
bool has_uninitialized_vars;
|
||||
OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs(
|
||||
ctx, &arguments, &has_uninitialized_vars));
|
||||
|
||||
const bool use_tuple_arg = (arguments.size() != 1);
|
||||
|
||||
xla::ComputationBuilder* builder = ctx->builder();
|
||||
XlaCompiler* compiler = ctx->compiler();
|
||||
|
||||
VLOG(1) << "Compiling body";
|
||||
|
||||
// All resource that are inputs to the loop's body must also be
|
||||
// present as loop body outputs; the signature of the loop's input and
|
||||
// output must match. We ensure this by asking the compiler to include the
|
||||
// current values of all resources, even if they haven't been updated by the
|
||||
// computation.
|
||||
// TODO(phawkins): consider adding loop-invariant inputs to XLA's While()
|
||||
// operator.
|
||||
XlaCompiler::CompileOptions body_options;
|
||||
body_options.use_tuple_arg = use_tuple_arg;
|
||||
body_options.return_updated_values_for_all_resources = true;
|
||||
XlaCompiler::CompilationResult body;
|
||||
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
|
||||
arguments, &body));
|
||||
|
||||
// We must use a static shape for parameters to an XLA compilation. However,
|
||||
// we may not know the shape of a TensorArray if it is first written inside
|
||||
// the loop. Ideally we would require the user to provide a static shape,
|
||||
// but this is not always easy.
|
||||
// So if uninitialized resource are used by the loop body, we compile the
|
||||
// body function twice:
|
||||
// 1) once with uninitialized resource inputs. We discard the computation
|
||||
// but we assume resource shapes reach a fixpoint after one iteration.
|
||||
// So we can use the output shapes of the resource as the "true" shapes.
|
||||
// 2) again with the "correct" input shapes determined by (1).
|
||||
if (has_uninitialized_vars) {
|
||||
// Initializes any uninitialized resource with zero values of the
|
||||
// shape determined by the first compilation.
|
||||
for (int i = 0; i < body.resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
|
||||
XlaCompiler::Argument& arg = arguments[update.input_index];
|
||||
if (!arg.initialized) {
|
||||
arg.initialized = true;
|
||||
arg.shape = update.shape;
|
||||
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetResourceInput(update.input_index, &resource));
|
||||
|
||||
xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, arg.type);
|
||||
resource->value = builder->Broadcast(zero, update.shape.dim_sizes());
|
||||
}
|
||||
}
|
||||
// Recompile the body with the "correct" shapes.
|
||||
body = {};
|
||||
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
|
||||
arguments, &body));
|
||||
}
|
||||
|
||||
VLOG(1) << "Compiling condition";
|
||||
|
||||
XlaCompiler::CompileOptions cond_options;
|
||||
cond_options.use_tuple_arg = use_tuple_arg;
|
||||
XlaCompiler::CompilationResult cond;
|
||||
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
|
||||
arguments, &cond));
|
||||
|
||||
xla::Shape body_input_shape, cond_input_shape;
|
||||
if (use_tuple_arg) {
|
||||
body_input_shape = xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes);
|
||||
cond_input_shape = xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes);
|
||||
} else {
|
||||
CHECK(!body.xla_input_shapes.empty());
|
||||
body_input_shape = body.xla_input_shapes[0];
|
||||
CHECK(!body.xla_input_shapes.empty());
|
||||
cond_input_shape = cond.xla_input_shapes[0];
|
||||
}
|
||||
|
||||
VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape)
|
||||
<< " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape);
|
||||
VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape)
|
||||
<< " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape);
|
||||
|
||||
OP_REQUIRES(ctx,
|
||||
xla::ShapeUtil::Compatible(body_input_shape, cond_input_shape),
|
||||
errors::InvalidArgument(
|
||||
"Input shapes of loop body and condition do not match: ",
|
||||
xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
|
||||
xla::ShapeUtil::HumanString(cond_input_shape)));
|
||||
OP_REQUIRES(
|
||||
ctx, xla::ShapeUtil::Compatible(body_input_shape, body.xla_output_shape),
|
||||
errors::InvalidArgument(
|
||||
"Input and output shapes of loop body do not match: ",
|
||||
xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
|
||||
xla::ShapeUtil::HumanString(body.xla_output_shape)));
|
||||
|
||||
xla::ComputationDataHandle data;
|
||||
|
||||
int num_inputs = body.input_mapping.size();
|
||||
|
||||
std::vector<xla::ComputationDataHandle> inputs(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
int input_num = body.input_mapping[i];
|
||||
if (ctx->input_type(input_num) == DT_RESOURCE) {
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
|
||||
inputs[i] = resource->value;
|
||||
} else {
|
||||
inputs[i] = ctx->Input(i);
|
||||
}
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle init;
|
||||
if (use_tuple_arg) {
|
||||
init = builder->Tuple(inputs);
|
||||
} else {
|
||||
init = inputs[0];
|
||||
}
|
||||
|
||||
VLOG(1) << "Building while loop";
|
||||
|
||||
xla::ComputationDataHandle while_result =
|
||||
builder->While(*cond.computation, *body.computation, init);
|
||||
|
||||
auto get_loop_output = [&](int i) {
|
||||
if (use_tuple_arg) {
|
||||
return builder->GetTupleElement(while_result, i);
|
||||
} else {
|
||||
return while_result;
|
||||
}
|
||||
};
|
||||
|
||||
// Sets non-variable outputs.
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
if (ctx->input_type(i) != DT_RESOURCE) {
|
||||
ctx->SetOutput(body.input_mapping[i], get_loop_output(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Updates the values of any resource variables modified by the loop.
|
||||
for (int i = 0; i < body.resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
|
||||
if (update.modified) {
|
||||
int pos = body.outputs.size() + i;
|
||||
resource->value = get_loop_output(pos);
|
||||
}
|
||||
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
|
||||
<< " name: " << resource->name << " modified: " << update.modified
|
||||
<< " type: " << DataTypeString(update.type)
|
||||
<< " shape: " << update.shape.DebugString();
|
||||
// Copies the identity of the resource variable from input to output
|
||||
// unchanged, even if the variable was not modified.
|
||||
ctx->op_kernel_context()->set_output(
|
||||
update.input_index,
|
||||
ctx->op_kernel_context()->input(update.input_index));
|
||||
}
|
||||
|
||||
VLOG(1) << "Done building while loop";
|
||||
}
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp);
|
||||
|
||||
} // namespace tensorflow
|
65
tensorflow/compiler/tf2xla/kernels/while_op.h
Normal file
65
tensorflow/compiler/tf2xla/kernels/while_op.h
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This TensorFlow op provides a functional iteration primitive.
|
||||
//
|
||||
// The inputs and outputs of the loop body must agree on the number, types, and
|
||||
// shapes of the Tensors carried around the loop body.
|
||||
//
|
||||
// Computations in while loops may read from and write to resource variables.
|
||||
// Resource variables may be passed as arguments to a function's body and
|
||||
// condition functions. The XlaCompiler converts resource variable arguments
|
||||
// into parameters to the XLA computation and moves them to the end of the
|
||||
// parameter list, and by using the `return_updated_values_for_all_variables`
|
||||
// we ensure that all variables that appear in the input also appear at the
|
||||
// end of the body's output. This ensures the loop body's input and output
|
||||
// signatures match.
|
||||
//
|
||||
// It is the user's responsibility to ensure that each non-variable _Arg matches
|
||||
// the corresponding _Retval.
|
||||
//
|
||||
// For example, suppose we have a loop body with arguments:
|
||||
// DT_INT32, DT_RESOURCE (pointing to a DT_BOOL var), DT_FLOAT
|
||||
// and return values
|
||||
// DT_INT32, DT_FLOAT
|
||||
// It is an error for the body to return DT_RESOURCE values.
|
||||
//
|
||||
// The body will be lowered into an XLA computation that takes and returns a
|
||||
// tuple with XLA type (I32, F32, PRED). Note the resource variable appears at
|
||||
// the end of both the loop body's input and output argument lists.
|
||||
class XlaWhileOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit XlaWhileOp(OpKernelConstruction* ctx);
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
NameAttrList cond_name_attr_;
|
||||
NameAttrList body_name_attr_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_
|
@ -27,13 +27,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
|
||||
host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape()));
|
||||
|
||||
xla::LiteralUtil::Reserve(host_tensor.NumElements(), literal);
|
||||
literal->Reserve(host_tensor.NumElements());
|
||||
|
||||
// memcpy over the payload ...
|
||||
// TODO(phawkins): handle string types.
|
||||
size_t total_bytes = host_tensor.TotalBytes();
|
||||
if (total_bytes > 0) {
|
||||
void* dst_ptr = xla::LiteralUtil::MutableInternalData(literal);
|
||||
void* dst_ptr = literal->MutableInternalData();
|
||||
const void* src_ptr = DMAHelper::base(&host_tensor);
|
||||
memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
}
|
||||
@ -55,7 +55,7 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
|
||||
*host_tensor = Tensor(target_type, shape);
|
||||
size_t total_bytes = host_tensor->TotalBytes();
|
||||
if (total_bytes > 0) {
|
||||
const void* src_ptr = xla::LiteralUtil::InternalData(literal);
|
||||
const void* src_ptr = literal.InternalData();
|
||||
void* dst_ptr = DMAHelper::base(host_tensor);
|
||||
memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
|
||||
{
|
||||
std::vector<int64> int64_values = {1, 2, 3};
|
||||
std::unique_ptr<xla::Literal> int64_values_literal =
|
||||
xla::LiteralUtil::CreateR1(gtl::ArraySlice<int64>(int64_values));
|
||||
xla::Literal::CreateR1(gtl::ArraySlice<int64>(int64_values));
|
||||
Tensor host_tensor;
|
||||
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
|
||||
LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
|
||||
@ -48,7 +48,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
|
||||
Tensor host_tensor;
|
||||
std::vector<int32> int32_values = {10, 11};
|
||||
std::unique_ptr<xla::Literal> int32_values_literal =
|
||||
xla::LiteralUtil::CreateR1(gtl::ArraySlice<int32>(int32_values));
|
||||
xla::Literal::CreateR1(gtl::ArraySlice<int32>(int32_values));
|
||||
EXPECT_TRUE(
|
||||
LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
|
||||
.ok());
|
||||
|
38
tensorflow/compiler/tf2xla/ops/BUILD
Normal file
38
tensorflow/compiler/tf2xla/ops/BUILD
Normal file
@ -0,0 +1,38 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
|
||||
cc_library(
|
||||
name = "functional_ops",
|
||||
srcs = ["functional_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_functional_ops",
|
||||
out = "gen_functional_ops.py",
|
||||
deps = [
|
||||
":functional_ops",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
45
tensorflow/compiler/tf2xla/ops/functional_ops.cc
Normal file
45
tensorflow/compiler/tf2xla/ops/functional_ops.cc
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TODO(b/37549631) setting the While Op to always be stateful is too
|
||||
// conservative.
|
||||
REGISTER_OP("XlaWhile")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: list(type) >= 0")
|
||||
.Attr("cond: func")
|
||||
.Attr("body: func")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
output = input; While (Cond(output)) { output = Body(output) }
|
||||
|
||||
input: A list of input tensors whose types are T.
|
||||
output: A list of output tensors whose types are T.
|
||||
cond: A function takes 'input' and returns a tensor. If the tensor is
|
||||
a scalar of non-boolean, the scalar is converted to a boolean
|
||||
according to the following rule: if the scalar is a numerical
|
||||
value, non-zero means True and zero means False; if the scalar is
|
||||
a string, non-empty means True and empty means False. If the
|
||||
tensor is not a scalar, non-emptiness means True and False
|
||||
otherwise.
|
||||
body: A function that takes a list of tensors and returns another
|
||||
list of tensors. Both lists have the same types as specified by T.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
42
tensorflow/compiler/tf2xla/test_util.cc
Normal file
42
tensorflow/compiler/tf2xla/test_util.cc
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/test_util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status InstantiateFunctionForTest(const string& name,
|
||||
const FunctionLibraryDefinition& library,
|
||||
InstantiationResultForTest* result) {
|
||||
const FunctionDef* fdef = library.Find(name);
|
||||
TF_RET_CHECK(fdef != nullptr);
|
||||
|
||||
auto get_func_sig = [&library](const string& op, const OpDef** sig) {
|
||||
return library.LookUpOpDef(op, sig);
|
||||
};
|
||||
InstantiationResult inst;
|
||||
TF_RETURN_IF_ERROR(
|
||||
InstantiateFunction(*fdef, AttrSlice(), get_func_sig, &inst));
|
||||
result->arg_types = inst.arg_types;
|
||||
result->ret_types = inst.ret_types;
|
||||
for (NodeDef& n : inst.nodes) {
|
||||
*result->gdef.add_node() = std::move(n);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
46
tensorflow/compiler/tf2xla/test_util.h
Normal file
46
tensorflow/compiler/tf2xla/test_util.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Helper functions for tests.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Same as InstantiationResult, but has a GraphDef instead of just nodes.
|
||||
struct InstantiationResultForTest {
|
||||
DataTypeVector arg_types;
|
||||
DataTypeVector ret_types;
|
||||
GraphDef gdef;
|
||||
};
|
||||
|
||||
// Instantiates a function, producing a GraphDef to compare against the
|
||||
// expected graph.
|
||||
Status InstantiateFunctionForTest(const string& name,
|
||||
const FunctionLibraryDefinition& library,
|
||||
InstantiationResultForTest* result);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_
|
@ -64,26 +64,35 @@ class XlaCompilationDevice : public LocalDevice {
|
||||
std::unique_ptr<XlaCompilationAllocator> allocator_;
|
||||
};
|
||||
|
||||
struct XlaVariable {
|
||||
// If this variable is visible externally, what was its argument number?
|
||||
// Represents a resource, such as a Variable or TensorArray.
|
||||
struct XlaResource {
|
||||
enum Kind {
|
||||
kInvalid,
|
||||
kVariable,
|
||||
kTensorArray,
|
||||
};
|
||||
|
||||
Kind kind = kInvalid;
|
||||
|
||||
// If this resource is visible externally, what was its argument number?
|
||||
int arg_num = -1;
|
||||
|
||||
// A descriptive name for the variable, used in error messages.
|
||||
// A descriptive name for the resource, used in error messages.
|
||||
string name;
|
||||
|
||||
// Current type and value of the variable. Uninitialized variables are
|
||||
// Current type and value of the resource. Uninitialized resources are
|
||||
// represented by a default (zero) handle and type DT_INVALID.
|
||||
// While the type of a variable is notionally fixed during execution, when
|
||||
// a variable is first initialized we do not yet know its type, so we keep
|
||||
// While the type of a resource is notionally fixed during execution, when
|
||||
// a resource is first initialized we do not yet know its type, so we keep
|
||||
// track of its type dynamically.
|
||||
DataType type = DT_INVALID;
|
||||
xla::ComputationDataHandle value;
|
||||
|
||||
// Value of the variable at computation entry. Used to detect which
|
||||
// Value of the resource at computation entry. Used to detect which
|
||||
// variables have new values that need to be written back.
|
||||
xla::ComputationDataHandle initial_value;
|
||||
|
||||
// We treat TensorArrays as a Variable with some extra metadata.
|
||||
// TensorArray-specific fields
|
||||
|
||||
// 'tensor_array_size' stores the expected size of the TensorArray. We need
|
||||
// to store this since sometimes TensorArrays must be initialized lazily since
|
||||
@ -91,10 +100,10 @@ struct XlaVariable {
|
||||
int64 tensor_array_size = -1;
|
||||
|
||||
// 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
|
||||
// to an XlaVariable containing the gradient TensorArrays. We store a pointer
|
||||
// to an XlaResource containing the gradient TensorArrays. We store a pointer
|
||||
// here since there should only be one gradient TensorArray per 'source'
|
||||
// string, irrespective of the number of calls to TensorArrayGrad.
|
||||
std::unordered_map<string, XlaVariable*> tensor_array_gradient;
|
||||
std::unordered_map<string, XlaResource*> tensor_array_gradient;
|
||||
};
|
||||
|
||||
// A XlaExpression wraps an XLA computation. Each Tensor on an
|
||||
@ -115,8 +124,8 @@ class XlaExpression {
|
||||
bool has_constant_value() const { return has_constant_value_; }
|
||||
const Tensor& constant_value() const { return constant_value_; }
|
||||
|
||||
void set_variable(XlaVariable* variable) { variable_ = variable; }
|
||||
XlaVariable* variable() const { return variable_; }
|
||||
void set_resource(XlaResource* resource) { resource_ = resource; }
|
||||
XlaResource* resource() const { return resource_; }
|
||||
|
||||
private:
|
||||
// The XLA handle of the expression's computation.
|
||||
@ -128,7 +137,7 @@ class XlaExpression {
|
||||
bool has_constant_value_ = false;
|
||||
Tensor constant_value_;
|
||||
|
||||
XlaVariable* variable_ = nullptr; // Not owned.
|
||||
XlaResource* resource_ = nullptr; // Not owned.
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
|
||||
};
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <numeric>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
@ -85,9 +86,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||
(*options_.populate_resource_manager)(device_->resource_manager());
|
||||
}
|
||||
|
||||
flib_def_.reset(new FunctionLibraryDefinition(*options.flib_def));
|
||||
flib_runtime_.reset(NewFunctionLibraryRuntime(
|
||||
&device_mgr_, Env::Default(), device_, options.graph_def_version,
|
||||
options.flib_def, OptimizerOptions(),
|
||||
flib_def_.get(), OptimizerOptions(),
|
||||
nullptr /* custom_kernel_creator */));
|
||||
}
|
||||
|
||||
@ -249,35 +251,36 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
std::vector<xla::Shape>* input_shapes) {
|
||||
context_args->resize(args.size());
|
||||
|
||||
// Argument numbers of arguments and variables that are to be passed to the
|
||||
// Argument numbers of arguments and resources that are to be passed to the
|
||||
// XLA computation as runtime parameters.
|
||||
std::vector<int> parameters, variables;
|
||||
std::vector<int> parameters, resources;
|
||||
parameters.reserve(args.size());
|
||||
variables.reserve(args.size());
|
||||
resources.reserve(args.size());
|
||||
|
||||
for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
|
||||
++i) {
|
||||
XlaContext::Argument& context_arg = (*context_args)[i];
|
||||
context_arg.kind = args[i].kind;
|
||||
context_arg.name = args[i].name;
|
||||
context_arg.value.constant_value = args[i].constant_value;
|
||||
context_arg.value.type = args[i].type;
|
||||
|
||||
switch (args[i].kind) {
|
||||
case XlaCompiler::Argument::kVariable:
|
||||
variables.push_back(i);
|
||||
context_arg.is_variable = true;
|
||||
context_arg.value.is_constant = false;
|
||||
case XlaCompiler::Argument::kTensorArray:
|
||||
context_arg.is_resource = true;
|
||||
if (args[i].initialized) {
|
||||
resources.push_back(i);
|
||||
context_arg.value.is_constant = false;
|
||||
} else {
|
||||
context_arg.value.is_constant = true;
|
||||
}
|
||||
context_arg.tensor_array_size = args[i].tensor_array_size;
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
parameters.push_back(i);
|
||||
context_arg.value.is_constant = false;
|
||||
break;
|
||||
case XlaCompiler::Argument::kUninitializedVariable:
|
||||
context_arg.is_variable = true;
|
||||
context_arg.value.is_constant = true;
|
||||
context_arg.tensor_array_size = args[i].tensor_array_size;
|
||||
break;
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
context_arg.value.is_constant = true;
|
||||
break;
|
||||
@ -288,7 +291,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
|
||||
// Append parameters containing variable values after the other runtime
|
||||
// parameters.
|
||||
parameters.insert(parameters.end(), variables.begin(), variables.end());
|
||||
parameters.insert(parameters.end(), resources.begin(), resources.end());
|
||||
if (parameters.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
@ -329,22 +332,22 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
// variable states, generated by the symbolic evaluation.
|
||||
// If `has_side_effects` is true, the computation has side effects and should be
|
||||
// built even if it has no outputs.
|
||||
// If `return_updated_values_for_all_variables` is true, all variables will be
|
||||
// included in `variable_updates`, regardless of whether their value changed.
|
||||
// If `return_updated_values_for_all_resources` is true, all resources will be
|
||||
// included in `resource_updates`, regardless of whether their value changed.
|
||||
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
|
||||
// Sets `*variable_updates` to a description of variables whose values are
|
||||
// Sets `*resource_updates` to a description of resources whose values are
|
||||
// written by the computation; the variable writes are the last
|
||||
// `variable_updates.size()` return values from the computation. Each entry in
|
||||
// `variable_updates` is a (input_index, type) pair, where `input_index` is the
|
||||
// `resource_updates.size()` return values from the computation. Each entry in
|
||||
// `resource_updates` is a (input_index, type) pair, where `input_index` is the
|
||||
// index of a resource variable argument to the computation, and `type` is the
|
||||
// type of the final output.
|
||||
Status BuildComputation(
|
||||
const std::vector<XlaContext::HandleOrConstant>& retvals,
|
||||
const std::vector<std::unique_ptr<XlaVariable>>& variables,
|
||||
bool has_side_effects, bool return_updated_values_for_all_variables,
|
||||
const std::vector<std::unique_ptr<XlaResource>>& resources,
|
||||
bool has_side_effects, bool return_updated_values_for_all_resources,
|
||||
xla::ComputationBuilder* builder, xla::Computation* computation,
|
||||
int* num_nonconst_outputs,
|
||||
std::vector<XlaCompiler::VariableUpdate>* variable_updates) {
|
||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
||||
std::vector<xla::ComputationDataHandle> elems;
|
||||
elems.reserve(retvals.size());
|
||||
for (const XlaContext::HandleOrConstant& retval : retvals) {
|
||||
@ -354,24 +357,24 @@ Status BuildComputation(
|
||||
}
|
||||
*num_nonconst_outputs = elems.size();
|
||||
|
||||
// Add return values for variables whose values have changed.
|
||||
std::vector<const XlaVariable*> arg_vars;
|
||||
arg_vars.reserve(variables.size());
|
||||
for (const auto& var : variables) {
|
||||
// Add return values for resources whose values have changed.
|
||||
std::vector<const XlaResource*> arg_vars;
|
||||
arg_vars.reserve(resources.size());
|
||||
for (const auto& var : resources) {
|
||||
if (var->arg_num >= 0) {
|
||||
arg_vars.push_back(var.get());
|
||||
}
|
||||
}
|
||||
std::sort(arg_vars.begin(), arg_vars.end(),
|
||||
[](const XlaVariable* a, const XlaVariable* b) {
|
||||
[](const XlaResource* a, const XlaResource* b) {
|
||||
return a->arg_num < b->arg_num;
|
||||
});
|
||||
|
||||
for (const XlaVariable* var : arg_vars) {
|
||||
for (const XlaResource* var : arg_vars) {
|
||||
bool modified = var->value.handle() != var->initial_value.handle();
|
||||
if (return_updated_values_for_all_variables || modified) {
|
||||
variable_updates->emplace_back();
|
||||
XlaCompiler::VariableUpdate& update = variable_updates->back();
|
||||
if (return_updated_values_for_all_resources || modified) {
|
||||
resource_updates->emplace_back();
|
||||
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
||||
update.input_index = var->arg_num;
|
||||
update.type = var->type;
|
||||
update.modified = modified;
|
||||
@ -413,6 +416,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
// Report the error here if initialization failed.
|
||||
TF_RETURN_IF_ERROR(initialization_status_);
|
||||
|
||||
// Converts Tensorflow's graph control-flow constructs into functional
|
||||
// control-flow that can be compiled into XLA code.
|
||||
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(graph.get(), flib_def_.get()));
|
||||
|
||||
xla::ComputationBuilder builder(client(), name);
|
||||
XlaContext* context =
|
||||
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
|
||||
@ -433,10 +440,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
int num_nonconst_outputs;
|
||||
result->computation = std::make_shared<xla::Computation>();
|
||||
TF_RETURN_IF_ERROR(BuildComputation(
|
||||
context->retvals(), context->variables(), context->has_side_effects(),
|
||||
options.return_updated_values_for_all_variables, &builder,
|
||||
context->retvals(), context->resources(), context->has_side_effects(),
|
||||
options.return_updated_values_for_all_resources, &builder,
|
||||
result->computation.get(), &num_nonconst_outputs,
|
||||
&result->variable_updates));
|
||||
&result->resource_updates));
|
||||
|
||||
result->requires_runtime_context = context->has_context_parameter();
|
||||
|
||||
@ -511,15 +518,15 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
}
|
||||
}
|
||||
|
||||
for (std::vector<VariableUpdate>::size_type i = 0;
|
||||
i < result->variable_updates.size(); ++i) {
|
||||
for (std::vector<ResourceUpdate>::size_type i = 0;
|
||||
i < result->resource_updates.size(); ++i) {
|
||||
if (num_computation_outputs > 1) {
|
||||
result->variable_updates[i].shape =
|
||||
result->resource_updates[i].shape =
|
||||
XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(
|
||||
result->xla_output_shape, computation_output));
|
||||
} else {
|
||||
CHECK_EQ(0, computation_output);
|
||||
result->variable_updates[i].shape =
|
||||
result->resource_updates[i].shape =
|
||||
XLAShapeToTensorShape(result->xla_output_shape);
|
||||
}
|
||||
++computation_output;
|
||||
|
@ -85,14 +85,14 @@ class XlaCompiler {
|
||||
// Argument is a compile-time constant. No associated runtime parameter.
|
||||
kConstant,
|
||||
|
||||
// Argument is a variable that has not been initialized yet. No associated
|
||||
// runtime parameter.
|
||||
kUninitializedVariable,
|
||||
|
||||
// Argument is a variable that already has a value set. Expects a runtime
|
||||
// parameter containing the current value.
|
||||
// Argument is a variable resource. Has an associated runtime parameter
|
||||
// iff `initialized` is true.
|
||||
kVariable,
|
||||
|
||||
// Argument is a TensorArray resource. Has an associated runtime parameter
|
||||
// iff `initialized` is true.
|
||||
kTensorArray,
|
||||
|
||||
// Argument is a run-time parameter.
|
||||
kParameter,
|
||||
};
|
||||
@ -114,8 +114,11 @@ class XlaCompiler {
|
||||
// The name of this argument, used for debugging.
|
||||
string name;
|
||||
|
||||
// For a kVariable or kUninitializedVariable corresponding to a TensorArray,
|
||||
// what is the tensor array's declared size?
|
||||
// For a kVariable or kTensorArray, has this resource been initialized?
|
||||
bool initialized = false;
|
||||
|
||||
// For a kTensorArray, what is the array's declared size? (Used for lazy
|
||||
// initialization.)
|
||||
int64 tensor_array_size = -1;
|
||||
|
||||
bool operator==(const Argument& other) const;
|
||||
@ -133,7 +136,7 @@ class XlaCompiler {
|
||||
};
|
||||
|
||||
// Describes a variable write side effect of the computation.
|
||||
struct VariableUpdate {
|
||||
struct ResourceUpdate {
|
||||
// Index of the input that contains the variable resource to write to.
|
||||
int input_index;
|
||||
|
||||
@ -142,14 +145,14 @@ class XlaCompiler {
|
||||
TensorShape shape;
|
||||
|
||||
// Was the value of the variable modified by the computation?
|
||||
// (Always true, unless `return_updated_values_for_all_variables` is true.)
|
||||
// (Always true, unless `return_updated_values_for_all_resources` is true.)
|
||||
bool modified;
|
||||
};
|
||||
|
||||
struct CompilationResult {
|
||||
// Vector that maps from the parameters of the XLA computation to their
|
||||
// original argument positions. To handle compile-time constant inputs and
|
||||
// variables, the parameters to the XLA computation may be a subset of the
|
||||
// resources, the parameters to the XLA computation may be a subset of the
|
||||
// original arguments, and are not necessarily in the same order.)
|
||||
std::vector<int> input_mapping;
|
||||
|
||||
@ -172,10 +175,10 @@ class XlaCompiler {
|
||||
// containing both constant and non-constant results.
|
||||
std::vector<OutputDescription> outputs;
|
||||
|
||||
// Variables whose values were updated by the computation, ordered
|
||||
// by return value position. Variable updates follow the non-constant
|
||||
// Resources whose values were updated by the computation, ordered
|
||||
// by return value position. Resource updates follow the non-constant
|
||||
// results in the outputs of XLA computation.
|
||||
std::vector<VariableUpdate> variable_updates;
|
||||
std::vector<ResourceUpdate> resource_updates;
|
||||
|
||||
// The XLA computation built from the tensorflow subgraph. May be null
|
||||
// if the output consists solely of compile-time constants.
|
||||
@ -229,12 +232,12 @@ class XlaCompiler {
|
||||
// arguments; if false, each argument gets its own parameter.
|
||||
bool use_tuple_arg = false;
|
||||
|
||||
// If 'return_updated_values_for_all_variables' is true, then updated
|
||||
// values of all resource variables arguments will be included in the
|
||||
// 'variable_updates' of the computation, even if the variable was not
|
||||
// If 'return_updated_values_for_all_resources' is true, then updated
|
||||
// values of all resource resources arguments will be included in the
|
||||
// 'resource_updates' of the computation, even if the resource was not
|
||||
// modified by the computation. Used when compiling loop bodies to ensure
|
||||
// the input and output signatures match.
|
||||
bool return_updated_values_for_all_variables = false;
|
||||
bool return_updated_values_for_all_resources = false;
|
||||
};
|
||||
|
||||
// Compiles a Tensorflow function `fn_name_attrs` into an XLA computation.
|
||||
@ -294,6 +297,7 @@ class XlaCompiler {
|
||||
XlaCompilationDevice* device_; // Owned by device_mgr_
|
||||
DeviceMgr device_mgr_;
|
||||
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
|
||||
|
||||
struct SignatureHash {
|
||||
|
@ -163,9 +163,9 @@ TEST_F(XlaCompilerTest, Simple) {
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
xla::Literal::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
xla::Literal::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
@ -179,7 +179,7 @@ TEST_F(XlaCompilerTest, Simple) {
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({4, 143});
|
||||
xla::Literal::CreateR1<int32>({4, 143});
|
||||
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
|
||||
}
|
||||
|
||||
@ -225,7 +225,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
xla::Literal::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
|
||||
@ -236,7 +236,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({-7, -42});
|
||||
xla::Literal::CreateR1<int32>({-7, -42});
|
||||
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
|
||||
}
|
||||
|
||||
@ -260,7 +260,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
xla::Literal::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
|
||||
@ -270,12 +270,11 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::LiteralUtil::CreateR0<int32>(7);
|
||||
std::unique_ptr<xla::Literal> expected0 = xla::Literal::CreateR0<int32>(7);
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::LiteralUtil::CreateR1<int32>({-7, -42});
|
||||
xla::Literal::CreateR1<int32>({-7, -42});
|
||||
std::unique_ptr<xla::Literal> expected =
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
||||
xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal);
|
||||
}
|
||||
}
|
||||
|
@ -129,16 +129,18 @@ void XlaContext::AddSideEffects() {
|
||||
|
||||
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
|
||||
|
||||
Status XlaContext::CreateVariable(int arg_num, string name, DataType type,
|
||||
Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num,
|
||||
string name, DataType type,
|
||||
const xla::ComputationDataHandle& handle,
|
||||
XlaVariable** variable) {
|
||||
variables_.emplace_back(new XlaVariable);
|
||||
*variable = variables_.back().get();
|
||||
XlaVariable& var = **variable;
|
||||
var.arg_num = arg_num;
|
||||
var.name = std::move(name);
|
||||
var.type = type;
|
||||
var.initial_value = var.value = handle;
|
||||
XlaResource** resource) {
|
||||
resources_.emplace_back(new XlaResource);
|
||||
*resource = resources_.back().get();
|
||||
XlaResource& r = **resource;
|
||||
r.kind = kind;
|
||||
r.arg_num = arg_num;
|
||||
r.name = std::move(name);
|
||||
r.type = type;
|
||||
r.initial_value = r.value = handle;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -52,11 +52,13 @@ class XlaContext : public ResourceBase {
|
||||
};
|
||||
|
||||
struct Argument {
|
||||
// Descriptive name for the variable, for use in error messages.
|
||||
XlaCompiler::Argument::Kind kind;
|
||||
|
||||
// Descriptive name for the resource, for use in error messages.
|
||||
string name;
|
||||
|
||||
// Is this a variable?
|
||||
bool is_variable = false;
|
||||
// Is this a resource?
|
||||
bool is_resource = false;
|
||||
|
||||
HandleOrConstant value;
|
||||
|
||||
@ -106,15 +108,15 @@ class XlaContext : public ResourceBase {
|
||||
|
||||
bool has_side_effects() const { return has_side_effects_; }
|
||||
|
||||
// Creates a variable with variable `variable_id` and initial type `type` and
|
||||
// Creates a resource with resource `kind` and initial type `type` and
|
||||
// value `handle`. `name` is a descriptive name for use in error messages.
|
||||
// Fails if the variable already exists.
|
||||
Status CreateVariable(int arg_num, string name, DataType type,
|
||||
const xla::ComputationDataHandle& handle,
|
||||
XlaVariable** variable);
|
||||
// Fails if the resource already exists.
|
||||
Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
|
||||
DataType type, const xla::ComputationDataHandle& handle,
|
||||
XlaResource** resource);
|
||||
|
||||
const std::vector<std::unique_ptr<XlaVariable>>& variables() {
|
||||
return variables_;
|
||||
const std::vector<std::unique_ptr<XlaResource>>& resources() {
|
||||
return resources_;
|
||||
}
|
||||
|
||||
// Get an XLA lambda to compute Max. This is cached in the
|
||||
@ -166,8 +168,8 @@ class XlaContext : public ResourceBase {
|
||||
// Does the computation have side effects, i.e., Send() calls?
|
||||
bool has_side_effects_ = false;
|
||||
|
||||
// Holds ownership of variables. The variables are not ordered.
|
||||
std::vector<std::unique_ptr<XlaVariable>> variables_;
|
||||
// Holds ownership of resources. The resources are not ordered.
|
||||
std::vector<std::unique_ptr<XlaResource>> resources_;
|
||||
|
||||
// Cache of prebuilt computations indexed by their type.
|
||||
using ComputationMap = std::map<DataType, xla::Computation>;
|
||||
|
@ -30,28 +30,28 @@ xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b,
|
||||
DataType data_type) {
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
||||
return b->ConstantLiteral(xla::LiteralUtil::MinValue(type));
|
||||
return b->ConstantLiteral(xla::Literal::MinValue(type));
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b,
|
||||
DataType data_type) {
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
||||
return b->ConstantLiteral(xla::LiteralUtil::MaxValue(type));
|
||||
return b->ConstantLiteral(xla::Literal::MaxValue(type));
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b,
|
||||
DataType data_type) {
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
||||
return b->ConstantLiteral(xla::LiteralUtil::Zero(type));
|
||||
return b->ConstantLiteral(xla::Literal::Zero(type));
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
|
||||
DataType data_type) {
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
||||
return b->ConstantLiteral(xla::LiteralUtil::One(type));
|
||||
return b->ConstantLiteral(xla::Literal::One(type));
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
|
||||
@ -61,28 +61,28 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
||||
switch (type) {
|
||||
case xla::U8:
|
||||
literal = *xla::LiteralUtil::CreateR0<uint8>(value);
|
||||
literal = *xla::Literal::CreateR0<uint8>(value);
|
||||
break;
|
||||
case xla::U32:
|
||||
literal = *xla::LiteralUtil::CreateR0<uint32>(value);
|
||||
literal = *xla::Literal::CreateR0<uint32>(value);
|
||||
break;
|
||||
case xla::U64:
|
||||
literal = *xla::LiteralUtil::CreateR0<uint64>(value);
|
||||
literal = *xla::Literal::CreateR0<uint64>(value);
|
||||
break;
|
||||
case xla::S8:
|
||||
literal = *xla::LiteralUtil::CreateR0<int8>(value);
|
||||
literal = *xla::Literal::CreateR0<int8>(value);
|
||||
break;
|
||||
case xla::S32:
|
||||
literal = *xla::LiteralUtil::CreateR0<int32>(value);
|
||||
literal = *xla::Literal::CreateR0<int32>(value);
|
||||
break;
|
||||
case xla::S64:
|
||||
literal = *xla::LiteralUtil::CreateR0<int64>(value);
|
||||
literal = *xla::Literal::CreateR0<int64>(value);
|
||||
break;
|
||||
case xla::F32:
|
||||
literal = *xla::LiteralUtil::CreateR0<float>(value);
|
||||
literal = *xla::Literal::CreateR0<float>(value);
|
||||
break;
|
||||
case xla::F64:
|
||||
literal = *xla::LiteralUtil::CreateR0<double>(value);
|
||||
literal = *xla::Literal::CreateR0<double>(value);
|
||||
break;
|
||||
case xla::PRED:
|
||||
LOG(FATAL) << "pred element type is not integral";
|
||||
@ -91,7 +91,7 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case xla::F16:
|
||||
literal =
|
||||
*xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
|
||||
*xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value));
|
||||
break;
|
||||
case xla::TUPLE:
|
||||
LOG(FATAL) << "tuple element type is not integral";
|
||||
|
@ -39,7 +39,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
||||
CHECK(expression->handle().handle() != 0 ||
|
||||
expression->variable() != nullptr);
|
||||
expression->resource() != nullptr);
|
||||
VLOG(1) << "Fetched T" << expression->handle().handle();
|
||||
return expression;
|
||||
}
|
||||
@ -144,9 +144,9 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) {
|
||||
return errors::InvalidArgument("value is not a scalar");
|
||||
}
|
||||
if (literal.shape().element_type() == xla::S32) {
|
||||
*out = xla::LiteralUtil::Get<int32>(literal, {});
|
||||
*out = literal.Get<int32>({});
|
||||
} else if (literal.shape().element_type() == xla::S64) {
|
||||
*out = xla::LiteralUtil::Get<int64>(literal, {});
|
||||
*out = literal.Get<int64>({});
|
||||
} else {
|
||||
return errors::InvalidArgument("value must be either int32 or int64");
|
||||
}
|
||||
@ -168,11 +168,11 @@ static Status LiteralToInt64Vector(const xla::Literal& literal,
|
||||
int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
|
||||
if (literal.shape().element_type() == xla::S32) {
|
||||
for (int64 i = 0; i < size; ++i) {
|
||||
out->push_back(xla::LiteralUtil::Get<int32>(literal, {i}));
|
||||
out->push_back(literal.Get<int32>({i}));
|
||||
}
|
||||
} else if (literal.shape().element_type() == xla::S64) {
|
||||
for (int64 i = 0; i < size; ++i) {
|
||||
out->push_back(xla::LiteralUtil::Get<int64>(literal, {i}));
|
||||
out->push_back(literal.Get<int64>({i}));
|
||||
}
|
||||
} else {
|
||||
return errors::InvalidArgument("value must be either int32 or int64");
|
||||
@ -252,8 +252,9 @@ Status XlaOpKernelContext::ReadVariableInput(
|
||||
int index, xla::ComputationDataHandle* value) {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
XlaVariable* variable = expression->variable();
|
||||
XlaResource* variable = expression->resource();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind == XlaResource::kVariable);
|
||||
if (variable->value.handle() == 0) {
|
||||
return errors::InvalidArgument("Read of uninitialized variable ",
|
||||
variable->name);
|
||||
@ -262,22 +263,13 @@ Status XlaOpKernelContext::ReadVariableInput(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string XlaOpKernelContext::VariableDebugString(int index) {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
XlaVariable* variable = expression->variable();
|
||||
if (!variable) {
|
||||
return "<invalid variable ID>";
|
||||
}
|
||||
return variable->name;
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
||||
TensorShape* shape) const {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
XlaVariable* variable = expression->variable();
|
||||
XlaResource* variable = expression->resource();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind == XlaResource::kVariable);
|
||||
if (variable->value.handle() == 0) {
|
||||
return errors::InvalidArgument("Read of uninitialized variable ",
|
||||
variable->name);
|
||||
@ -337,33 +329,34 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
|
||||
expression->set_constant_value(constant);
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) {
|
||||
void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
|
||||
Tensor* output = nullptr;
|
||||
// The shape of the output tensor is the shape of the variable resource
|
||||
// (i.e., a scalar), not the shape of the variable's value.
|
||||
// The shape of the output tensor is the shape of the resource itself
|
||||
// (i.e., a scalar), not the shape of the resource's value.
|
||||
OP_REQUIRES_OK(context_,
|
||||
context_->allocate_output(index, TensorShape(), &output));
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
expression->set_variable(variable);
|
||||
expression->set_resource(resource);
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) {
|
||||
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
|
||||
const XlaExpression* expression =
|
||||
CastExpressionFromTensor(context_->input(index));
|
||||
TF_RET_CHECK(expression->variable() != nullptr);
|
||||
*variable = expression->variable();
|
||||
TF_RET_CHECK(expression->resource() != nullptr);
|
||||
*resource = expression->resource();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::AssignVariable(
|
||||
int index, DataType type, const xla::ComputationDataHandle& handle) {
|
||||
int input_index, DataType type, const xla::ComputationDataHandle& handle) {
|
||||
TF_RET_CHECK(handle.handle() != 0);
|
||||
SetOpHasSideEffects();
|
||||
|
||||
const XlaExpression* expression =
|
||||
CastExpressionFromTensor(context_->input(index));
|
||||
XlaVariable* variable = expression->variable();
|
||||
CastExpressionFromTensor(context_->input(input_index));
|
||||
XlaResource* variable = expression->resource();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind == XlaResource::kVariable);
|
||||
if (!((variable->type == DT_INVALID && type != DT_INVALID) ||
|
||||
(variable->type == type))) {
|
||||
return errors::InvalidArgument(
|
||||
|
@ -148,6 +148,12 @@ class XlaOpKernelContext {
|
||||
|
||||
// Variables
|
||||
|
||||
// Sets '*resource' to the resource associated with input `index`.
|
||||
Status GetResourceInput(int index, XlaResource** resource);
|
||||
|
||||
// Sets output 'index' to be a reference to resource 'resource'.
|
||||
void SetResourceOutput(int index, XlaResource* resource);
|
||||
|
||||
// Sets `*type` and `*shape` to the current type and shape of a variable's
|
||||
// value.
|
||||
Status GetVariableTypeAndShape(int index, DataType* type,
|
||||
@ -158,20 +164,10 @@ class XlaOpKernelContext {
|
||||
Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
|
||||
|
||||
// Assigns the value `handle` to the variable referenced by input
|
||||
// `variable_index`. Marks the operator as having side effects.
|
||||
Status AssignVariable(int variable_index, DataType type,
|
||||
// `input_index`. Marks the operator as having side effects.
|
||||
Status AssignVariable(int input_index, DataType type,
|
||||
const xla::ComputationDataHandle& handle);
|
||||
|
||||
// Sets '*variable' to the variable associated with input `index`.
|
||||
Status GetVariableInput(int index, XlaVariable** variable);
|
||||
|
||||
// Sets output 'index' to be a reference to variable 'variable'. Used
|
||||
// to propagate resource variables through the compilation.
|
||||
void SetVariableOutput(int index, XlaVariable* variable);
|
||||
|
||||
// Returns a human-readable debug string describing 'variable_index'.
|
||||
string VariableDebugString(int variable_index);
|
||||
|
||||
// Helper routines for the OP_REQUIRES macros
|
||||
void CtxFailure(Status s);
|
||||
void CtxFailureWithWarning(Status s);
|
||||
|
@ -34,11 +34,18 @@ const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT";
|
||||
const char* const DEVICE_XLA_CPU = "XLA_CPU";
|
||||
const char* const DEVICE_XLA_GPU = "XLA_GPU";
|
||||
|
||||
// Is platform 'id' supported by XLA?
|
||||
static bool IsPlatformSupported(perftools::gputools::Platform::Id id) {
|
||||
auto platform = perftools::gputools::MultiPlatformManager::PlatformWithId(id);
|
||||
if (!platform.ok()) return false;
|
||||
return xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()).ok();
|
||||
static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) {
|
||||
const OpDef* op_def;
|
||||
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def));
|
||||
NodeDef node_def;
|
||||
node_def.set_name("_XlaLaunch-op");
|
||||
node_def.set_op("_XlaLaunch");
|
||||
string kernel_class_name;
|
||||
TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr,
|
||||
&kernel_class_name));
|
||||
VLOG(1) << "LaunchOpHasKernelForDevice"
|
||||
<< " kernel_class_name: " << kernel_class_name;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaOpRegistry::XlaOpRegistry() = default;
|
||||
@ -75,7 +82,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
// GetCompilationDevice is called.
|
||||
static void* registration_init = [®istry]() {
|
||||
mutex_lock lock(registry.mutex_);
|
||||
if (IsPlatformSupported(perftools::gputools::host::kHostPlatformId)) {
|
||||
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) {
|
||||
DeviceRegistration& registration =
|
||||
registry.compilation_devices_[DEVICE_CPU];
|
||||
registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
|
||||
@ -83,7 +90,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
registration.enable_jit_by_default = false;
|
||||
registration.compile_resource_ops = false;
|
||||
}
|
||||
if (IsPlatformSupported(perftools::gputools::cuda::kCudaPlatformId)) {
|
||||
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
|
||||
DeviceRegistration& registration =
|
||||
registry.compilation_devices_[DEVICE_GPU];
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
|
@ -46,21 +46,18 @@ xla_proto_library(
|
||||
],
|
||||
)
|
||||
|
||||
# This is a headers target that extra XLA devices can use to prevent
|
||||
# circular dependencies. Devices that are compiled as separate shared
|
||||
# objects can also use it to prevent linking of library code.
|
||||
cc_header_only_library(
|
||||
name = "xla_headers_lib",
|
||||
visibility = ["//visibility:public"],
|
||||
cc_library(
|
||||
name = "execution_options_util",
|
||||
srcs = [
|
||||
"execution_options_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"execution_options_util.h",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_evaluator",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:stream_executor_headers_lib",
|
||||
":xla_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
],
|
||||
)
|
||||
|
||||
@ -602,3 +599,18 @@ filegroup(
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
|
||||
cc_header_only_library(
|
||||
name = "xla_headers_lib",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_data_proto",
|
||||
":xla_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:stream_executor_headers_lib",
|
||||
],
|
||||
)
|
||||
|
@ -114,7 +114,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:compile_only_service",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@llvm//:support",
|
||||
],
|
||||
|
@ -971,6 +971,11 @@ ComputationDataHandle ComputationBuilder::Sign(
|
||||
return UnaryOp(UNOP_SIGN, operand);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::Cos(
|
||||
const ComputationDataHandle& operand) {
|
||||
return UnaryOp(UNOP_COS, operand);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::Tanh(
|
||||
const ComputationDataHandle& operand) {
|
||||
return UnaryOp(UNOP_TANH, operand);
|
||||
@ -1411,6 +1416,52 @@ ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding(
|
||||
return ParseOpResponse(s, &response);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::BatchNormTraining(
|
||||
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
|
||||
const ComputationDataHandle& offset, float epsilon, int64 feature_index) {
|
||||
if (!first_error_.ok() || !PrepareComputation().ok()) {
|
||||
return ComputationDataHandle();
|
||||
}
|
||||
BatchNormTrainingRequest request;
|
||||
*request.mutable_operand() = operand;
|
||||
*request.mutable_scale() = scale;
|
||||
*request.mutable_offset() = offset;
|
||||
request.set_epsilon(epsilon);
|
||||
request.set_feature_index(feature_index);
|
||||
|
||||
OpRequest op_request;
|
||||
*op_request.mutable_batch_norm_training_request() = request;
|
||||
*op_request.mutable_computation() = computation_.handle();
|
||||
AddOpMetadata(&op_request);
|
||||
|
||||
OpResponse response;
|
||||
|
||||
VLOG(2) << "making BatchNormTraining request";
|
||||
|
||||
Status s = client_->stub()->Op(&op_request, &response);
|
||||
return ParseOpResponse(s, &response);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::BatchNormInference(
|
||||
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
|
||||
const ComputationDataHandle& offset, const ComputationDataHandle& mean,
|
||||
const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
|
||||
// TODO(b/62843645): Implement BatchNormInference.
|
||||
NoteError(Unimplemented("BatchNormInference is not implemented yet."));
|
||||
return ComputationDataHandle();
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::BatchNormGrad(
|
||||
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
|
||||
const ComputationDataHandle& batch_mean,
|
||||
const ComputationDataHandle& batch_var,
|
||||
const ComputationDataHandle& grad_output, float epsilon,
|
||||
int64 feature_index) {
|
||||
// TODO(b/62843645): Implement BatchNormGrad.
|
||||
NoteError(Unimplemented("BatchNormGrad is not implemented yet."));
|
||||
return ComputationDataHandle();
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::CrossReplicaSum(
|
||||
const ComputationDataHandle& operand) {
|
||||
if (!first_error_.ok() || !PrepareComputation().ok()) {
|
||||
@ -1487,6 +1538,28 @@ ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
|
||||
return ParseOpResponse(s, &response);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::ReducePrecision(
|
||||
const ComputationDataHandle& operand, const int exponent_bits,
|
||||
const int mantissa_bits) {
|
||||
if (!first_error_.ok() || !PrepareComputation().ok()) {
|
||||
return ComputationDataHandle();
|
||||
}
|
||||
|
||||
ReducePrecisionRequest request;
|
||||
*request.mutable_operand() = operand;
|
||||
request.set_exponent_bits(exponent_bits);
|
||||
request.set_mantissa_bits(mantissa_bits);
|
||||
OpRequest op_request;
|
||||
*op_request.mutable_computation() = computation_.handle();
|
||||
*op_request.mutable_reduce_precision_request() = request;
|
||||
AddOpMetadata(&op_request);
|
||||
OpResponse response;
|
||||
|
||||
VLOG(2) << "making reduce-precision request";
|
||||
Status s = client_->stub()->Op(&op_request, &response);
|
||||
return ParseOpResponse(s, &response);
|
||||
}
|
||||
|
||||
void ComputationBuilder::Send(const ComputationDataHandle& operand,
|
||||
const ChannelHandle& handle) {
|
||||
if (!first_error_.ok() || !PrepareComputation().ok()) {
|
||||
|
@ -510,6 +510,9 @@ class ComputationBuilder {
|
||||
// Enqueues a sign instruction onto the computation.
|
||||
ComputationDataHandle Sign(const ComputationDataHandle& operand);
|
||||
|
||||
// Enqueues a cosine instruction onto the computation.
|
||||
ComputationDataHandle Cos(const ComputationDataHandle& operand);
|
||||
|
||||
// Enqueues a tanh instruction onto the computation.
|
||||
ComputationDataHandle Tanh(const ComputationDataHandle& operand);
|
||||
|
||||
@ -597,6 +600,11 @@ class ComputationBuilder {
|
||||
const Computation& body,
|
||||
const ComputationDataHandle& init);
|
||||
|
||||
// Enqueues a ReducePrecision node onto the computation.
|
||||
ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand,
|
||||
const int exponent_bits,
|
||||
const int mantissa_bits);
|
||||
|
||||
// Enqueues a Send node onto the computation, to send the given operand to
|
||||
// a Recv instruction that shares the same channel handle.
|
||||
void Send(const ComputationDataHandle& operand, const ChannelHandle& handle);
|
||||
@ -820,87 +828,80 @@ class ComputationBuilder {
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) {
|
||||
return ConstantOp(
|
||||
[value](Literal* literal) { LiteralUtil::PopulateR0(value, literal); });
|
||||
return ConstantOp([value](Literal* literal) { literal->PopulateR0(value); });
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR1(
|
||||
tensorflow::gtl::ArraySlice<NativeT> values) {
|
||||
return ConstantOp([&values](Literal* literal) {
|
||||
LiteralUtil::PopulateR1(values, literal);
|
||||
});
|
||||
return ConstantOp(
|
||||
[&values](Literal* literal) { literal->PopulateR1(values); });
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR1(int64 length,
|
||||
NativeT value) {
|
||||
return ConstantOp([length, value](Literal* literal) {
|
||||
LiteralUtil::PopulateWithValue(value, {length}, literal);
|
||||
literal->PopulateWithValue(value, {length});
|
||||
});
|
||||
}
|
||||
|
||||
inline ComputationDataHandle ComputationBuilder::ConstantR1(
|
||||
const tensorflow::core::Bitmap& values) {
|
||||
return ConstantOp([&values](Literal* literal) {
|
||||
LiteralUtil::PopulateR1(values, literal);
|
||||
});
|
||||
return ConstantOp(
|
||||
[&values](Literal* literal) { literal->PopulateR1(values); });
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR2(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values) {
|
||||
return ConstantOp([&values](Literal* literal) {
|
||||
LiteralUtil::PopulateR2(values, literal);
|
||||
});
|
||||
return ConstantOp(
|
||||
[&values](Literal* literal) { literal->PopulateR2(values); });
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
|
||||
const Array2D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantOp([&values, &layout](Literal* literal) {
|
||||
LiteralUtil::PopulateR2FromArray2DWithLayout(values, layout, literal);
|
||||
literal->PopulateR2FromArray2DWithLayout(values, layout);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
|
||||
const Array2D<NativeT>& values) {
|
||||
return ConstantOp([&values](Literal* literal) {
|
||||
LiteralUtil::PopulateR2FromArray2D(values, literal);
|
||||
});
|
||||
return ConstantOp(
|
||||
[&values](Literal* literal) { literal->PopulateR2FromArray2D(values); });
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
|
||||
const Array3D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantOp([&values, &layout](Literal* literal) {
|
||||
LiteralUtil::PopulateR3FromArray3DWithLayout(values, layout, literal);
|
||||
literal->PopulateR3FromArray3DWithLayout(values, layout);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
|
||||
const Array3D<NativeT>& values) {
|
||||
return ConstantOp([&values](Literal* literal) {
|
||||
LiteralUtil::PopulateR3FromArray3D(values, literal);
|
||||
});
|
||||
return ConstantOp(
|
||||
[&values](Literal* literal) { literal->PopulateR3FromArray3D(values); });
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
|
||||
const Array4D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantOp([&values, &layout](Literal* literal) {
|
||||
LiteralUtil::PopulateR4FromArray4DWithLayout(values, layout, literal);
|
||||
literal->PopulateR4FromArray4DWithLayout(values, layout);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
|
||||
const Array4D<NativeT>& values) {
|
||||
return ConstantOp([&values](Literal* literal) {
|
||||
LiteralUtil::PopulateR4FromArray4D(values, literal);
|
||||
});
|
||||
return ConstantOp(
|
||||
[&values](Literal* literal) { literal->PopulateR4FromArray4D(values); });
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -32,6 +32,7 @@ cc_library(
|
||||
srcs = ["testing.cc"],
|
||||
hdrs = ["testing.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -34,11 +35,11 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
|
||||
client,
|
||||
tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape)));
|
||||
// TODO(b/26811613): Replace this when RNG is supported on all backends.
|
||||
b.Broadcast(b.ConstantLiteral(LiteralUtil::One(shape.element_type())),
|
||||
b.Broadcast(b.ConstantLiteral(Literal::One(shape.element_type())),
|
||||
AsInt64Slice(shape.dimensions()));
|
||||
Computation computation = b.Build().ConsumeValueOrDie();
|
||||
|
||||
ExecutionOptions execution_options;
|
||||
auto execution_options = CreateDefaultExecutionOptions();
|
||||
*execution_options.mutable_shape_with_output_layout() = shape;
|
||||
return client->Execute(computation, /*arguments=*/{}, &execution_options)
|
||||
.ConsumeValueOrDie();
|
||||
|
@ -77,4 +77,14 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const {
|
||||
return execution_profile_;
|
||||
}
|
||||
|
||||
ExecutableRunOptions& ExecutableRunOptions::set_device_assignment(
|
||||
DeviceAssignment* device_assignment) {
|
||||
device_assignment_ = device_assignment;
|
||||
return *this;
|
||||
}
|
||||
|
||||
DeviceAssignment* ExecutableRunOptions::device_assignment() const {
|
||||
return device_assignment_;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -40,6 +40,7 @@ struct ThreadPoolDevice;
|
||||
namespace xla {
|
||||
|
||||
class DeviceMemoryAllocator;
|
||||
class DeviceAssignment;
|
||||
class ExecutionProfile;
|
||||
|
||||
// Class containing options for running a LocalExecutable.
|
||||
@ -79,9 +80,14 @@ class ExecutableRunOptions {
|
||||
ExecutionProfile* execution_profile() const;
|
||||
ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
|
||||
|
||||
ExecutableRunOptions& set_device_assignment(
|
||||
DeviceAssignment* device_assignment);
|
||||
DeviceAssignment* device_assignment() const;
|
||||
|
||||
private:
|
||||
DeviceMemoryAllocator* allocator_ = nullptr;
|
||||
int device_ordinal_ = -1;
|
||||
DeviceAssignment* device_assignment_ = nullptr;
|
||||
perftools::gputools::Stream* stream_ = nullptr;
|
||||
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
|
||||
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
|
||||
|
@ -1,6 +1,4 @@
|
||||
<!--
|
||||
@license
|
||||
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,19 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
|
||||
<link rel="import" href="../polymer/polymer.html">
|
||||
<link rel="import" href="../tf-imports/d3.html">
|
||||
namespace xla {
|
||||
|
||||
<!--
|
||||
tf-color-scale is a plumbing component that takes in an array of runs, and produces
|
||||
an upward-bindable outColorScale, which is a color scale mapping from those runs to
|
||||
a set of colors.
|
||||
ExecutionOptions CreateDefaultExecutionOptions() {
|
||||
ExecutionOptions execution_options;
|
||||
*(execution_options.mutable_debug_options()) =
|
||||
legacy_flags::GetDebugOptionsFromFlags();
|
||||
return execution_options;
|
||||
}
|
||||
|
||||
@element tf-color-scale
|
||||
-->
|
||||
<dom-module id="tf-color-scale">
|
||||
<script src="palettes.js"></script>
|
||||
<script src="colorScale.js"></script>
|
||||
</dom-module>
|
||||
} // namespace xla
|
29
tensorflow/compiler/xla/execution_options_util.h
Normal file
29
tensorflow/compiler/xla/execution_options_util.h
Normal file
@ -0,0 +1,29 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Create a default ExecutionOptions proto; this proto has its debug options
|
||||
// popupated to the default values taken from flags.
|
||||
ExecutionOptions CreateDefaultExecutionOptions();
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
|
@ -73,26 +73,12 @@ cc_library(
|
||||
deps =
|
||||
[
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_compiler_flags",
|
||||
srcs = ["cpu_compiler_flags.cc"],
|
||||
hdrs = ["cpu_compiler_flags.h"],
|
||||
deps =
|
||||
[
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_runtime_flags",
|
||||
srcs = ["cpu_runtime_flags.cc"],
|
||||
@ -128,30 +114,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_compiler_flags",
|
||||
srcs = ["gpu_compiler_flags.cc"],
|
||||
hdrs = ["gpu_compiler_flags.h"],
|
||||
deps = [
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_backend_lib_flags",
|
||||
srcs = ["gpu_backend_lib_flags.cc"],
|
||||
hdrs = ["gpu_backend_lib_flags.h"],
|
||||
deps = [
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_assignment_flags",
|
||||
srcs = ["stream_assignment_flags.cc"],
|
||||
@ -175,28 +137,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "alias_analysis_flags",
|
||||
srcs = ["alias_analysis_flags.cc"],
|
||||
hdrs = ["alias_analysis_flags.h"],
|
||||
deps = [
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "llvm_util_flags",
|
||||
srcs = ["llvm_util_flags.cc"],
|
||||
hdrs = ["llvm_util_flags.h"],
|
||||
deps = [
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "service_flags",
|
||||
srcs = ["service_flags.cc"],
|
||||
|
@ -1,62 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for XLA's alias_analysis module.
|
||||
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static AliasAnalysisFlags* flags;
|
||||
static std::vector<tensorflow::Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new AliasAnalysisFlags;
|
||||
flags->xla_emit_alias_scope = true;
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag("xla_emit_alias_scope", &flags->xla_emit_alias_scope,
|
||||
"Use buffer analysis to refine alias-analysis."),
|
||||
});
|
||||
ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with XLA's alias_analysis
|
||||
// module.
|
||||
void AppendAliasAnalysisFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the AliasAnalysisFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
AliasAnalysisFlags* GetAliasAnalysisFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
@ -1,46 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_
|
||||
|
||||
// Legacy flags for XLA's alias_analysis module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with XLA's alias_analysis
|
||||
// module.
|
||||
void AppendAliasAnalysisFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with XLA's alias_analysis module.
|
||||
typedef struct {
|
||||
bool xla_emit_alias_scope; // Use buffer analysis to refine alias-analysis.
|
||||
} AliasAnalysisFlags;
|
||||
|
||||
// Return a pointer to the AliasAnalysisFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
AliasAnalysisFlags* GetAliasAnalysisFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_
|
@ -1,68 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for XLA's cpu_compiler module.
|
||||
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static CpuCompilerFlags* flags;
|
||||
static std::vector<tensorflow::Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new CpuCompilerFlags;
|
||||
flags->xla_cpu_embed_ir = false;
|
||||
flags->xla_cpu_dump_debug_json_to = "";
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag(
|
||||
"xla_cpu_embed_ir", &flags->xla_cpu_embed_ir,
|
||||
"Embed the LLVM IR module string in the resultant CpuExecutable."),
|
||||
tensorflow::Flag("xla_cpu_dump_debug_json_to",
|
||||
&flags->xla_cpu_dump_debug_json_to,
|
||||
"Dump debug JSON to this directory."),
|
||||
});
|
||||
ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with XLA's cpu_compiler
|
||||
// module.
|
||||
void AppendCpuCompilerFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the CpuCompilerFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
CpuCompilerFlags* GetCpuCompilerFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
@ -1,49 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_
|
||||
|
||||
// Legacy flags for the XLA's cpu_compiler module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with XLA's cpu_compiler
|
||||
// module.
|
||||
void AppendCpuCompilerFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with XLA's cpu_compiler module.
|
||||
typedef struct {
|
||||
bool xla_cpu_embed_ir; // Embed the LLVM IR module string in the resultant
|
||||
// CpuExecutable
|
||||
string xla_cpu_dump_debug_json_to; // Dump debug JSON to this directory.
|
||||
} CpuCompilerFlags;
|
||||
|
||||
// Return a pointer to the CpuCompilerFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
CpuCompilerFlags* GetCpuCompilerFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_
|
@ -28,6 +28,12 @@ struct DebugOptionsFlags {
|
||||
string xla_disable_hlo_passes;
|
||||
bool xla_enable_fast_math;
|
||||
int32 xla_backend_optimization_level;
|
||||
bool xla_embed_ir_in_executable;
|
||||
string xla_dump_debug_json_to;
|
||||
|
||||
string xla_gpu_cuda_data_dir;
|
||||
bool xla_gpu_ftz;
|
||||
|
||||
string xla_backend_extra_options;
|
||||
};
|
||||
|
||||
@ -44,7 +50,11 @@ void AllocateFlags() {
|
||||
flag_values->xla_generate_hlo_graph = "";
|
||||
flag_values->xla_disable_hlo_passes = "";
|
||||
flag_values->xla_enable_fast_math = true;
|
||||
flag_values->xla_backend_optimization_level = 2;
|
||||
flag_values->xla_backend_optimization_level = 3;
|
||||
flag_values->xla_embed_ir_in_executable = false;
|
||||
flag_values->xla_dump_debug_json_to = "";
|
||||
flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib";
|
||||
flag_values->xla_gpu_ftz = false;
|
||||
flag_values->xla_backend_extra_options = "";
|
||||
|
||||
flag_objects = new std::vector<tensorflow::Flag>(
|
||||
@ -52,7 +62,6 @@ void AllocateFlags() {
|
||||
"xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph,
|
||||
"HLO modules matching this regex will be dumped to a .dot file "
|
||||
"throughout various stages in compilation."),
|
||||
|
||||
tensorflow::Flag(
|
||||
"xla_enable_fast_math", &flag_values->xla_enable_fast_math,
|
||||
"Enable unsafe fast-math optimizations in the compiler; "
|
||||
@ -61,18 +70,31 @@ void AllocateFlags() {
|
||||
"xla_backend_optimization_level",
|
||||
&flag_values->xla_backend_optimization_level,
|
||||
"Numerical optimization level for the XLA compiler backend."),
|
||||
|
||||
tensorflow::Flag(
|
||||
"xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes,
|
||||
"Comma-separated list of hlo passes to be disabled. These names "
|
||||
"must exactly match the passes' names; no whitespace around "
|
||||
"commas."),
|
||||
tensorflow::Flag("xla_embed_ir_in_executable",
|
||||
&flag_values->xla_embed_ir_in_executable,
|
||||
"Embed the compiler IR as a string in the executable."),
|
||||
tensorflow::Flag("xla_gpu_cuda_data_dir",
|
||||
&flag_values->xla_gpu_cuda_data_dir,
|
||||
"If non-empty, speficies a local directory containing "
|
||||
"ptxas and nvvm libdevice files; otherwise we use "
|
||||
"those from runfile directories."),
|
||||
tensorflow::Flag("xla_gpu_ftz", &flag_values->xla_gpu_ftz,
|
||||
"If true, flush-to-zero semantics are enabled in the "
|
||||
"code generated for GPUs."),
|
||||
tensorflow::Flag(
|
||||
"xla_dump_debug_json_to", &flag_values->xla_dump_debug_json_to,
|
||||
"Dump compilation artifacts as JSON into this directory."),
|
||||
tensorflow::Flag("xla_backend_extra_options",
|
||||
&flag_values->xla_backend_extra_options,
|
||||
"Extra options to pass to a backend; "
|
||||
"comma-separated list of 'key=val' strings (=val "
|
||||
"may be omitted); no whitespace around commas."),
|
||||
"may be omitted); no whitespace around commas.")});
|
||||
|
||||
tensorflow::Flag(
|
||||
"xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes,
|
||||
"Comma-separated list of HLO passes to be disabled. These names "
|
||||
"must exactly match the passes' names; "
|
||||
"no whitespace around commas.")});
|
||||
ParseFlagsFromEnv(*flag_objects);
|
||||
}
|
||||
|
||||
@ -99,6 +121,11 @@ xla::DebugOptions GetDebugOptionsFromFlags() {
|
||||
options.set_xla_enable_fast_math(flag_values->xla_enable_fast_math);
|
||||
options.set_xla_backend_optimization_level(
|
||||
flag_values->xla_backend_optimization_level);
|
||||
options.set_xla_embed_ir_in_executable(
|
||||
flag_values->xla_embed_ir_in_executable);
|
||||
options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to);
|
||||
options.set_xla_gpu_cuda_data_dir(flag_values->xla_gpu_cuda_data_dir);
|
||||
options.set_xla_gpu_ftz(flag_values->xla_gpu_ftz);
|
||||
|
||||
std::vector<string> extra_options_parts =
|
||||
tensorflow::str_util::Split(flag_values->xla_backend_extra_options, ',');
|
||||
|
@ -1,88 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for XLA's gpu_backend_lib module.
|
||||
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static GpuBackendLibFlags* flags;
|
||||
static std::vector<tensorflow::Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new GpuBackendLibFlags;
|
||||
flags->dump_temp_products_to = "";
|
||||
flags->ftz = false;
|
||||
flags->fma = true;
|
||||
flags->verbose_ptx_asm = false;
|
||||
flags->kernel = "";
|
||||
flags->llvm_dump_passes = false;
|
||||
flags->llvm_cl_opts = "";
|
||||
flags->dump_ir_before_passes = false;
|
||||
flags->opt_level = 3;
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag("dump_temp_products_to", &flags->dump_temp_products_to,
|
||||
"dump temporary compilation products to this directory. "
|
||||
"If empty, no dump is produced"),
|
||||
tensorflow::Flag("ftz", &flags->ftz, "flush to zero semantics"),
|
||||
tensorflow::Flag("fma", &flags->fma, "use FMA synthesis"),
|
||||
tensorflow::Flag("verbose_ptx_asm", &flags->verbose_ptx_asm,
|
||||
"emit PTX assembly with extra comments"),
|
||||
tensorflow::Flag("kernel", &flags->kernel,
|
||||
"only emit the IR and PTX for this kernel"),
|
||||
tensorflow::Flag("llvm_dump_passes", &flags->llvm_dump_passes,
|
||||
"dump the passes LLVM runs to stderr"),
|
||||
tensorflow::Flag(
|
||||
"llvm_cl_opts", &flags->llvm_cl_opts,
|
||||
"comma-separated list of command line options to pass to "
|
||||
"LLVM. For example, --llvm_cl_opts=--print-before=loop-unroll"),
|
||||
tensorflow::Flag("dump_ir_before_passes", &flags->dump_ir_before_passes,
|
||||
"dump the IR before each optimization pass in "
|
||||
"sequentially-named files."),
|
||||
tensorflow::Flag("opt_level", &flags->opt_level,
|
||||
"optimization level (default to 3)"),
|
||||
});
|
||||
ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with XLA's gpu_backend_lib
|
||||
// module.
|
||||
void AppendGpuBackendLibFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the GpuBackendLibFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
GpuBackendLibFlags* GetGpuBackendLibFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
@ -1,55 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_
|
||||
|
||||
// Legacy flags for XLA's gpu_backend_lib module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with XLA's gpu_backend_lib
|
||||
// module.
|
||||
void AppendGpuBackendLibFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with XLA's gpu_backend_lib module.
|
||||
typedef struct {
|
||||
string dump_temp_products_to; // temporary compilation products dir
|
||||
bool ftz; // flush to zero semantics
|
||||
bool fma; // use FMA synthesis
|
||||
bool verbose_ptx_asm; // emit PTX assembly with extra comments
|
||||
string kernel; // only emit the IR and PTX for this kernel
|
||||
bool llvm_dump_passes; // dump the passes LLVM runs to stderr
|
||||
string llvm_cl_opts; // comma-separated list of LLVM options
|
||||
bool dump_ir_before_passes; // dump IR before each pass
|
||||
int32 opt_level; // optimization level
|
||||
} GpuBackendLibFlags;
|
||||
|
||||
// Return a pointer to the GpuBackendLibFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
GpuBackendLibFlags* GetGpuBackendLibFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_
|
@ -1,76 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for XLA's gpu_compiler module.
|
||||
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static GpuCompilerFlags* flags;
|
||||
static std::vector<tensorflow::Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new GpuCompilerFlags;
|
||||
flags->xla_gpu_embed_ir = false;
|
||||
flags->xla_cuda_data_dir = "./cuda_sdk_lib";
|
||||
flags->xla_gpu_dump_debug_json_to = "";
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag(
|
||||
"xla_gpu_embed_ir", &flags->xla_gpu_embed_ir,
|
||||
"Embed the LLVM IR module string in the resultant GpuExecutable."),
|
||||
tensorflow::Flag(
|
||||
"xla_cuda_data_dir", &flags->xla_cuda_data_dir,
|
||||
"If non-empty, specifies a local directory containing ptxas and "
|
||||
"nvvm libdevice files. Otherwise, by default, we use those from "
|
||||
"runfile directories."),
|
||||
tensorflow::Flag("xla_ptxas_path", &flags->xla_ptxas_path,
|
||||
"The path to ptxas. Required to log stats of the ptx."),
|
||||
tensorflow::Flag("xla_gpu_dump_debug_json_to",
|
||||
&flags->xla_gpu_dump_debug_json_to,
|
||||
"Dump debug JSON to this directory."),
|
||||
});
|
||||
ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with XLA's gpu_compiler
|
||||
// module.
|
||||
void AppendGpuCompilerFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the GpuCompilerFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
GpuCompilerFlags* GetGpuCompilerFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
@ -1,55 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_
|
||||
|
||||
// Legacy flags for XLA's gpu_compiler module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with XLA's gpu_compiler
|
||||
// module.
|
||||
void AppendGpuCompilerFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with XLA's gpu_compiler module.
|
||||
typedef struct {
|
||||
bool xla_gpu_embed_ir; // Embed the LLVM IR module string in the resultant
|
||||
// GpuExecutable.
|
||||
string xla_cuda_data_dir; // If non-empty, specifies a local directory
|
||||
// containing ptxas and nvvm libdevice files.
|
||||
// Otherwise, by default, we use those from runfile
|
||||
// directories.
|
||||
string xla_ptxas_path; // The path to ptxas. Required to log stats of
|
||||
// the ptx.
|
||||
string xla_gpu_dump_debug_json_to; // Dump debug JSON to this directory.
|
||||
} GpuCompilerFlags;
|
||||
|
||||
// Return a pointer to the GpuCompilerFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
GpuCompilerFlags* GetGpuCompilerFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_
|
@ -1,63 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for XLA's llvm_util module.
|
||||
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static LlvmUtilFlags* flags;
|
||||
static std::vector<tensorflow::Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new LlvmUtilFlags;
|
||||
flags->xla_emit_tbaa = true;
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag("xla_emit_tbaa", &flags->xla_emit_tbaa,
|
||||
"Perform type-based alias analysis optimizations for "
|
||||
"LLVM-based backends."),
|
||||
});
|
||||
ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with XLA's llvm_util
|
||||
// module.
|
||||
void AppendLlvmUtilFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the LlvmUtilFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
LlvmUtilFlags* GetLlvmUtilFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
@ -1,46 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_
|
||||
|
||||
// Legacy flags for XLA's llvm_util module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with XLA's llvm_util module.
|
||||
void AppendLlvmUtilFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with XLA's llvm_util module.
|
||||
typedef struct {
|
||||
bool xla_emit_tbaa; // Perform type-based alias analysis optimizations for
|
||||
// LLVM-based backends.
|
||||
} LlvmUtilFlags;
|
||||
|
||||
// Return a pointer to the LlvmUtilFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
LlvmUtilFlags* GetLlvmUtilFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_
|
@ -321,6 +321,7 @@ Status Literal::Copy(const Literal& src_literal,
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> Literal::Relayout(const Layout& layout) const {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
std::unique_ptr<Literal> result = CloneToUnique();
|
||||
*result->mutable_shape()->mutable_layout() = layout;
|
||||
|
||||
@ -754,10 +755,30 @@ void Literal::EachCellAsString(
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename NativeSrcT, typename NativeDestT>
|
||||
std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
|
||||
auto result_literal = MakeUnique<Literal>();
|
||||
Shape* result_shape = result_literal->mutable_shape();
|
||||
*result_shape = src_literal.shape();
|
||||
result_shape->set_element_type(
|
||||
primitive_util::NativeToPrimitiveType<NativeDestT>());
|
||||
result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape));
|
||||
tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
|
||||
src_literal.GetArraySlice<NativeSrcT>();
|
||||
tensorflow::gtl::MutableArraySlice<NativeDestT> dest_data =
|
||||
result_literal->GetMutableArraySlice<NativeDestT>();
|
||||
int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape());
|
||||
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
dest_data[i] = static_cast<NativeDestT>(src_data[i]);
|
||||
}
|
||||
return result_literal;
|
||||
}
|
||||
|
||||
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
|
||||
std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
|
||||
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
|
||||
return LiteralUtil::Convert<
|
||||
return ConvertBetweenNativeTypes<
|
||||
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
|
||||
typename primitive_util::PrimitiveTypeToNative<
|
||||
primitive_dest_type>::type>(src_literal);
|
||||
@ -782,19 +803,20 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
|
||||
#undef CONVERT_IF_TYPES_MATCH
|
||||
// Other types are not yet supported.
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unimplemented: ConvertIfDestTypeMatches for type " +
|
||||
PrimitiveType_Name(src_literal.shape().element_type()));
|
||||
return InvalidArgument(
|
||||
"Unimplemented: Convert from type %s to type %s",
|
||||
PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
|
||||
PrimitiveType_Name(primitive_dest_type).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LiteralUtil::ConvertIfSrcTypeMatches(
|
||||
const Literal& src_literal, PrimitiveType primitive_dest_type) {
|
||||
switch (src_literal.shape().element_type()) {
|
||||
StatusOr<std::unique_ptr<Literal>> Literal::Convert(
|
||||
PrimitiveType primitive_dest_type) const {
|
||||
switch (shape().element_type()) {
|
||||
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
|
||||
case (type): \
|
||||
return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
|
||||
return ConvertIfDestTypeMatches<(type)>(*this, primitive_dest_type);
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(S8)
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(S32)
|
||||
@ -807,9 +829,9 @@ StatusOr<std::unique_ptr<Literal>> LiteralUtil::ConvertIfSrcTypeMatches(
|
||||
#undef CONVERT_IF_DEST_TYPE_MATCHES
|
||||
// Other types are not yet supported.
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unimplemented: ConvertIfSrcTypeMatches for type " +
|
||||
PrimitiveType_Name(src_literal.shape().element_type()));
|
||||
return InvalidArgument("Unimplemented: Convert from type %s to type %s",
|
||||
PrimitiveType_Name(shape().element_type()).c_str(),
|
||||
PrimitiveType_Name(primitive_dest_type).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -251,7 +251,7 @@ class Literal {
|
||||
*other = temp;
|
||||
}
|
||||
|
||||
// CreatesCreate new literal of a given rank. To minimize ambiguity (for users
|
||||
// Creates a new literal of a given rank. To minimize ambiguity (for users
|
||||
// and the compiler) these CreateR[0-2] methods should explicitly specify the
|
||||
// native type. For example:
|
||||
//
|
||||
@ -362,10 +362,10 @@ class Literal {
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<Literal> Replicate(int64 times) const;
|
||||
|
||||
// Creates a literal by converting each element in this literal to a new
|
||||
// type.
|
||||
template <typename NativeSrcT, typename NativeDestT>
|
||||
std::unique_ptr<Literal> Convert() const;
|
||||
// Converts this literal to another primitive type. Returns an error if the
|
||||
// conversion is not possible.
|
||||
StatusOr<std::unique_ptr<Literal>> Convert(
|
||||
PrimitiveType primitive_dest_type) const;
|
||||
|
||||
// Creates a literal value zero of the given primitive type.
|
||||
static Literal Zero(PrimitiveType primitive_type);
|
||||
@ -444,10 +444,21 @@ class Literal {
|
||||
template <typename NativeT>
|
||||
void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
|
||||
|
||||
// Retrieves the mutable array slice interface which can be used to manipulate
|
||||
// pre-allocated literal values.
|
||||
// Returns a (Mutable)ArraySlice view of the array for this literal for the
|
||||
// given NativeT (e.g., float). These functions map native type to XLA
|
||||
// PrimitiveType via template specialization. The unspecialized forms below
|
||||
// aborts to handle the error case where the given native type does not map to
|
||||
// an XLA primitive type.
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> GetMutableArraySlice();
|
||||
tensorflow::gtl::ArraySlice<NativeT> GetArraySlice() const {
|
||||
static_assert(!std::is_same<NativeT, NativeT>::value,
|
||||
"Cannot map native type to primitive type.");
|
||||
}
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> GetMutableArraySlice() {
|
||||
static_assert(!std::is_same<NativeT, NativeT>::value,
|
||||
"Cannot map native type to primitive type.");
|
||||
}
|
||||
|
||||
// Returns the element value at index (0, ..., 0), however many zeroes are
|
||||
// required for that index.
|
||||
@ -588,17 +599,6 @@ class Literal {
|
||||
bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
|
||||
|
||||
private:
|
||||
// Returns an ArraySlice view of the array for this literal for the given
|
||||
// NativeT (e.g., float). These functions map native type to XLA PrimitiveType
|
||||
// via template specialization. The unspecialized forms below aborts to handle
|
||||
// the error case where the given native type does not map to an XLA primitive
|
||||
// type.
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::ArraySlice<NativeT> GetArraySlice() const {
|
||||
static_assert(!std::is_same<NativeT, NativeT>::value,
|
||||
"Cannot map native type to primitive type.");
|
||||
}
|
||||
|
||||
// Copy from a LiteralProto instance.
|
||||
void CopyFromProto(const LiteralProto& literal_proto);
|
||||
|
||||
@ -646,544 +646,6 @@ class Literal {
|
||||
std::vector<Literal> tuple_literals_;
|
||||
};
|
||||
|
||||
// Utility class for dealing with XLA literal values. Most methods are
|
||||
// templated by native (host) type which corresponds to a unique XLA
|
||||
// PrimitiveType. See ComputationBuilder for details. Not all primitive types
|
||||
// defined in xla_data.proto have a corresponding native type or even have a
|
||||
// storage location in the Literal proto yet (for example, primitive type F16).
|
||||
//
|
||||
// TODO(dnovillo) - All functions in this class simply redirect to the
|
||||
// corresponding function in class Literal. Remove this class after converting
|
||||
// all user code to use Literal directly.
|
||||
class LiteralUtil {
|
||||
public:
|
||||
// Creates new literal of a given rank. To minimize ambiguity (for users and
|
||||
// the compiler) these CreateR[0-2] methods should explicitly specify the
|
||||
// native type. For example:
|
||||
//
|
||||
// CreateR1<float>({1.0, 42.0});
|
||||
// CreateR2<uint32>({{1, 2}, {3, 4}});
|
||||
//
|
||||
// The variants not ending with WithLayout use the default XLA layout for the
|
||||
// literal's linear representation in memory.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR0(NativeT value) {
|
||||
return Literal::CreateR0(value);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR1(
|
||||
tensorflow::gtl::ArraySlice<NativeT> values) {
|
||||
return Literal::CreateR1(values);
|
||||
}
|
||||
|
||||
static std::unique_ptr<Literal> CreateR1(
|
||||
const tensorflow::core::Bitmap& values) {
|
||||
return Literal::CreateR1(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR2(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values) {
|
||||
return Literal::CreateR2(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR2WithLayout(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
const Layout& layout) {
|
||||
return Literal::CreateR2WithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR3(
|
||||
std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>
|
||||
values) {
|
||||
return Literal::CreateR3(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR3WithLayout(
|
||||
std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>
|
||||
values,
|
||||
const Layout& layout) {
|
||||
return Literal::CreateR3WithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR4(
|
||||
std::initializer_list<std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>>
|
||||
values) {
|
||||
return Literal::CreateR4(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR4WithLayout(
|
||||
std::initializer_list<std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>>
|
||||
values,
|
||||
const Layout& layout) {
|
||||
return Literal::CreateR4WithLayout(values, layout);
|
||||
}
|
||||
|
||||
// Creates a new Literal object with the shape specified as parameter.
|
||||
// The content of the literal values is the default value of the primitive
|
||||
// type of literal itself (0 for numeric types, and false for predicates).
|
||||
static std::unique_ptr<Literal> CreateFromShape(const Shape& shape) {
|
||||
return Literal::CreateFromShape(shape);
|
||||
}
|
||||
|
||||
// Creates a new Literal object with its values havings the primitive_type
|
||||
// type, and with dimensions defined by the dimensions parameter.
|
||||
// The content of the literal values is the default value of the primitive
|
||||
// type of literal itself (0 for numeric types, and false for predicates).
|
||||
static std::unique_ptr<Literal> CreateFromDimensions(
|
||||
PrimitiveType primitive_type,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
return Literal::CreateFromDimensions(primitive_type, dimensions);
|
||||
}
|
||||
|
||||
// Copies the values from src_literal, starting at src_base shape indexes,
|
||||
// to dest_literal, starting at dest_base, where the copy size in each
|
||||
// dimension is specified by copy_size.
|
||||
//
|
||||
// The src_literal and dest_literal must have the same primitive type,
|
||||
// src_base+copy_size must fit the source literal dimensions, as well as
|
||||
// dest_base+copy_size must fit the destination literal dimensions.
|
||||
static Status Copy(const Literal& src_literal,
|
||||
tensorflow::gtl::ArraySlice<int64> src_base,
|
||||
Literal* dest_literal,
|
||||
tensorflow::gtl::ArraySlice<int64> dest_base,
|
||||
tensorflow::gtl::ArraySlice<int64> copy_size) {
|
||||
return dest_literal->Copy(src_literal, src_base, dest_base, copy_size);
|
||||
}
|
||||
|
||||
// Creates a new value that has the equivalent value as literal, but conforms
|
||||
// to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major
|
||||
// dimension layout can be re-laid-out as {1, 0} minor-to-major dimension
|
||||
// layout and the value in the cell at any given logical index (i0, i1) will
|
||||
// be the same.
|
||||
//
|
||||
// Note: this is useful when the client wants to ensure that a value placed in
|
||||
// the XLA allocation tracker has a particular layout; for efficiency
|
||||
// purposes or avoiding unimplemented operation/layout combinations.
|
||||
static std::unique_ptr<Literal> Relayout(const Literal& literal,
|
||||
const Layout& new_layout) {
|
||||
return literal.Relayout(new_layout);
|
||||
}
|
||||
|
||||
// Reshapes literal 'input' to have 'shape'. Both the original shape and
|
||||
// 'shape' must contain the same number of elements. The implementation
|
||||
// currently only supports monotonic dim0-major layouts.
|
||||
static StatusOr<std::unique_ptr<Literal>> Reshape(
|
||||
const xla::Literal& input, tensorflow::gtl::ArraySlice<int64> shape) {
|
||||
return input.Reshape(shape);
|
||||
}
|
||||
|
||||
// Creates a new literal by reordering the dimensions of the original literal.
|
||||
// The given `permutation` must be a permutation of the dimension numbers
|
||||
// in the original literal, and it specifies the order of the new dimensions
|
||||
// in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
|
||||
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
|
||||
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
|
||||
static std::unique_ptr<Literal> Transpose(
|
||||
const Literal& literal, tensorflow::gtl::ArraySlice<int64> permutation) {
|
||||
return literal.Transpose(permutation);
|
||||
}
|
||||
|
||||
// Creates a sub-array from the given literal by extracting the indices
|
||||
// [start_index, limit_index) of each dimension. The result literal has the
|
||||
// same rank and layout as for the given literal. The number of indices in
|
||||
// start_indices and limit_indices must be the rank of the literal, and the
|
||||
// indices follow the order of the dimensions.
|
||||
static std::unique_ptr<Literal> Slice(
|
||||
const Literal& literal, tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices) {
|
||||
return literal.Slice(start_indices, limit_indices);
|
||||
}
|
||||
|
||||
// Creates a literal with a prepended dimension with bound "times"; e.g. a
|
||||
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input
|
||||
// literal replicated four times.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> Replicate(const Literal& input, int64 times) {
|
||||
return input.Replicate<NativeT>(times);
|
||||
}
|
||||
|
||||
// Creates a literal by converting each element in an original literal to a
|
||||
// new type.
|
||||
template <typename NativeSrcT, typename NativeDestT>
|
||||
static std::unique_ptr<Literal> Convert(const Literal& literal) {
|
||||
return literal.Convert<NativeSrcT, NativeDestT>();
|
||||
}
|
||||
|
||||
// Convert a literal to another primitive type, but only if the literal
|
||||
// type is connvertable into the destination type
|
||||
static StatusOr<std::unique_ptr<Literal>> ConvertIfSrcTypeMatches(
|
||||
const Literal& src_literal, PrimitiveType primitive_dest_type);
|
||||
|
||||
// Creates a literal value zero of the given primitive type.
|
||||
static Literal Zero(PrimitiveType primitive_type) {
|
||||
return Literal::Zero(primitive_type);
|
||||
}
|
||||
|
||||
// Creates a literal value one of the given primitive type.
|
||||
static Literal One(PrimitiveType primitive_type) {
|
||||
return Literal::One(primitive_type);
|
||||
}
|
||||
|
||||
// Creates a literal value containing the minimum value of the given
|
||||
// primitive type. For floating-point types, returns -inf.
|
||||
static Literal MinValue(PrimitiveType primitive_type) {
|
||||
return Literal::MinValue(primitive_type);
|
||||
}
|
||||
|
||||
// Creates a literal value containing the maximum value of the given
|
||||
// primitive type. For floating-point types, returns inf.
|
||||
static Literal MaxValue(PrimitiveType primitive_type) {
|
||||
return Literal::MaxValue(primitive_type);
|
||||
}
|
||||
|
||||
// Creates a literal of the given shape where each element is `value`.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateFullWithMonotonicDim0MajorLayout(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
|
||||
return Literal::CreateFullWithMonotonicDim0MajorLayout(dimensions, value);
|
||||
}
|
||||
|
||||
// Creates a new literal from an array. The variants not ending with
|
||||
// WithLayout use the default XLA layout for the literal's linear
|
||||
// representation in memory.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR2FromArray2D(
|
||||
const Array2D<NativeT>& values) {
|
||||
return Literal::CreateR2FromArray2D(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
|
||||
const Array2D<NativeT>& values, const Layout& layout) {
|
||||
return Literal::CreateR2FromArray2DWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR3FromArray3D(
|
||||
const Array3D<NativeT>& values) {
|
||||
return Literal::CreateR3FromArray3D(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
|
||||
const Array3D<NativeT>& values, const Layout& layout) {
|
||||
return Literal::CreateR3FromArray3DWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR4FromArray4D(
|
||||
const Array4D<NativeT>& values) {
|
||||
return Literal::CreateR4FromArray4D(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
|
||||
const Array4D<NativeT>& values, const Layout& layout) {
|
||||
return Literal::CreateR4FromArray4DWithLayout(values, layout);
|
||||
}
|
||||
|
||||
// Creates a new vector of U8s literal value from a string.
|
||||
static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value) {
|
||||
return Literal::CreateR1U8(value);
|
||||
}
|
||||
|
||||
// Creates a linspace-populated literal with the given number of rows and
|
||||
// columns.
|
||||
static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
|
||||
int64 rows, int64 cols) {
|
||||
return Literal::CreateR2F32Linspace(from, to, rows, cols);
|
||||
}
|
||||
|
||||
// Creates a literal that projects the (x, y) dimensions given in values into
|
||||
// the z dimension given by "projection".
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR3Projected(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
int64 projection) {
|
||||
return Literal::CreateR3Projected(values, projection);
|
||||
}
|
||||
|
||||
// Creates a literal that projects the (x, y) dimensions given in values into
|
||||
// the z and p dimensions given.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR4Projected(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
int64 projection_p, int64 projection_z) {
|
||||
return Literal::CreateR4Projected(values, projection_p, projection_z);
|
||||
}
|
||||
|
||||
// Clones literal into an owned unique_ptr version.
|
||||
static std::unique_ptr<Literal> CloneToUnique(const Literal& literal) {
|
||||
return literal.CloneToUnique();
|
||||
}
|
||||
|
||||
// Returns the linear index of the given index within the literal's
|
||||
// element_type repeated field.
|
||||
static int64 LinearIndex(const Literal& literal,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return literal.LinearIndex(multi_index);
|
||||
}
|
||||
|
||||
// Gets or sets an element in the literal at the given index. The index is
|
||||
// CHECKed against the dimension sizes.
|
||||
template <typename NativeT>
|
||||
static NativeT Get(const Literal& literal,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return literal.Get<NativeT>(multi_index);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void Set(Literal* literal,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
NativeT value) {
|
||||
literal->Set(multi_index, value);
|
||||
}
|
||||
|
||||
// Retrieves the mutable array slice interface which can be used to manipulate
|
||||
// pre-allocated literal values.
|
||||
template <typename NativeT>
|
||||
static tensorflow::gtl::MutableArraySlice<NativeT> GetMutableArraySlice(
|
||||
Literal* literal) {
|
||||
return literal->GetMutableArraySlice<NativeT>();
|
||||
}
|
||||
|
||||
// Returns the element value at index (0, ..., 0), however many zeroes are
|
||||
// required for that index.
|
||||
template <typename NativeT>
|
||||
static NativeT GetFirstElement(const Literal& literal) {
|
||||
return literal.GetFirstElement<NativeT>();
|
||||
}
|
||||
|
||||
// As Get(), but determines the correct type and converts the value
|
||||
// into text.
|
||||
static string GetAsString(const Literal& literal,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return literal.GetAsString(multi_index);
|
||||
}
|
||||
|
||||
// Returns an identity matrix (rank 2) with the given row and column count.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> MakeIdentityR2(int64 size) {
|
||||
return Literal::MakeIdentityR2<NativeT>(size);
|
||||
}
|
||||
|
||||
// Returns a tuple literal composed of given literals.
|
||||
static std::unique_ptr<Literal> MakeTuple(
|
||||
tensorflow::gtl::ArraySlice<const Literal*> elements) {
|
||||
return Literal::MakeTuple(elements);
|
||||
}
|
||||
|
||||
// Validates that the data payload of the literal matches the literal shape;
|
||||
// if it does not, an appropriate status is returned.
|
||||
static tensorflow::Status ValidateLiteral(const Literal& literal) {
|
||||
return literal.ValidateLiteral();
|
||||
}
|
||||
|
||||
// Returns a string representation of the literal value.
|
||||
static string ToString(const Literal& literal) { return literal.ToString(); }
|
||||
|
||||
// Invokes the "per cell" callback for each element in the provided
|
||||
// literal with the element's indices and a string representation of
|
||||
// the element's value.
|
||||
//
|
||||
// This function is useful if you want a polymorphic representation
|
||||
// of the tensor's elements (turning it to a string for something
|
||||
// like representation in a protobuf).
|
||||
static void EachCellAsString(
|
||||
const Literal& literal,
|
||||
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
|
||||
const string& value)>& per_cell) {
|
||||
literal.EachCellAsString(per_cell);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void EachCell(
|
||||
const Literal& literal,
|
||||
std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
|
||||
NativeT value)>
|
||||
per_cell) {
|
||||
literal.EachCell<NativeT>(per_cell);
|
||||
}
|
||||
|
||||
// Templated methods which populate the given repeated field in the Literal
|
||||
// proto with the given value(s). The Shape field of the Literal proto is set
|
||||
// to match the array dimensions and type. Examples:
|
||||
//
|
||||
// // Populate with floats.
|
||||
// Array2D<float> float_values = ...
|
||||
// PopulateR2FromArray2D(values, literal);
|
||||
//
|
||||
// // Populate with int32s.
|
||||
// PopulateR2({{1, 2}, {3, 4}}, literal);
|
||||
//
|
||||
template <typename NativeT>
|
||||
static void PopulateR0(NativeT values, Literal* literal) {
|
||||
literal->PopulateR0(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values,
|
||||
Literal* literal) {
|
||||
literal->PopulateR1(values);
|
||||
}
|
||||
|
||||
static void PopulateR1(const tensorflow::core::Bitmap& values,
|
||||
Literal* literal) {
|
||||
literal->PopulateR1(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void PopulateR2(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
Literal* literal) {
|
||||
literal->PopulateR2(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void PopulateR2WithLayout(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
const Layout& layout, Literal* literal) {
|
||||
literal->PopulateR2WithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void PopulateR2FromArray2D(const Array2D<NativeT>& values,
|
||||
Literal* literal) {
|
||||
literal->PopulateR2FromArray2D(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
|
||||
const Layout& layout,
|
||||
Literal* literal) {
|
||||
literal->PopulateR2FromArray2DWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void PopulateR3FromArray3D(const Array3D<NativeT>& values,
|
||||
Literal* literal) {
|
||||
literal->PopulateR3FromArray3D(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
|
||||
const Layout& layout,
|
||||
Literal* literal) {
|
||||
literal->PopulateR3FromArray3DWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void PopulateR4FromArray4D(const Array4D<NativeT>& values,
|
||||
Literal* literal) {
|
||||
literal->PopulateR4FromArray4D(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
|
||||
const Layout& layout,
|
||||
Literal* literal) {
|
||||
literal->PopulateR4FromArray4DWithLayout(values, layout);
|
||||
}
|
||||
|
||||
// Populates literal values by calling the generator function for every cell
|
||||
// in the literal object.
|
||||
template <typename NativeT>
|
||||
static Status Populate(
|
||||
Literal* literal,
|
||||
const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
|
||||
generator) {
|
||||
return literal->Populate(generator);
|
||||
}
|
||||
|
||||
// Creates a Literal of the given dimensions with all elements set to the
|
||||
// given value.
|
||||
template <typename NativeT>
|
||||
static void PopulateWithValue(NativeT value,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
Literal* literal) {
|
||||
return literal->PopulateWithValue(value, dimensions);
|
||||
}
|
||||
|
||||
// Returns a pointer to the underlying vector containing the array data. Use
|
||||
// with care.
|
||||
static const void* InternalData(const Literal& literal) {
|
||||
return literal.InternalData();
|
||||
}
|
||||
|
||||
static void* MutableInternalData(Literal* literal) {
|
||||
return literal->MutableInternalData();
|
||||
}
|
||||
|
||||
// Allocates space in the underlying vector of the literal sufficient to hold
|
||||
// num_elements of the literal's primitive type. Values in the vector are set
|
||||
// to zero. num_elements must equal the number of elements in the literals
|
||||
// shape.
|
||||
static void Reserve(int64 num_elements, Literal* literal) {
|
||||
literal->Reserve(num_elements);
|
||||
}
|
||||
|
||||
// Allocates space in the underlying vector of the literal sufficient to hold
|
||||
// num_elements of the literal's primitive type and sets each element in the
|
||||
// literal to the given value. num_elements must equal the number of elements
|
||||
// in the literals shape.
|
||||
template <typename NativeT>
|
||||
static void Resize(int64 num_elements, NativeT value, Literal* literal) {
|
||||
literal->Resize(num_elements, value);
|
||||
}
|
||||
|
||||
// Returns true if the two given literals have the same shape and
|
||||
// values. Layout is not considered in the comparison.
|
||||
static bool Equal(const Literal& literal1, const Literal& literal2) {
|
||||
return literal1.Equal(literal2);
|
||||
}
|
||||
|
||||
// Returns whether every element in the given literal is equal to value.
|
||||
//
|
||||
// value is an int8 because we expect this to be called with small
|
||||
// compile-time constants (0, -1, etc.) and so that whatever value you pass
|
||||
// can be represented exactly by floating-point types as small as 16 bits.
|
||||
//
|
||||
// If value doesn't fit in literal's type, returns false. Values of 1/0 are
|
||||
// considered equal to true/false; other values are not considered equal to
|
||||
// true.
|
||||
static bool IsAll(const Literal& literal, int8 value) {
|
||||
return literal.IsAll(value);
|
||||
}
|
||||
|
||||
// Like IsAll(const Literal&, int8), except we check whether the literal is
|
||||
// equal to a particular floating-point number.
|
||||
//
|
||||
// If the literal is not a floating-point value, this always returns false.
|
||||
//
|
||||
// This casts value to the type of literal, then compares using ==. The usual
|
||||
// admonishments about floating-point equality checks apply. We expect you to
|
||||
// use this to check for values that can be expressed precisely as a float,
|
||||
// e.g. -0.5.
|
||||
static bool IsAllFloat(const Literal& literal, float value) {
|
||||
return literal.IsAllFloat(value);
|
||||
}
|
||||
|
||||
// Returns whether the literal is zero at the specified index. The literal
|
||||
// must be an array.
|
||||
static bool IsZero(const Literal& literal,
|
||||
tensorflow::gtl::ArraySlice<int64> indices) {
|
||||
return literal.IsZero(indices);
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil);
|
||||
};
|
||||
|
||||
// Declarations of template specializations for GetArraySlice and
|
||||
// GetMutableArraySlice. The specializations map native type to XLA primitive
|
||||
// type.
|
||||
@ -1759,27 +1221,6 @@ void Literal::PopulateWithValue(NativeT value,
|
||||
Resize<NativeT>(ShapeUtil::ElementsIn(shape()), value);
|
||||
}
|
||||
|
||||
template <typename NativeSrcT, typename NativeDestT>
|
||||
std::unique_ptr<Literal> Literal::Convert() const {
|
||||
const Shape& this_shape = shape();
|
||||
auto result_literal = MakeUnique<Literal>();
|
||||
Shape* result_shape = result_literal->mutable_shape();
|
||||
*result_shape = this_shape;
|
||||
result_shape->set_element_type(
|
||||
primitive_util::NativeToPrimitiveType<NativeDestT>());
|
||||
result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape));
|
||||
tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
|
||||
GetArraySlice<NativeSrcT>();
|
||||
tensorflow::gtl::MutableArraySlice<NativeDestT> dest_data =
|
||||
result_literal->GetMutableArraySlice<NativeDestT>();
|
||||
int64 num_elements = ShapeUtil::ElementsIn(this_shape);
|
||||
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
dest_data[i] = static_cast<NativeDestT>(src_data[i]);
|
||||
}
|
||||
return result_literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal>
|
||||
Literal::CreateFullWithMonotonicDim0MajorLayout(
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -58,8 +58,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
|
||||
}
|
||||
|
||||
int64 elements = ShapeUtil::ElementsIn(shape);
|
||||
LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
|
||||
result.get());
|
||||
result.get()->Resize(elements, std::numeric_limits<float>::quiet_NaN());
|
||||
std::vector<float>* field = result->mutable_f32s();
|
||||
char* data = tensorflow::bit_cast<char*>(field->data());
|
||||
uint64 bytes = elements * sizeof(float);
|
||||
|
@ -52,7 +52,7 @@ class ReferenceUtilTest : public ::testing::Test {
|
||||
|
||||
TEST_F(ReferenceUtilTest, TransposeArray2D) {
|
||||
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
@ -62,7 +62,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
|
||||
{7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f},
|
||||
});
|
||||
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
@ -70,7 +70,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
|
||||
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
|
||||
auto add = [](float lhs, float rhs) { return lhs + rhs; };
|
||||
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
|
||||
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
|
||||
auto actual_literal = Literal::CreateR1<float>(*result);
|
||||
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
@ -78,7 +78,7 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
|
||||
TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
|
||||
auto add = [](float lhs, float rhs) { return lhs + rhs; };
|
||||
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
|
||||
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
|
||||
auto actual_literal = Literal::CreateR1<float>(*result);
|
||||
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
@ -86,7 +86,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
|
||||
TEST_F(ReferenceUtilTest, MapArray2D) {
|
||||
auto identity = [](float value) { return log(exp(value)); };
|
||||
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
@ -96,7 +96,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
|
||||
return value + row + col;
|
||||
};
|
||||
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
@ -107,7 +107,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
|
||||
input->FillWithMultiples(1.0f);
|
||||
auto multiply_by_two = [](float value) { return 2 * value; };
|
||||
auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two);
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*result);
|
||||
|
||||
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
|
||||
expected.FillWithMultiples(2.0f);
|
||||
@ -124,7 +124,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
|
||||
return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width);
|
||||
};
|
||||
auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index);
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*result);
|
||||
|
||||
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
|
||||
expected.Fill(0.0f);
|
||||
@ -161,7 +161,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
|
||||
}));
|
||||
// clang-format on
|
||||
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
@ -195,7 +195,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
|
||||
}));
|
||||
// clang-format on
|
||||
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
@ -247,7 +247,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
|
||||
}});
|
||||
// clang-format on
|
||||
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
@ -296,7 +296,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
|
||||
Array4D<float> expected({{{{2514, 2685}}}});
|
||||
// clang-format on
|
||||
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
@ -309,7 +309,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
|
||||
|
||||
auto actual = ReferenceUtil::ApplyElementwise2D(
|
||||
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
|
||||
auto actual_literal = Literal::CreateR2FromArray2D(*actual);
|
||||
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
|
@ -112,6 +112,7 @@ cc_test(
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
@ -330,6 +331,7 @@ cc_library(
|
||||
hdrs = ["backend.h"],
|
||||
deps = [
|
||||
":compiler",
|
||||
":computation_placer",
|
||||
":device_memory_allocator",
|
||||
":platform_util",
|
||||
":pool",
|
||||
@ -338,7 +340,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
@ -382,6 +383,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:service_flags",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
||||
"//tensorflow/core:lib",
|
||||
@ -416,6 +418,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:service_flags",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
@ -948,6 +951,26 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "computation_placer",
|
||||
srcs = ["computation_placer.cc"],
|
||||
hdrs = ["computation_placer.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
alwayslink = True, # Contains per-platform computation placer registration
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "generic_transfer_manager",
|
||||
srcs = ["generic_transfer_manager.cc"],
|
||||
@ -1165,6 +1188,7 @@ cc_library(
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":liveness_util",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1398,7 +1422,6 @@ cc_library(
|
||||
":call_graph",
|
||||
":flatten_call_graph",
|
||||
":hlo",
|
||||
":hlo_cost_analysis",
|
||||
":hlo_dce",
|
||||
":hlo_ordering",
|
||||
":liveness_util",
|
||||
@ -1572,10 +1595,8 @@ cc_test(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1777,7 +1798,6 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_proto",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -48,7 +48,7 @@ namespace {
|
||||
// Returns whether operand is a literal with the given value.
|
||||
bool IsLiteralWithValue(const HloInstruction* operand, int8 value) {
|
||||
return operand->opcode() == HloOpcode::kConstant &&
|
||||
LiteralUtil::IsAll(operand->literal(), value);
|
||||
operand->literal().IsAll(value);
|
||||
}
|
||||
|
||||
bool IsAll(const HloInstruction* op, int8 value) {
|
||||
@ -126,10 +126,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
HloInstruction* concatenate,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
|
||||
|
||||
Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
|
||||
Status HandleCopy(HloInstruction* copy) override;
|
||||
|
||||
Status HandleConvert(HloInstruction* convert,
|
||||
HloInstruction* operand) override;
|
||||
Status HandleConvert(HloInstruction* convert) override;
|
||||
|
||||
Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
|
||||
HloInstruction* rhs, const Window& window) override;
|
||||
@ -179,11 +178,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override;
|
||||
|
||||
Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override;
|
||||
|
||||
Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override;
|
||||
Status HandleMaximum(HloInstruction* maximum) override;
|
||||
Status HandleMinimum(HloInstruction* minimum) override;
|
||||
|
||||
// Returns whether algebraic simplification has occurred.
|
||||
const bool changed() const { return changed_; }
|
||||
@ -334,16 +330,16 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy,
|
||||
HloInstruction* operand) {
|
||||
Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
|
||||
// If a copy feeds a copy, make it a single copy.
|
||||
if (operand->opcode() == HloOpcode::kCopy) {
|
||||
if (copy->operand(0)->opcode() == HloOpcode::kCopy) {
|
||||
return ReplaceWithNewInstruction(
|
||||
copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy,
|
||||
operand->operands()[0]));
|
||||
copy, HloInstruction::CreateUnary(
|
||||
copy->shape(), HloOpcode::kCopy,
|
||||
copy->mutable_operand(0)->mutable_operand(0)));
|
||||
}
|
||||
// All copies can be eliminated (assuming layout constraints are satisified).
|
||||
ReplaceInstructionIfSameShape(copy, operand);
|
||||
ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -469,7 +465,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
|
||||
ShapeUtil::HasZeroElements(lhs->shape()) ||
|
||||
ShapeUtil::HasZeroElements(rhs->shape())) {
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
|
||||
return ReplaceWithNewInstruction(
|
||||
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
|
||||
}
|
||||
@ -507,7 +503,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
|
||||
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
|
||||
computation_->parent(), F32, HloOpcode::kAdd);
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
|
||||
auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||
ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
|
||||
{0}, add_reduce_computation));
|
||||
@ -531,7 +527,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
|
||||
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
|
||||
computation_->parent(), F32, HloOpcode::kAdd);
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
|
||||
HloInstruction* reduce;
|
||||
if (ShapeUtil::Rank(rhs->shape()) == 1) {
|
||||
auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
@ -571,7 +567,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
|
||||
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
|
||||
computation_->parent(), F32, HloOpcode::kAdd);
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
|
||||
auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||
ShapeUtil::MakeShape(dot->shape().element_type(),
|
||||
{lhs->shape().dimensions(0)}),
|
||||
@ -792,12 +788,11 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
|
||||
// A conversion to the same element type as the operand is a nop and can be
|
||||
// removed. A conversion of a constant can be simplified by making a new
|
||||
// constant.
|
||||
Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert,
|
||||
HloInstruction* operand) {
|
||||
PrimitiveType src_type = operand->shape().element_type();
|
||||
Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
|
||||
PrimitiveType src_type = convert->operand(0)->shape().element_type();
|
||||
PrimitiveType dest_type = convert->shape().element_type();
|
||||
if (src_type == dest_type) {
|
||||
return ReplaceInstruction(convert, operand);
|
||||
return ReplaceInstruction(convert, convert->mutable_operand(0));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -897,8 +892,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power,
|
||||
HloInstruction* rhs) {
|
||||
VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
|
||||
if (IsAll(rhs, 0)) {
|
||||
auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique(
|
||||
LiteralUtil::One(power->shape().element_type())));
|
||||
auto one = HloInstruction::CreateConstant(
|
||||
Literal::One(power->shape().element_type()).CloneToUnique());
|
||||
std::unique_ptr<HloInstruction> ones;
|
||||
if (ShapeUtil::IsScalar(power->shape())) {
|
||||
ones = std::move(one);
|
||||
@ -923,9 +918,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power,
|
||||
|
||||
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
|
||||
if (IsAll(rhs, -1)) {
|
||||
auto* one = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CloneToUnique(
|
||||
LiteralUtil::One(rhs->shape().element_type()))));
|
||||
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::One(rhs->shape().element_type()).CloneToUnique()));
|
||||
return ReplaceWithNewInstruction(
|
||||
power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
|
||||
one, lhs));
|
||||
@ -1008,7 +1002,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
|
||||
// dimension.
|
||||
if (ShapeUtil::HasZeroElements(reshape->shape())) {
|
||||
auto empty_constant = HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateFromShape(reshape->shape()));
|
||||
Literal::CreateFromShape(reshape->shape()));
|
||||
|
||||
return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
|
||||
}
|
||||
@ -1208,8 +1202,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
||||
// try to get more fancy about proving equivalence in cases beyond that.
|
||||
if (pad_value->opcode() != HloOpcode::kConstant ||
|
||||
reduce_init_value->opcode() != HloOpcode::kConstant ||
|
||||
!LiteralUtil::Equal(pad_value->literal(),
|
||||
reduce_init_value->literal())) {
|
||||
!pad_value->literal().Equal(reduce_init_value->literal())) {
|
||||
VLOG(10) << "Not folding pad into reduce-window due to different pad "
|
||||
"values.";
|
||||
return Status::OK();
|
||||
@ -1396,9 +1389,7 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
|
||||
return true;
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum,
|
||||
HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
|
||||
// Match the following tree:
|
||||
// min_operand operand
|
||||
// \ /
|
||||
@ -1429,9 +1420,7 @@ Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum,
|
||||
HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
|
||||
// Match the following tree:
|
||||
// max_operand operand
|
||||
// \ /
|
||||
|
@ -55,7 +55,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
|
||||
|
||||
@ -76,7 +76,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction* bcast = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
|
||||
builder.AddInstruction(
|
||||
@ -99,7 +99,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({0, 0, 0})));
|
||||
HloInstruction* bcast =
|
||||
builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
|
||||
builder.AddInstruction(
|
||||
@ -123,7 +123,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
|
||||
|
||||
@ -145,7 +145,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
|
||||
HloInstruction* div = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
|
||||
|
||||
@ -167,7 +167,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||
HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
|
||||
Literal::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
|
||||
HloInstruction* div = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
|
||||
|
||||
@ -300,7 +300,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
|
||||
|
||||
@ -315,7 +315,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Constant());
|
||||
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->literal()), 1);
|
||||
EXPECT_EQ(root->literal().GetFirstElement<float>(), 1);
|
||||
}
|
||||
|
||||
// Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
|
||||
@ -325,7 +325,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r1f32, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
|
||||
|
||||
@ -344,8 +344,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
|
||||
<< ShapeUtil::HumanString(root->shape());
|
||||
EXPECT_EQ(root->dimensions().size(), 0);
|
||||
EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
|
||||
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->operand(0)->literal()),
|
||||
1);
|
||||
EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
|
||||
}
|
||||
|
||||
// Test that pow(A, 1) is simplified to A.
|
||||
@ -355,7 +354,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
|
||||
|
||||
@ -378,7 +377,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* two = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
|
||||
|
||||
@ -401,7 +400,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* negative_one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(-1)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
|
||||
param0, negative_one));
|
||||
|
||||
@ -416,8 +415,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Divide(op::Constant(), param0));
|
||||
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->operand(0)->literal()),
|
||||
1);
|
||||
EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
|
||||
@ -451,7 +449,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
|
||||
TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* input = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
|
||||
|
||||
@ -519,7 +517,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, r1f32, "param1"));
|
||||
HloInstruction* empty_literal = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
|
||||
HloInstruction* empty_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
|
||||
@ -550,7 +548,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r1f32, "param0"));
|
||||
HloInstruction* empty_literal = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
|
||||
HloInstruction* empty_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
|
||||
@ -735,7 +733,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
|
||||
builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
|
||||
HloOpcode::kMaximum, movable_reshape, zero));
|
||||
@ -1035,7 +1033,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
PaddingConfig no_padding;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
auto dimension = no_padding.add_dimensions();
|
||||
@ -1066,7 +1064,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
PaddingConfig padding;
|
||||
int64 low_padding[2] = {-1, -2};
|
||||
int64 high_padding[2] = {2, -3};
|
||||
@ -1376,9 +1374,9 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* min_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction* max_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
|
||||
HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
r0f32, HloOpcode::kMinimum, param0, min_value));
|
||||
builder.AddInstruction(
|
||||
@ -1406,9 +1404,9 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* min_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction* max_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
|
||||
HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
r0f32, HloOpcode::kMaximum, param0, max_value));
|
||||
builder.AddInstruction(
|
||||
@ -1437,9 +1435,9 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r1f32, "param0"));
|
||||
HloInstruction* min_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction* max_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
|
||||
HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
r1f32, HloOpcode::kMaximum, param0, max_value));
|
||||
builder.AddInstruction(
|
||||
@ -1497,9 +1495,9 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* min_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
HloInstruction* max_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
|
||||
HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
r0f32, HloOpcode::kMaximum, param0, max_value));
|
||||
HloInstruction* fmax = builder.AddInstruction(
|
||||
@ -1566,7 +1564,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
|
||||
TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* forty_two = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
|
||||
|
||||
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
|
||||
HloInstruction* broadcast =
|
||||
@ -1614,7 +1612,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
|
||||
padding.mutable_dimensions(3)->set_edge_padding_high(2);
|
||||
|
||||
HloInstruction* pad_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
|
||||
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
|
||||
ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
|
||||
|
||||
@ -1645,7 +1643,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
|
||||
const Shape reduce_window_shape =
|
||||
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
|
||||
HloInstruction* reduce_init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
|
||||
HloInstruction* reduce_window =
|
||||
builder.AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
reduce_window_shape, pad, reduce_init_value, window,
|
||||
@ -1714,9 +1712,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
|
||||
|
||||
HloComputation::Builder call_builder(TestName() + ".Call");
|
||||
HloInstruction* zero = call_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({0.0f})));
|
||||
HloInstruction* one = call_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0f})));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
|
||||
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -51,13 +50,6 @@ perftools::gputools::Platform* BackendOptions::platform() const {
|
||||
return platform_;
|
||||
}
|
||||
|
||||
BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) {
|
||||
number_of_replicas_ = number_of_replicas;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int BackendOptions::number_of_replicas() const { return number_of_replicas_; }
|
||||
|
||||
BackendOptions& BackendOptions::set_intra_op_parallelism_threads(
|
||||
int num_threads) {
|
||||
intra_op_parallelism_threads_ = num_threads;
|
||||
@ -85,20 +77,17 @@ struct Backend::EigenThreadPoolWrapper {
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
|
||||
const BackendOptions& options) {
|
||||
int64 replica_count = options.number_of_replicas();
|
||||
if (replica_count == -1) {
|
||||
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
|
||||
replica_count = flags->xla_replicas;
|
||||
}
|
||||
perftools::gputools::Platform* platform = options.platform();
|
||||
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
|
||||
TF_ASSIGN_OR_RETURN(auto stream_executors,
|
||||
PlatformUtil::GetStreamExecutors(platform));
|
||||
TF_ASSIGN_OR_RETURN(auto transfer_manager,
|
||||
TransferManager::GetForPlatform(platform));
|
||||
TF_ASSIGN_OR_RETURN(auto computation_placer,
|
||||
ComputationPlacer::GetForPlatform(platform));
|
||||
std::unique_ptr<Backend> backend(
|
||||
new Backend(replica_count, platform, compiler, stream_executors,
|
||||
transfer_manager, options.intra_op_parallelism_threads()));
|
||||
new Backend(platform, compiler, stream_executors, transfer_manager,
|
||||
computation_placer, options.intra_op_parallelism_threads()));
|
||||
return std::move(backend);
|
||||
}
|
||||
|
||||
@ -132,34 +121,25 @@ StatusOr<Backend::StreamPtr> Backend::BorrowStream(
|
||||
}
|
||||
|
||||
Backend::Backend(
|
||||
int64 replica_count, perftools::gputools::Platform* platform,
|
||||
Compiler* compiler,
|
||||
perftools::gputools::Platform* platform, Compiler* compiler,
|
||||
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
|
||||
TransferManager* transfer_manager, int intra_op_parallelism_threads)
|
||||
TransferManager* transfer_manager, ComputationPlacer* computation_placer,
|
||||
int intra_op_parallelism_threads)
|
||||
: platform_(platform),
|
||||
compiler_(compiler),
|
||||
transfer_manager_(transfer_manager),
|
||||
replica_count_(replica_count) {
|
||||
computation_placer_(computation_placer) {
|
||||
// The given set of stream executors set may include invalid executors.
|
||||
for (se::StreamExecutor* exec : stream_executors) {
|
||||
if (exec != nullptr) {
|
||||
stream_executors_.push_back(exec);
|
||||
}
|
||||
}
|
||||
CHECK_GE(replica_count, 1) << "Must request at least 1 replica.";
|
||||
|
||||
// Create a memory allocator for the valid stream executors.
|
||||
memory_allocator_ =
|
||||
MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
|
||||
|
||||
// First check that there are some non-null stream executors to avoid issuing
|
||||
// an error mentioning replicas in the common case of requesting just 1
|
||||
// replica, which means no replication.
|
||||
CHECK(!stream_executors_.empty())
|
||||
<< "Service found no devices for backend " << platform_->Name() << '.';
|
||||
CHECK_GE(stream_executors_.size(), replica_count)
|
||||
<< "Requested more replicas than there are devices for backend "
|
||||
<< platform_->Name() << '.';
|
||||
|
||||
if (platform->id() == se::host::kHostPlatformId) {
|
||||
inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool(
|
||||
@ -179,36 +159,6 @@ int Backend::default_device_ordinal() const {
|
||||
return default_stream_executor()->device_ordinal();
|
||||
}
|
||||
|
||||
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Backend::Replicas(
|
||||
int device_ordinal) const {
|
||||
if (stream_executors_[device_ordinal] == nullptr) {
|
||||
return InvalidArgument("device %s not supported by XLA service",
|
||||
device_name(device_ordinal).c_str());
|
||||
}
|
||||
|
||||
// Find replica_count_ stream executors starting from the given device
|
||||
// ordinal.
|
||||
std::vector<perftools::gputools::StreamExecutor*> replicas;
|
||||
for (se::StreamExecutor* exec : stream_executors_) {
|
||||
CHECK(exec != nullptr);
|
||||
if (exec->device_ordinal() >= device_ordinal) {
|
||||
replicas.push_back(exec);
|
||||
if (replicas.size() >= replica_count_) {
|
||||
return replicas;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return InvalidArgument(
|
||||
"Not enough devices for replicas for the device ordinal %d",
|
||||
device_ordinal);
|
||||
}
|
||||
|
||||
std::vector<perftools::gputools::StreamExecutor*> Backend::Replicas() const {
|
||||
CHECK_GE(stream_executors_.size(), replica_count_);
|
||||
return Replicas(default_device_ordinal()).ValueOrDie();
|
||||
}
|
||||
|
||||
tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const {
|
||||
return inter_op_thread_pool_.get();
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/pool.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
@ -46,12 +47,6 @@ class BackendOptions {
|
||||
BackendOptions& set_platform(perftools::gputools::Platform* platform);
|
||||
perftools::gputools::Platform* platform() const;
|
||||
|
||||
// Set the number of replicas to use when compiling replicated
|
||||
// programs. The default is -1 meaning that the value is read from
|
||||
// the xla_replicas flag.
|
||||
BackendOptions& set_number_of_replicas(int number_of_replicas);
|
||||
int number_of_replicas() const;
|
||||
|
||||
// Sets the thread pool size for parallel execution of an individual operator.
|
||||
// The default value of -1 will result in initializing the thread pool with
|
||||
// the number of threads equal to the number of cores in the system.
|
||||
@ -60,7 +55,6 @@ class BackendOptions {
|
||||
|
||||
private:
|
||||
perftools::gputools::Platform* platform_ = nullptr;
|
||||
int number_of_replicas_ = -1;
|
||||
int intra_op_parallelism_threads_ = -1;
|
||||
};
|
||||
|
||||
@ -74,8 +68,7 @@ class Backend {
|
||||
public:
|
||||
using StreamPtr = Pool<perftools::gputools::Stream>::SmartPtr;
|
||||
|
||||
// Creates a new backend for the given platform with the given number of
|
||||
// replicas.
|
||||
// Creates a new backend.
|
||||
static StatusOr<std::unique_ptr<Backend>> CreateBackend(
|
||||
const BackendOptions& options);
|
||||
|
||||
@ -92,6 +85,7 @@ class Backend {
|
||||
return memory_allocator_.get();
|
||||
}
|
||||
TransferManager* transfer_manager() const { return transfer_manager_; }
|
||||
ComputationPlacer* computation_placer() const { return computation_placer_; }
|
||||
|
||||
// Returns the number of devices of the platform type which are visible. Not
|
||||
// all of these devices may be usable by XLA.
|
||||
@ -107,24 +101,13 @@ class Backend {
|
||||
return stream_executors_;
|
||||
}
|
||||
|
||||
// Returns the replicas for the default stream executor.
|
||||
//
|
||||
// When the number of replicas is R, the first R stream executors are assigned
|
||||
// to the replicas of the default stream executor.
|
||||
std::vector<perftools::gputools::StreamExecutor*> Replicas() const;
|
||||
|
||||
// Returns the replicas for the given device_ordinal. The given device ordinal
|
||||
// is considered to be the first device ordinal among the replicas. Returns an
|
||||
// error status if the stream executor for the given given device ordinal does
|
||||
// not exist or if there are not enough stream executors for the replicas.
|
||||
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
|
||||
int device_ordinal) const;
|
||||
|
||||
// Return the stream executor for the given device ordinal.
|
||||
// Returns the stream executor for the given device ordinal.
|
||||
StatusOr<perftools::gputools::StreamExecutor*> stream_executor(
|
||||
int device_ordinal) const;
|
||||
|
||||
// Return the stream executor for the default device ordinal.
|
||||
// Returns the stream executor for the default device ordinal. This stream
|
||||
// executor can only be used when the number of computations is 1 (replication
|
||||
// can be > 1).
|
||||
perftools::gputools::StreamExecutor* default_stream_executor() const {
|
||||
CHECK(!stream_executors_.empty());
|
||||
return stream_executors_[0];
|
||||
@ -174,18 +157,19 @@ class Backend {
|
||||
|
||||
private:
|
||||
struct EigenThreadPoolWrapper;
|
||||
Backend(int64 replica_count, perftools::gputools::Platform* platform,
|
||||
Compiler* compiler,
|
||||
Backend(perftools::gputools::Platform* platform, Compiler* compiler,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
stream_executors,
|
||||
TransferManager* transfer_manager, int intra_op_parallelism_threads);
|
||||
TransferManager* transfer_manager,
|
||||
ComputationPlacer* computation_placer,
|
||||
int intra_op_parallelism_threads);
|
||||
Backend(const Backend&) = delete;
|
||||
Backend& operator=(const Backend&) = delete;
|
||||
|
||||
perftools::gputools::Platform* platform_;
|
||||
Compiler* compiler_;
|
||||
TransferManager* transfer_manager_;
|
||||
int64 replica_count_ = -1;
|
||||
ComputationPlacer* computation_placer_;
|
||||
|
||||
// Vector of stream executors. stream_executors_[0] is the default executor.
|
||||
std::vector<perftools::gputools::StreamExecutor*> stream_executors_;
|
||||
|
@ -1074,7 +1074,8 @@ void BufferAssigner::AddSetToColocatedBufferSets(
|
||||
// different while instructions.
|
||||
void BufferAssigner::AddWhileSetToColocatedBufferSets(
|
||||
const std::vector<const LogicalBuffer*>& colocated_set,
|
||||
const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo,
|
||||
const LogicalBuffer* while_init_buffer,
|
||||
const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo,
|
||||
const HloComputation& computation, const BufferLiveness& buffer_liveness,
|
||||
const LogicalBuffer::SizeFunction& buffer_size,
|
||||
std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
|
||||
@ -1137,16 +1138,30 @@ void BufferAssigner::AddWhileSetToColocatedBufferSets(
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip predecessor set if the live range of any predecessor buffers
|
||||
// overlaps with 'while_init_buffer'. Note that tuple element buffer
|
||||
// forwarding can cause the same buffer to appear on both sides of the
|
||||
// interference comparison below.
|
||||
if (std::any_of(
|
||||
predecessor_while_buffers.begin(), predecessor_while_buffers.end(),
|
||||
[while_init_buffer, &buffer_liveness](const LogicalBuffer* buffer) {
|
||||
return while_init_buffer->id() != buffer->id() &&
|
||||
buffer_liveness.MayInterfere(*while_init_buffer, *buffer);
|
||||
})) {
|
||||
// Skip predecessor set if the live range of any predecessor
|
||||
// buffers overlaps with 'while_init_buffer' or
|
||||
// 'while_result_buffer' (we need to check both since they're
|
||||
// aliased together, but the points-to analysis is unaware of this
|
||||
// aliasing). Note that tuple element buffer forwarding can cause
|
||||
// the same buffer to appear on both sides of the interference
|
||||
// comparison below.
|
||||
auto may_interfere_with_init_or_result = [&](const LogicalBuffer* buffer) {
|
||||
if (while_init_buffer->id() != buffer->id() &&
|
||||
buffer_liveness.MayInterfere(*while_init_buffer, *buffer)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (while_result_buffer->id() != buffer->id() &&
|
||||
buffer_liveness.MayInterfere(*while_result_buffer, *buffer)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
if (std::any_of(predecessor_while_buffers.begin(),
|
||||
predecessor_while_buffers.end(),
|
||||
may_interfere_with_init_or_result)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1209,8 +1224,8 @@ void BufferAssigner::BuildColocatedBufferSets(
|
||||
AddBufferToColocatedSet(while_hlo->operand(0), index,
|
||||
points_to_analysis, &colocated_set);
|
||||
// Add while.result.
|
||||
AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
|
||||
&colocated_set);
|
||||
auto* result_buffer = AddBufferToColocatedSet(
|
||||
while_hlo, index, points_to_analysis, &colocated_set);
|
||||
// Add while.cond.parameter.
|
||||
AddBufferToColocatedSet(
|
||||
while_hlo->while_condition()->parameter_instruction(0), index,
|
||||
@ -1224,8 +1239,9 @@ void BufferAssigner::BuildColocatedBufferSets(
|
||||
while_hlo->while_body()->root_instruction(), index,
|
||||
points_to_analysis, &colocated_set);
|
||||
AddWhileSetToColocatedBufferSets(
|
||||
colocated_set, init_buffer, while_hlo, *computation,
|
||||
buffer_liveness, buffer_size, colocated_buffer_sets);
|
||||
colocated_set, init_buffer, result_buffer, while_hlo,
|
||||
*computation, buffer_liveness, buffer_size,
|
||||
colocated_buffer_sets);
|
||||
});
|
||||
} else if (opcode == HloOpcode::kCall) {
|
||||
const HloInstruction* call_hlo = instruction;
|
||||
|
@ -511,7 +511,8 @@ class BufferAssigner {
|
||||
// colocated buffers for while instructions.
|
||||
void AddWhileSetToColocatedBufferSets(
|
||||
const std::vector<const LogicalBuffer*>& colocated_set,
|
||||
const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo,
|
||||
const LogicalBuffer* while_init_buffer,
|
||||
const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo,
|
||||
const HloComputation& computation, const BufferLiveness& buffer_liveness,
|
||||
const LogicalBuffer::SizeFunction& buffer_size,
|
||||
std::vector<ColocatedBufferSet>* colocated_buffer_sets);
|
||||
|
@ -105,7 +105,7 @@ class BufferAssignmentTest : public HloTestBase {
|
||||
auto param =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
|
||||
auto value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
|
||||
return builder.Build();
|
||||
@ -122,7 +122,7 @@ class BufferAssignmentTest : public HloTestBase {
|
||||
const string& name) {
|
||||
auto builder = HloComputation::Builder(name);
|
||||
auto const4 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
|
||||
auto index = builder.AddInstruction(
|
||||
@ -147,9 +147,9 @@ class BufferAssignmentTest : public HloTestBase {
|
||||
const string& name) {
|
||||
auto builder = HloComputation::Builder(name);
|
||||
auto const1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
|
||||
auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
|
||||
auto indexc = builder.AddInstruction(
|
||||
@ -264,7 +264,7 @@ static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
|
||||
TEST_F(BufferAssignmentTest, ScalarConstant) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
@ -278,9 +278,9 @@ TEST_F(BufferAssignmentTest, BufferForConst) {
|
||||
// no buffers assigned, and their consumer has a buffer.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
|
||||
Literal::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
|
||||
auto add = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
|
||||
auto module = CreateNewModule();
|
||||
@ -298,7 +298,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) {
|
||||
// This computation copies a constant to output.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
auto copy = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
|
||||
auto module = CreateNewModule();
|
||||
@ -586,7 +586,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
|
||||
auto exp2 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
|
||||
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
|
||||
/*shape=*/f32vec10_,
|
||||
/*operand=*/exp2,
|
||||
@ -634,9 +634,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) {
|
||||
// Creates the main kernel and verifies instruction counts.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const3 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
|
||||
auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
|
||||
auto tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
|
||||
auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
@ -1075,9 +1075,8 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) {
|
||||
// Test that a tuple constant which is forwarded to the computation output is
|
||||
// properly handled.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
|
||||
LiteralUtil::CreateR0<int64>(1).get()})));
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple(
|
||||
{Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()})));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
@ -1369,9 +1368,9 @@ class WhileBufferAssignmentTest : public HloTestBase {
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
|
||||
auto ten = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(10)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
|
||||
return builder.Build();
|
||||
@ -1429,7 +1428,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
|
||||
HloInstruction::CreateParameter(2, data_shape_, "weights1"));
|
||||
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
|
||||
auto output0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
auto output1 = builder.AddInstruction(
|
||||
@ -1484,7 +1483,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
|
||||
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
|
||||
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
|
||||
auto output0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
auto output1 = builder.AddInstruction(
|
||||
@ -1532,16 +1531,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) {
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param"));
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto add = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
|
||||
sub_computation = module->AddEmbeddedComputation(builder.Build(add));
|
||||
}
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
|
||||
auto constant3 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
|
||||
auto call1 = builder.AddInstruction(
|
||||
HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
|
||||
auto call2 = builder.AddInstruction(
|
||||
@ -1565,6 +1564,104 @@ TEST_F(BufferAssignmentTest, TwoCalls) {
|
||||
EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
|
||||
}
|
||||
|
||||
static bool IsPostOrderTraversal(
|
||||
const std::vector<const HloInstruction*>& sequence) {
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> seen_so_far;
|
||||
auto has_not_been_seen_yet = [&](const HloInstruction* instruction) {
|
||||
return seen_so_far.count(instruction) == 0;
|
||||
};
|
||||
|
||||
for (auto instruction : sequence) {
|
||||
if (std::any_of(instruction->operands().begin(),
|
||||
instruction->operands().end(), has_not_been_seen_yet) ||
|
||||
std::any_of(instruction->control_predecessors().begin(),
|
||||
instruction->control_predecessors().end(),
|
||||
has_not_been_seen_yet)) {
|
||||
return false; // Not a post order.
|
||||
}
|
||||
if (!seen_so_far.insert(instruction).second) {
|
||||
return false; // Not a "traversal".
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
|
||||
auto input0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, data_shape_, "input0"));
|
||||
auto weights0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
|
||||
auto output0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
|
||||
auto input1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, data_shape_, "input1"));
|
||||
auto weights1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(3, data_shape_, "weights1"));
|
||||
auto output1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
|
||||
|
||||
auto cond =
|
||||
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
|
||||
auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
|
||||
|
||||
auto tuple0 = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({input0, weights0, output0}));
|
||||
auto tuple1 = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({input1, weights1, output1}));
|
||||
|
||||
auto while0 = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
|
||||
auto while1 = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
|
||||
|
||||
auto root_add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
while0->shape(), HloOpcode::kAdd, while0, while1));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
RunCopyInsertion(module.get());
|
||||
|
||||
{
|
||||
FlattenCallGraph flatten;
|
||||
TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
auto sequence =
|
||||
CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
|
||||
|
||||
// To trigger b/38494731, we want a specific Hlo sequence for the
|
||||
// root computation, so we overwrite that entry with a manually
|
||||
// crafted sequence.
|
||||
std::vector<const HloInstruction*> sequence_for_buffer_assigment = {
|
||||
input1, weights1, one, output1, tuple1, while1, input0,
|
||||
weights0, zero, output0, tuple0, while0, root_add};
|
||||
|
||||
// If this ASSERT_TRUE fails, we constructed a bogus sequence above
|
||||
// and this test itself is buggy.
|
||||
ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment));
|
||||
|
||||
sequence[module->entry_computation()] =
|
||||
std::move(sequence_for_buffer_assigment);
|
||||
|
||||
auto assignment = BufferAssigner::Run(module.get(),
|
||||
MakeUnique<SequentialHloOrdering>(
|
||||
module.get(), sequence),
|
||||
ByteSizeOf, 1)
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
|
||||
}
|
||||
|
||||
// Test buffer assignment for while nodes with multiple uses.
|
||||
// TODO(b/37245345): Fix buffer assignment for this case.
|
||||
TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) {
|
||||
@ -1577,7 +1674,7 @@ TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) {
|
||||
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
|
||||
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
|
||||
auto output0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
|
||||
|
@ -122,7 +122,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
|
||||
if (b.instruction()->IsUserOf(alias.instruction()) &&
|
||||
!CanShareOperandBufferWithUser(alias.instruction(), alias.index(),
|
||||
b.instruction(), b.index(),
|
||||
points_to_analysis())) {
|
||||
&points_to_analysis())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -397,13 +397,11 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
|
||||
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
|
||||
// the buffers containing {3} and 3 are dead.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto inner_tuple0 =
|
||||
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
|
||||
LiteralUtil::CreateR0<int64>(1).get()});
|
||||
auto inner_tuple1 =
|
||||
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
|
||||
auto inner_tuple0 = Literal::MakeTuple(
|
||||
{Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()});
|
||||
auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0<int64>(3).get()});
|
||||
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
|
||||
Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
inner_tuple0->shape(), tuple_constant, 0));
|
||||
|
||||
@ -450,7 +448,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
tuple_element0_shape, tuple_param0, 0));
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
|
||||
|
||||
@ -462,7 +460,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
tuple_element1_shape, tuple_param0, 1));
|
||||
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
|
||||
Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
|
||||
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
|
||||
|
||||
@ -513,7 +511,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
tuple_element0_shape, tuple_param0, 0));
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
|
||||
|
||||
@ -585,7 +583,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
|
||||
|
||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
Literal::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
HloInstruction* slice = nullptr;
|
||||
if (update_uses_tuple_element1) {
|
||||
// Create a slice instruction as an additional user of 'gte1'.
|
||||
@ -596,7 +594,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
}
|
||||
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
|
||||
auto starts = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
|
||||
auto dynamic_update_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
data_shape, gte1, update, starts));
|
||||
@ -715,7 +713,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
|
||||
|
||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
Literal::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
|
||||
if (tuple_element1_has_two_uses) {
|
||||
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
|
||||
@ -724,7 +722,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
}
|
||||
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
|
||||
auto starts = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
|
||||
auto dynamic_update_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
data_shape, gte1, update, starts));
|
||||
|
@ -133,6 +133,37 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
|
||||
return nodes_[it->second];
|
||||
}
|
||||
|
||||
bool CallGraph::DominatesHelper(
|
||||
const HloComputation* a, const HloComputation* b,
|
||||
tensorflow::gtl::FlatSet<const HloComputation*>* visited) const {
|
||||
if (a == b || ContainsKey(*visited, b)) {
|
||||
// The call graph is guaranteed to be acyclic so any previously visited node
|
||||
// we encounter was already determined to be dominated.
|
||||
return true;
|
||||
}
|
||||
|
||||
const CallGraphNode& b_node = GetNode(b);
|
||||
if (b_node.callers().empty()) {
|
||||
// We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
|
||||
return false;
|
||||
}
|
||||
|
||||
// Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
|
||||
visited->insert(b);
|
||||
for (const HloComputation* b_caller : b_node.callers()) {
|
||||
if (!DominatesHelper(a, b_caller, visited)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CallGraph::Dominates(const HloComputation* a,
|
||||
const HloComputation* b) const {
|
||||
tensorflow::gtl::FlatSet<const HloComputation*> visited;
|
||||
return DominatesHelper(a, b, &visited);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns the call context of a computation which is called from contexts 'a'
|
||||
|
@ -189,6 +189,20 @@ class CallGraph {
|
||||
Status VisitNodes(const VisitorFunction& visitor_func,
|
||||
bool visit_unreachable_nodes = true) const;
|
||||
|
||||
// Returns true if 'a' dominates 'b' in the call graph. Computation 'a'
|
||||
// dominates computation 'b' iff all callgraph paths in the caller-to-callee
|
||||
// direction from a root computation to 'b' pass through computation
|
||||
// 'a'. Trivially, a computation dominates itself.
|
||||
bool Dominates(const HloComputation* a, const HloComputation* b) const;
|
||||
|
||||
// Returns whether 'instruction' is contained in 'computation' either directly
|
||||
// ('instruction->parent' is 'computation') or indirectly ('computation'
|
||||
// dominates 'instruction->parent' in the call graph).
|
||||
bool InstructionIsNestedIn(const HloInstruction* instruction,
|
||||
const HloComputation* computation) const {
|
||||
return Dominates(computation, instruction->parent());
|
||||
}
|
||||
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
@ -205,6 +219,13 @@ class CallGraph {
|
||||
const VisitorFunction& visitor_func, const CallGraphNode& node,
|
||||
tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const;
|
||||
|
||||
// Recursive helper for computing whether 'a' dominates 'b' in the call
|
||||
// graph. 'b_ancestor' is the currently visited node (which starts at 'b'),
|
||||
// and 'visited' is the set of computations which have been visited.
|
||||
bool DominatesHelper(
|
||||
const HloComputation* a, const HloComputation* b,
|
||||
tensorflow::gtl::FlatSet<const HloComputation*>* visited) const;
|
||||
|
||||
// The HLO module represented by this call graph.
|
||||
const HloModule* module_ = nullptr;
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user