Merge commit for internal changes

This commit is contained in:
Yifei Feng 2017-06-21 10:25:00 -07:00
commit f69c7569cd
1131 changed files with 36858 additions and 89628 deletions

View File

@ -39,7 +39,7 @@ config_setting(
config_setting(
name = "android_armeabi",
values = {
"cc_target_os": "android",
"crosstool_top": "//external:android/crosstool",
"cpu": "armeabi",
},
visibility = ["//visibility:public"],
@ -218,7 +218,9 @@ filegroup(
"//tensorflow/compiler/jit/ops:all_files",
"//tensorflow/compiler/tests:all_files",
"//tensorflow/compiler/tf2xla:all_files",
"//tensorflow/compiler/tf2xla/cc:all_files",
"//tensorflow/compiler/tf2xla/kernels:all_files",
"//tensorflow/compiler/tf2xla/ops:all_files",
"//tensorflow/compiler/xla:all_files",
"//tensorflow/compiler/xla/client:all_files",
"//tensorflow/compiler/xla/client/lib:all_files",
@ -253,7 +255,7 @@ filegroup(
"//tensorflow/contrib/data/python/kernel_tests:all_files",
"//tensorflow/contrib/data/python/ops:all_files",
"//tensorflow/contrib/data/python/util:all_files",
"//tensorflow/contrib/decision_trees:all_files",
"//tensorflow/contrib/decision_trees/proto:all_files",
"//tensorflow/contrib/distributions:all_files",
"//tensorflow/contrib/factorization:all_files",
"//tensorflow/contrib/factorization/kernels:all_files",
@ -284,6 +286,8 @@ filegroup(
"//tensorflow/contrib/ndlstm:all_files",
"//tensorflow/contrib/nn:all_files",
"//tensorflow/contrib/opt:all_files",
"//tensorflow/contrib/predictor:all_files",
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
"//tensorflow/contrib/rnn:all_files",
"//tensorflow/contrib/saved_model:all_files",
"//tensorflow/contrib/saved_model/cc/saved_model:all_files",
@ -302,10 +306,13 @@ filegroup(
"//tensorflow/contrib/stateless:all_files",
"//tensorflow/contrib/tensor_forest:all_files",
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
"//tensorflow/contrib/tensor_forest/kernels/v4:all_files",
"//tensorflow/contrib/tensor_forest/proto:all_files",
"//tensorflow/contrib/tensorboard:all_files",
"//tensorflow/contrib/testing:all_files",
"//tensorflow/contrib/text:all_files",
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
"//tensorflow/contrib/tpu:all_files",
"//tensorflow/contrib/training:all_files",
"//tensorflow/contrib/util:all_files",
"//tensorflow/contrib/verbs:all_files",
@ -353,70 +360,6 @@ filegroup(
"//tensorflow/python/ops/distributions:all_files",
"//tensorflow/python/saved_model:all_files",
"//tensorflow/python/tools:all_files",
"//tensorflow/tensorboard:all_files",
"//tensorflow/tensorboard/backend:all_files",
"//tensorflow/tensorboard/backend/event_processing:all_files",
"//tensorflow/tensorboard/components:all_files",
"//tensorflow/tensorboard/components/tf_audio_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_audio_dashboard/test:all_files",
"//tensorflow/tensorboard/components/tf_backend:all_files",
"//tensorflow/tensorboard/components/tf_backend/test:all_files",
"//tensorflow/tensorboard/components/tf_color_scale:all_files",
"//tensorflow/tensorboard/components/tf_color_scale/test:all_files",
"//tensorflow/tensorboard/components/tf_dashboard_common:all_files",
"//tensorflow/tensorboard/components/tf_dashboard_common/test:all_files",
"//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_globals:all_files",
"//tensorflow/tensorboard/components/tf_graph:all_files",
"//tensorflow/tensorboard/components/tf_graph/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_app:all_files",
"//tensorflow/tensorboard/components/tf_graph_app/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_board:all_files",
"//tensorflow/tensorboard/components/tf_graph_board/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_common:all_files",
"//tensorflow/tensorboard/components/tf_graph_controls:all_files",
"//tensorflow/tensorboard/components/tf_graph_controls/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_graph_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_debugger_data_card:all_files",
"//tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_info:all_files",
"//tensorflow/tensorboard/components/tf_graph_info/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_loader:all_files",
"//tensorflow/tensorboard/components/tf_graph_loader/demo:all_files",
"//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_image_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_imports:all_files",
"//tensorflow/tensorboard/components/tf_option_selector:all_files",
"//tensorflow/tensorboard/components/tf_profile_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_profile_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_runs_selector:all_files",
"//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_storage:all_files",
"//tensorflow/tensorboard/components/tf_storage/test:all_files",
"//tensorflow/tensorboard/components/tf_tensorboard:all_files",
"//tensorflow/tensorboard/components/tf_text_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_trace_viewer:all_files",
"//tensorflow/tensorboard/components/vz_distribution_chart:all_files",
"//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files",
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
"//tensorflow/tensorboard/components/vz_projector:all_files",
"//tensorflow/tensorboard/components/vz_projector/test:all_files",
"//tensorflow/tensorboard/components/vz_sorting:all_files",
"//tensorflow/tensorboard/components/vz_sorting/test:all_files",
"//tensorflow/tensorboard/demo:all_files",
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
"//tensorflow/tensorboard/plugins:all_files",
"//tensorflow/tensorboard/plugins/audio:all_files",
"//tensorflow/tensorboard/plugins/distributions:all_files",
"//tensorflow/tensorboard/plugins/graphs:all_files",
"//tensorflow/tensorboard/plugins/histograms:all_files",
"//tensorflow/tensorboard/plugins/images:all_files",
"//tensorflow/tensorboard/plugins/projector:all_files",
"//tensorflow/tensorboard/plugins/scalars:all_files",
"//tensorflow/tensorboard/plugins/text:all_files",
"//tensorflow/tensorboard/scripts:all_files",
"//tensorflow/tools/api/golden:all_files",
"//tensorflow/tools/api/lib:all_files",
"//tensorflow/tools/api/tests:all_files",

View File

@ -628,7 +628,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
// Target nodes
const char** c_target_oper_names, int ntargets,
const char** handle, TF_Status* status) {
status->status = Status::OK();
*handle = nullptr;
std::vector<tensorflow::string> input_names(ninputs);
std::vector<tensorflow::string> output_names(noutputs);
@ -643,16 +643,12 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
target_oper_names[i] = c_target_oper_names[i];
}
tensorflow::string new_handle;
Status result;
result = s->session->PRunSetup(input_names, output_names, target_oper_names,
&new_handle);
if (result.ok()) {
status->status = s->session->PRunSetup(input_names, output_names,
target_oper_names, &new_handle);
if (status->status.ok()) {
char* buf = new char[new_handle.size() + 1];
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
*handle = buf;
} else {
*handle = nullptr;
status->status = result;
}
}
@ -2326,6 +2322,8 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
int ninputs, const TF_Output* outputs, int noutputs,
const TF_Operation* const* target_opers, int ntargets,
const char** handle, TF_Status* status) {
*handle = nullptr;
if (!ExtendSessionGraphHelper(session, status)) {
return;
}

View File

@ -1101,8 +1101,7 @@ TF_CAPI_EXPORT extern void TF_SessionRun(
// needed.
//
// On failure, out_status contains a tensorflow::Status with an error
// message.
// NOTE: This is EXPERIMENTAL and subject to change.
// message. *handle is set to nullptr.
TF_CAPI_EXPORT extern void TF_SessionPRunSetup(
TF_Session*,
// Input names
@ -1118,7 +1117,6 @@ TF_CAPI_EXPORT extern void TF_SessionPRunSetup(
// Continue to run the graph with additional feeds and fetches. The
// execution state is uniquely identified by the handle.
// NOTE: This is EXPERIMENTAL and subject to change.
TF_CAPI_EXPORT extern void TF_SessionPRun(
TF_Session*, const char* handle,
// Input tensors

View File

@ -61,7 +61,6 @@ cc_library(
":gradients",
":ops",
":scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@ -274,10 +273,6 @@ cc_library(
deps = [
":cc_ops",
":grad_op_registry",
":ops",
":scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
],
)
@ -305,10 +300,6 @@ cc_library(
":cc_ops",
":cc_ops_internal",
":grad_op_registry",
":ops",
":scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
],
)
@ -527,7 +518,6 @@ cc_library(
deps = [
":coordinator",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
@ -560,8 +550,6 @@ cc_library(
srcs = ["training/coordinator.cc"],
hdrs = ["training/coordinator.h"],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",

View File

@ -21,6 +21,9 @@ namespace tensorflow {
/// SavedModel assets directory.
constexpr char kSavedModelAssetsDirectory[] = "assets";
/// SavedModel assets.extra directory.
constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra";
/// SavedModel assets key for graph collection-def.
constexpr char kSavedModelAssetsKey[] = "saved_model_assets";

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
@ -76,8 +77,16 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
return Status::OK();
}
}
string tags_as_string = "{ ";
for (const string& tag : tags) {
tags_as_string = strings::StrCat(tags_as_string, tag, " ");
}
tags_as_string = strings::StrCat(tags_as_string, "}");
return Status(error::Code::NOT_FOUND,
"Could not find meta graph def matching supplied tags.");
"Could not find meta graph def matching supplied tags: " +
tags_as_string +
". To inspect available tag-sets in the SavedModel, please "
"use the SavedModel CLI: `saved_model_cli`");
}
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,

View File

@ -133,9 +133,9 @@ TEST_F(LoaderTest, NoTagMatch) {
Status st = LoadSavedModel(session_options, run_options, export_dir,
{"missing-tag"}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(
StringPiece(st.error_message())
.contains("Could not find meta graph def matching supplied tags."))
EXPECT_TRUE(StringPiece(st.error_message())
.contains("Could not find meta graph def matching supplied "
"tags: { missing-tag }"))
<< st.error_message();
}
@ -151,7 +151,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
EXPECT_FALSE(st.ok());
EXPECT_TRUE(
StringPiece(st.error_message())
.contains("Could not find meta graph def matching supplied tags."))
.contains("Could not find meta graph def matching supplied tags: "))
<< st.error_message();
}

View File

@ -18,10 +18,13 @@ limitations under the License.
namespace tensorflow {
/// Tag for the `gpu` graph.
constexpr char kSavedModelTagGpu[] = "gpu";
/// Tag for the `serving` graph.
constexpr char kSavedModelTagServe[] = "serve";
/// Tag for the `training` graph.`
/// Tag for the `training` graph.
constexpr char kSavedModelTagTrain[] = "train";
} // namespace tensorflow

View File

@ -126,14 +126,11 @@ cc_library(
deps = [
":tfcompile_lib",
":tfcompile_proto",
"//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags",
"//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags",
"//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags",
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags",
"//tensorflow/compiler/xla/legacy_flags:llvm_util_flags",
"//tensorflow/compiler/xla/legacy_flags:service_flags",
"//tensorflow/compiler/xla/legacy_flags:util_flags",
"//tensorflow/compiler/xla/service:compiler",

View File

@ -23,14 +23,11 @@ limitations under the License.
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/util_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
@ -136,14 +133,11 @@ int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
xla::legacy_flags::AppendAliasAnalysisFlags(&flag_list);
xla::legacy_flags::AppendBufferAssignmentFlags(&flag_list);
xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list);
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list);
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::legacy_flags::AppendLlvmUtilFlags(&flag_list);
xla::legacy_flags::AppendServiceFlags(&flag_list);
xla::legacy_flags::AppendUtilFlags(&flag_list);

View File

@ -22,20 +22,6 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# This target can be used by XLA device plugins to prevent circular
# dependencies, and provides access to all of the required headers
# for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
visibility = ["//visibility:public"],
deps = [
":xla_cpu_device",
":xla_cpu_jit",
":xla_gpu_device",
":xla_gpu_jit",
],
)
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
name = "jit",
@ -283,3 +269,15 @@ filegroup(
),
visibility = ["//tensorflow:__subpackages__"],
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
visibility = ["//visibility:public"],
deps = [
":xla_cpu_device",
":xla_cpu_jit",
":xla_gpu_device",
":xla_gpu_jit",
],
)

View File

@ -38,6 +38,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/allocator.h"
@ -149,6 +150,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
xla::ExecutionOptions execution_options;
*execution_options.mutable_shape_with_output_layout() =
kernel->xla_output_shape;
*execution_options.mutable_debug_options() =
xla::legacy_flags::GetDebugOptionsFromFlags();
Env* env = Env::Default();
auto start_time = env->NowMicros();
VLOG(1) << "Executing XLA Computation...";
@ -202,8 +205,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
// Apply variable updates, if any.
VLOG(2) << "Applying variable updates";
for (int i = 0; i < kernel->variable_updates.size(); ++i) {
const XlaCompiler::VariableUpdate& write = kernel->variable_updates[i];
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
OP_REQUIRES(ctx,
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
errors::Internal("Invalid input index for variable write."));

View File

@ -1,32 +1,20 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//tensorflow/compiler/tf2xla:internal",
],
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
cc_library(
name = "xla_ops",
srcs = [
"xla_ops.cc",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
srcs = ["xla_ops.cc"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)
cc_library(
name = "parallel_check_op",
srcs = ["parallel_check_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)

View File

@ -182,17 +182,18 @@ Status BuildArguments(int num_constant_args,
XlaCompiler::Argument& arg = (*args)[input_num];
arg.name = variable_args[variable_id].name;
arg.kind = XlaCompiler::Argument::kVariable;
if (variable_args[variable_id].present) {
const Tensor& value = variable_args[variable_id].value;
arg.kind = XlaCompiler::Argument::kVariable;
arg.type = value.dtype();
arg.shape = value.shape();
arg.initialized = true;
} else {
// The values of uninitialized variables are not passed as inputs, since
// they are meaningless. However, it is legal to assign to a resource
// variable for the first time inside the XLA computation, so we do permit
// uninitialized variables.
arg.kind = XlaCompiler::Argument::kUninitializedVariable;
arg.initialized = false;
arg.type = DT_INVALID;
arg.shape = TensorShape();
}

View File

@ -137,7 +137,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
done(result.status());
return;
}
const void* src_ptr = xla::LiteralUtil::InternalData(*result.ValueOrDie());
const void* src_ptr = result.ValueOrDie()->InternalData();
void* dst_ptr = DMAHelper::base(cpu_tensor);
size_t total_bytes = cpu_tensor->TotalBytes();
memcpy(dst_ptr, src_ptr, total_bytes);

View File

@ -40,6 +40,7 @@ py_library(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:random_seed",
"//tensorflow/python:variables",
],
)
@ -323,7 +324,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "reverse_ops_test",
size = "small",
size = "medium",
srcs = ["reverse_ops_test.py"],
deps = [
":xla_test",

View File

@ -228,34 +228,40 @@ class SpaceToBatchNDTest(XLATestCase):
outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]],
[[4, 41], [6, 61]]])
def testDirect(self):
def testDirect0(self):
# Test with zero-size remaining dimension.
self._testDirect(
input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]])
def testDirect1(self):
# Test with zero-size blocked dimension.
self._testDirect(
input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]])
def testDirect2(self):
# Test with padding up from zero size.
self._testDirect(
input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]])
def testDirect3(self):
self._testDirect(
input_shape=[3, 3, 4, 5, 2],
block_shape=[3, 4, 2],
paddings=[[1, 2], [0, 0], [3, 0]])
def testDirect4(self):
self._testDirect(
input_shape=[3, 3, 4, 5, 2],
block_shape=[3, 4, 2, 2],
paddings=[[1, 2], [0, 0], [3, 0], [0, 0]])
def testDirect5(self):
self._testDirect(
input_shape=[3, 2, 2, 3, 4, 5, 2, 5],
block_shape=[1, 1, 3, 4, 2, 2],
paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]])
def testDirect6(self):
self._testDirect(
input_shape=[3, 2, 2, 3, 4, 5, 2, 5],
block_shape=[1, 1, 3, 4, 2, 2, 1],

View File

@ -335,7 +335,7 @@ class TensorArrayTest(xla_test.XLATestCase):
r0_bad = gen_data_flow_ops._tensor_array_read_v3(
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
with self.assertRaisesOpError(
"TensorArray dtype is float but Op requested dtype double."):
"TensorArray dtype is float but op has dtype double."):
r0_bad.eval()
# Test reading from a different index than the one we wrote to
@ -573,13 +573,12 @@ class TensorArrayTest(xla_test.XLATestCase):
[2000.0, -2000.0]],
grad_vals[0])
# TODO(phawkins): implement TensorArrayClose
# def testCloseTensorArray(self):
# with self.test_session() as session, self.test_scope():
# ta = tensor_array_ops.TensorArray(
# dtype=dtypes.float32, tensor_array_name="foo", size=3)
# c1 = ta.close()
# session.run(c1)
def testCloseTensorArray(self):
with self.test_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
c1 = ta.close()
session.run(c1)
def testSizeTensorArray(self):
with self.test_session(), self.test_scope():
@ -588,17 +587,16 @@ class TensorArrayTest(xla_test.XLATestCase):
s = ta.size()
self.assertAllEqual(3, s.eval())
# TODO(phawkins): implement TensorArrayClose
# def testWriteCloseTensorArray(self):
# with self.test_session(), self.test_scope():
# ta = tensor_array_ops.TensorArray(
# dtype=dtypes.float32,
# tensor_array_name="foo",
# size=3,
# infer_shape=False)
# w0 = ta.write(0, [[4.0, 5.0]])
# w1 = w0.write(1, [3.0])
# w1.close().run() # Expected to run without problems
def testWriteCloseTensorArray(self):
with self.test_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=3,
infer_shape=False)
w0 = ta.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [3.0])
w1.close().run() # Expected to run without problems
# TODO(phawkins): implement while loops.
# def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):

View File

@ -42,6 +42,7 @@ cc_library(
deps = [
":common",
":dump_graph",
":functionalize_control_flow",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@ -152,7 +153,6 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
)
@ -165,13 +165,10 @@ cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
@ -203,6 +200,58 @@ cc_library(
],
)
cc_library(
name = "functionalize_control_flow",
srcs = ["functionalize_control_flow.cc"],
hdrs = ["functionalize_control_flow.h"],
deps = [
"//tensorflow/compiler/jit:graph_to_functiondef",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla/ops:functional_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
],
)
cc_test(
name = "functionalize_control_flow_test",
srcs = ["functionalize_control_flow_test.cc"],
deps = [
":functionalize_control_flow",
":test_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/compiler/tf2xla/cc:functional_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:ops",
"//tensorflow/core:resource_variable_ops_op_lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "test_util",
testonly = 1,
srcs = ["test_util.cc"],
hdrs = ["test_util.h"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -0,0 +1,44 @@
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
tf_gen_op_wrapper_cc(
name = "functional_ops_gen",
include_internal_ops = 1,
out_ops_file = "ops/functional_ops",
deps = ["//tensorflow/compiler/tf2xla/ops:functional_ops"],
)
cc_library(
name = "functional_ops",
srcs = ["ops/functional_ops.cc"],
hdrs = ["ops/functional_ops.h"],
deps = [
"//tensorflow/cc:const_op",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/compiler/tf2xla/ops:functional_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,566 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include <algorithm>
#include <deque>
#include <unordered_set>
#include <vector>
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/control_flow.h"
namespace tensorflow {
namespace {
const char* const kArgOp = "_Arg";
const char* const kRetValOp = "_Retval";
// Information about a loop argument.
struct Arg {
// Every loop argument has an Enter node.
Node* enter;
// Is the loop argument a loop-invariant value? Taken from the `is_constant`
// attribute on the Enter node.
bool is_loop_invariant;
// If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
// arguments must have all of the following nodes:
Node* merge = nullptr;
Node* switch_node = nullptr;
Node* next_iteration = nullptr;
Node* exit = nullptr;
};
// Information about a loop frame.
struct Frame {
string name;
// Pointer to the parent frame. The root frame has a pointer to itself.
Frame* parent = nullptr;
int num_children = 0;
// Arguments to this loop.
std::vector<Arg> args;
// The loop condition of the loop. There should be exactly one loop condition
// in every loop.
Node* loop_cond = nullptr;
// Set of nodes that belong to the loop frame.
std::unordered_set<Node*> nodes;
};
// Copies a subgraph from `graph` to `output` by performing a reverse DFS
// starting at nodes in vector `stack`.
// `node_map` is a vector indexed by source node ID to dest nodes.
// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
// before the traversal clients can cut the graph. Returns an error if the
// traversal leaves 'frame'; the client must add enough nodes to `node_map` to
// cut the graph and prevent the traversal from escaping.
//
// `squash_src_outputs` contains a bool for each source node ID. If true, then
// the source output on that node will be replaced by zero when copied. This is
// used when replacing a Switch node with an _Arg node. The output we are
// taking from the Switch node was not necessarily the first output, but _Arg
// nodes only have one output. By adding the Switch node to `squash_src_outputs`
// we rewrite the src_output of the corresponding edge to be 0.
Status CopySubgraph(const Graph& graph, const Frame& frame,
std::vector<Node*> stack,
const std::vector<bool>& squash_src_outputs,
std::vector<Node*>* node_map, Graph* output) {
std::vector<bool> visited(graph.num_node_ids(), false);
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
VLOG(3) << "Copying node " << n->name();
if (visited[n->id()]) continue;
visited[n->id()] = true;
for (const Edge* e : n->in_edges()) {
Node* src = e->src();
if (frame.nodes.find(src) == frame.nodes.end()) {
// We traversed out of the loop frame, without encountering a cut node.
return errors::Internal("Graph traversal of loop frame ", frame.name,
" escaped frame at ", src->name(),
" without encountering an argument node.");
}
if ((*node_map)[src->id()] == nullptr) {
(*node_map)[src->id()] = output->CopyNode(src);
stack.push_back(src);
}
Node* src_copy = (*node_map)[e->src()->id()];
int src_output = squash_src_outputs[e->src()->id()] ? 0 : e->src_output();
Node* dst_copy = (*node_map)[e->dst()->id()];
output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
}
}
return Status::OK();
}
Status BuildArgNode(Graph* graph, DataType type, int index, Node** arg_node) {
NodeDef arg_def;
NodeDefBuilder builder(strings::StrCat("_Arg", index), kArgOp);
builder.Attr("T", type);
builder.Attr("index", index);
TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
Status status;
*arg_node = graph->AddNode(arg_def, &status);
return status;
}
Status BuildRetvalNode(Graph* graph, DataType type, int index,
Node** retval_node) {
NodeDef ret_def;
ret_def.set_op(kRetValOp);
ret_def.set_name(strings::StrCat("_Retval", index));
AddNodeAttr("T", type, &ret_def);
AddNodeAttr("index", index, &ret_def);
Status status;
*retval_node = graph->AddNode(ret_def, &status);
return status;
}
// Builds a graph for the loop condition.
Status BuildLoopCondition(const Graph& graph, Frame* frame,
std::unique_ptr<Graph>* cond_output) {
VLOG(2) << "Building loop condition for " << frame->name;
*cond_output = xla::MakeUnique<Graph>(graph.op_registry());
Graph* output = cond_output->get();
// Map from nodes in the original graph to the condition graph.
std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
// Build one _Arg node for each Enter node.
for (int i = 0; i < frame->args.size(); ++i) {
const Arg& arg = frame->args[i];
Node* arg_node;
TF_RETURN_IF_ERROR(
BuildArgNode(output, arg.enter->input_type(0), i, &arg_node));
if (arg.is_loop_invariant) {
node_map[arg.enter->id()] = arg_node;
} else {
node_map[arg.merge->id()] = arg_node;
}
}
// Build a Retval node for the loop condition. The LoopCond nodes are always
// boolean because of the type constraints on the LoopCond op.
TF_RETURN_IF_ERROR(
BuildRetvalNode(output, DT_BOOL, 0, &node_map[frame->loop_cond->id()]));
// Performs a reverse DFS, copying nodes and edges to the output graph.
// The _Arg and _Retval nodes were added unconditionally above, so we are
// guaranteed to get the correct function signature.
TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, {frame->loop_cond},
squash_src_outputs, &node_map, output));
return Status::OK();
}
// Builds a graph for the loop body.
Status BuildLoopBody(const Graph& graph, Frame* frame,
DataTypeVector* arg_types,
std::unique_ptr<Graph>* body_output) {
VLOG(2) << "Building loop body for " << frame->name;
*body_output = xla::MakeUnique<Graph>(graph.op_registry());
Graph* output = body_output->get();
// Map from nodes in the original graph to the condition graph.
std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
// Build one _Arg node for each Enter node.
std::vector<Node*> next_iterations;
next_iterations.reserve(frame->args.size());
arg_types->reserve(frame->args.size());
for (int i = 0; i < frame->args.size(); ++i) {
const Arg& arg = frame->args[i];
DataType dtype = arg.enter->input_type(0);
arg_types->push_back(dtype);
Node* arg_node;
TF_RETURN_IF_ERROR(BuildArgNode(output, dtype, i, &arg_node));
if (dtype == DT_RESOURCE) {
// The convention of the XLA bridge is that resource variable arguments
// are only inputs to the loop body and have no corresponding output.
// TODO(b/37741920): change the convention so that DT_RESOURCE variables
// are both inputs and outputs, and then remove this case.
TF_RET_CHECK(arg.is_loop_invariant);
node_map[arg.enter->id()] = arg_node;
} else {
Node* retval_node;
TF_RETURN_IF_ERROR(BuildRetvalNode(output, dtype, i, &retval_node));
if (arg.is_loop_invariant) {
// Argument is loop-invariant. Forward it from the Arg to the Retval.
node_map[arg.enter->id()] = arg_node;
output->AddEdge(arg_node, 0, retval_node, 0);
} else {
// Argument is loop-varying.
node_map[arg.switch_node->id()] = arg_node;
// The Switch node has two outputs, but _Arg only has one. This tells
// the CopySubgraph function to rewrite the output number of edges from
// the _Arg node to be 0 rather than copying the output number from the
// Switch node.
squash_src_outputs[arg.switch_node->id()] = true;
node_map[arg.next_iteration->id()] = retval_node;
next_iterations.push_back(arg.next_iteration);
}
}
}
// Performs a reverse DFS, copying nodes and edges to the output graph.
// The _Arg and _Retval nodes were added unconditionally above, so we are
// guaranteed to get the correct function signature.
TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, std::move(next_iterations),
squash_src_outputs, &node_map, output));
return Status::OK();
}
Status FunctionalizeLoop(Graph* graph, Frame* frame,
FunctionLibraryDefinition* library) {
VLOG(2) << "Frame " << frame->name << " before: "
<< dump_graph::DumpGraphToFile("functionalize_before", *graph);
// Split loop-varying Enter nodes with multiple successors. If the same
// Tensor is fed as input to multiple loop arguments, we may end up with a
// shared Enter node. We clone Enter nodes with multiple successors to
// maintain the invariant of a unique Enter node per argument of the final
// loop.
std::vector<Arg> args;
for (const Arg& arg : frame->args) {
if (arg.is_loop_invariant) {
args.push_back(arg);
} else {
std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
arg.enter->out_edges().end());
for (int i = 0; i < edges.size(); ++i) {
TF_RET_CHECK(!edges[i]->IsControlEdge());
Arg new_arg;
new_arg.is_loop_invariant = false;
if (i == 0) {
new_arg.enter = arg.enter;
} else {
new_arg.enter = graph->CopyNode(arg.enter);
frame->nodes.insert(new_arg.enter);
for (Edge const* e : arg.enter->in_edges()) {
graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
e->IsControlEdge() ? Graph::kControlSlot : 0);
}
Node* dst = edges[i]->dst();
int dst_input = edges[i]->dst_input();
graph->RemoveEdge(edges[i]);
graph->AddEdge(new_arg.enter, 0, dst, dst_input);
}
args.push_back(new_arg);
}
}
}
frame->args = std::move(args);
// Order the arguments so that:
// a) resource variables are last, and
// b) sort lexicographically by name (for deterministic output).
std::sort(frame->args.begin(), frame->args.end(),
[](const Arg& a, const Arg& b) {
bool a_is_resource = (a.enter->input_type(0) == DT_RESOURCE);
bool b_is_resource = (b.enter->input_type(0) == DT_RESOURCE);
return std::tie(a_is_resource, a.enter->name()) <
std::tie(b_is_resource, b.enter->name());
});
if (frame->loop_cond == nullptr) {
return errors::InvalidArgument("Loop ", frame->name,
" has no LoopCond node");
}
// Find the set of Switch nodes that are successors of the LoopCond.
std::unordered_set<Node*> switches;
for (const Edge* edge : frame->loop_cond->out_edges()) {
if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
edge->dst_input() == 1) {
switches.insert(edge->dst());
}
}
// For each non-constant argument, looks for the following pattern of nodes:
// Enter ----> Merge --------> Switch --> Exit
// ^ ^
// | |
// NextIteration LoopCond
// ^ ^
// | |
// ... ...
for (Arg& arg : frame->args) {
if (!arg.is_loop_invariant) {
// Follow the edge from the Enter to Merge.
if (arg.enter->out_edges().size() != 1) {
return errors::Internal("Enter node for loop-varying argument ",
arg.enter->name(),
" does not have exactly one successor");
}
const Edge* enter_merge = *arg.enter->out_edges().begin();
arg.merge = enter_merge->dst();
if (!IsMerge(arg.merge)) {
return errors::InvalidArgument(
"Successor of Enter node for loop-varying argument ",
arg.merge->name(),
" is not a Merge node; got: ", arg.merge->type_string());
}
// Find the NextIteration from the merge. There should be two inputs to
// the Merge and the NextIteration should be the other input.
if (arg.merge->input_types().size() != 2) {
return errors::InvalidArgument(
"Unexpected number of inputs to Merge node for loop-varying "
"argument ",
arg.merge->name(), "; expected 2, got ",
arg.merge->input_types().size());
}
TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
&arg.next_iteration));
if (!IsNextIteration(arg.next_iteration)) {
return errors::InvalidArgument(
"Expected NextIteration node as input to Merge node; got node ",
arg.next_iteration->name(), " with kind ",
arg.next_iteration->type_string());
}
// Find the Switch successor of the Merge. There should be exactly one
// Switch node that is a successor of both the Merge and the LoopCond.
for (const Edge* edge : arg.merge->out_edges()) {
if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
switches.find(edge->dst()) != switches.end()) {
if (arg.switch_node != nullptr) {
return errors::InvalidArgument("Duplicate Switch successors to ",
arg.merge->name());
}
arg.switch_node = edge->dst();
}
}
if (arg.switch_node == nullptr) {
return errors::InvalidArgument("Missing Switch successor to ",
arg.merge->name());
}
// Find the Exit successor of the Switch.
for (const Edge* edge : arg.switch_node->out_edges()) {
if (edge->src_output() == 0 && IsExit(edge->dst())) {
if (arg.exit != nullptr) {
return errors::InvalidArgument("Duplicate Exit successors to ",
arg.switch_node->name());
}
arg.exit = edge->dst();
}
}
if (arg.exit == nullptr) {
return errors::InvalidArgument("Mising Exit successor to ",
arg.switch_node->name());
}
}
}
// Builds the condition and body functions.
std::unique_ptr<Graph> cond_graph;
TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
DataTypeVector arg_types;
std::unique_ptr<Graph> body_graph;
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
VLOG(2) << "Frame " << frame->name << " condition: "
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph)
<< " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
static std::atomic<int64> sequence_num(0LL);
int64 id = ++sequence_num;
NameAttrList cond_name;
cond_name.set_name(strings::StrCat("_functionalize_cond_", id));
NameAttrList body_name;
body_name.set_name(strings::StrCat("_functionalize_body_", id));
FunctionDef cond_fdef;
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
FunctionDef body_fdef;
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
// Builds a While operator.
NodeDef while_def;
NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
builder.Attr("T", arg_types);
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
std::vector<NodeDefBuilder::NodeOut> inputs;
for (int i = 0; i < frame->args.size(); ++i) {
const Arg& arg = frame->args[i];
const Edge* in_edge;
TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
if (in_edge->IsControlEdge()) {
builder.ControlInput(in_edge->src()->name());
} else {
inputs.push_back(NodeDefBuilder::NodeOut(
in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
}
}
builder.Input(inputs);
TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
Status status;
Node* while_node = graph->AddNode(while_def, &status);
if (!status.ok()) {
return status;
}
// Copies edges to the Enter nodes and from the Exit nodes onto the While.
for (int i = 0; i < frame->args.size(); ++i) {
const Arg& arg = frame->args[i];
const Edge* in_edge;
TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
if (in_edge->IsControlEdge()) {
graph->AddControlEdge(in_edge->src(), while_node);
} else {
graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
}
if (!arg.is_loop_invariant) {
std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
arg.exit->out_edges().end());
for (const Edge* edge : edges) {
Node* dst = edge->dst();
int dst_input = edge->dst_input();
graph->RemoveEdge(edge);
int src_output =
dst_input == Graph::kControlSlot ? Graph::kControlSlot : i;
graph->AddEdge(while_node, src_output, dst, dst_input);
}
}
}
// Remove the old nodes from the graph, and add the while node to the parent
// frame.
for (Node* node : frame->nodes) {
graph->RemoveNode(node);
}
frame->parent->nodes.insert(while_node);
VLOG(2) << "Frame " << frame->name << " after: "
<< dump_graph::DumpGraphToFile("functionalize_after", *graph);
return Status::OK();
}
} // namespace
// Transformation that converts Tensorflow's graph control flow constructs into
// functional equivalents.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "FunctionalizeControlFlow: "
<< dump_graph::DumpGraphToFile("functionalize_initial", *graph);
// Note: BuildControlFlowInfo() requires that the graph's source node is
// connected to all source nodes in the graph. Many graphs violate this
// invariant.
std::vector<ControlFlowInfo> cf_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info));
// Builds Frames, indexed by name.
std::unordered_map<string, Frame> frames;
for (Node* node : graph->op_nodes()) {
const ControlFlowInfo& cf = cf_info[node->id()];
VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name
<< " frame: " << (cf.frame ? cf.frame->name() : "---")
<< " parent_frame: "
<< (cf.parent_frame ? cf.parent_frame->name() : "---");
TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
Frame& frame = frames[cf.frame_name];
Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
if (frame.parent == nullptr) {
frame.parent = parent;
frame.name = cf.frame_name;
++parent->num_children;
} else if (frame.parent != parent) {
return errors::InvalidArgument("Mismatched parent frames for ",
cf.frame->id(), ": ", parent->name, " vs ",
frame.parent->name);
}
if (IsEnter(node)) {
Arg arg;
arg.enter = node;
TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
&arg.is_loop_invariant));
frame.args.push_back(arg);
} else if (IsLoopCond(node)) {
if (frame.loop_cond) {
return errors::InvalidArgument(
"Loop ", cf.frame_name,
" has more than one LoopCond node: ", node->name(), " and ",
frame.loop_cond->name());
}
frame.loop_cond = node;
}
frame.nodes.insert(node);
}
// Adds frames with no children (i.e., the innermost frames) to a worklist.
std::deque<Frame*> worklist;
for (auto& frame : frames) {
if (frame.second.num_children == 0) {
worklist.push_back(&frame.second);
}
}
// Eliminate loops from innermost to outermost.
while (!worklist.empty()) {
Frame* frame = worklist.front();
worklist.pop_front();
if (frame->parent == frame) {
// Skip the root frame.
continue;
}
TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));
// If the parent has no remaining children, add it to the worklist.
--frame->parent->num_children;
if (frame->parent->num_children == 0) {
worklist.push_back(frame->parent);
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,32 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// Transformation that converts tf.while_loop() loops into functional While
// operators, suitable for XLA compilation.
// TODO(b/36470387): add support for conditionals.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_

View File

@ -0,0 +1,647 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
namespace {
// Returns the names of the "cond" and "body" functions for the While node
// in a graph.
Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
NameAttrList* body) {
for (const NodeDef& node : graph.node()) {
if (node.op() == "XlaWhile") {
const NameAttrList* result;
TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
*cond = *result;
TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
*body = *result;
return Status::OK();
}
}
return errors::NotFound("No XlaWhile node found in graph");
}
// Graph:
// x = array_ops.placeholder(dtypes.int32)
// y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow, OneLoopVar) {
Graph graph(OpRegistry::Global());
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto enter =
ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
auto merge = ops::Merge(scope.WithOpName("while/Merge"),
std::initializer_list<Input>{enter, dummy});
auto ten = ops::Const<int32>(
scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
10);
auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
auto switch_ =
ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
switch_.output_false);
auto identity =
ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
auto one = ops::Const<int32>(
scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
auto next_iteration =
ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
auto sink = ops::Identity(scope.WithOpName("sink"), exit);
// Remove the dummy node and add the loop backedge.
scope.graph()->RemoveNode(dummy.node());
scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
TF_EXPECT_OK(scope.ToGraph(&graph));
}
FunctionLibraryDefinition library(OpRegistry::Global(), {});
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
GraphDef graph_def;
graph.ToGraphDef(&graph_def);
NameAttrList cond_fn, body_fn;
TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
// Outer graph
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
ops::XlaWhile(scope.WithOpName("while/LoopCond"),
std::initializer_list<Input>{source}, cond_fn, body_fn);
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
// Condition graph
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
auto ten = ops::Const<int32>(
scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
// Body graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
auto one = ops::Const<int32>(
scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
}
// Graph:
// x = array_ops.placeholder(dtypes.int32)
// y = array_ops.placeholder(dtypes.int32)
// cond = lambda (i, j): i + 3 < 10
// body = lambda (i, j): (i < 10, j * 2)
// z = control_flow_ops.while_loop(cond, body, [x, y])
TEST(FunctionalizeControlFlow, TwoLoopVars) {
Graph graph(OpRegistry::Global());
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
auto enter_x =
ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
auto enter_y =
ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
std::initializer_list<Input>{enter_x, dummy});
auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
std::initializer_list<Input>{enter_y, dummy});
// Loop condition
auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
.WithControlDependencies(merge_x.output),
3);
auto cond_add =
ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
.WithControlDependencies(merge_x.output),
10);
auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
merge_x.output, loop_cond);
auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
merge_y.output, loop_cond);
auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
switch_x.output_false);
auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
switch_y.output_false);
auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
switch_x.output_true);
auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
switch_y.output_true);
auto one = ops::Const<int32>(
scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
1);
auto two = ops::Const<int32>(
scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
2);
auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
auto next_iteration_x =
ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
auto next_iteration_y =
ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
// Remove the dummy node and add the loop backedges.
scope.graph()->RemoveNode(dummy.node());
scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
1);
scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
1);
TF_EXPECT_OK(scope.ToGraph(&graph));
}
FunctionLibraryDefinition library(OpRegistry::Global(), {});
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
GraphDef graph_def;
graph.ToGraphDef(&graph_def);
NameAttrList cond_fn, body_fn;
TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
// Outer graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
auto while_op =
ops::XlaWhile(scope.WithOpName("while/LoopCond"),
std::initializer_list<Input>{x, y}, cond_fn, body_fn);
auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
// Condition graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
.WithControlDependencies(arg0.output),
3);
auto cond_add =
ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
auto ten = ops::Const<int32>(
scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output),
10);
auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
// Body graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
auto one = ops::Const<int32>(
scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
1);
auto two = ops::Const<int32>(
scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
2);
auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
}
// Example with nesting, loop-invariant arguments, and resource variables.
//
// accum = resource_variable_ops.ResourceVariable(1)
// x = array_ops.placeholder(2, dtype=dtypes.int32)
// y = 3 + x
//
// def inner_body(j, k):
// add = state_ops.assign_add(accum, k * j + x)
// with ops.control_dependencies([add]):
// return [j + 1, k]
//
// def body(i):
// m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
// [1, y], name="inner")
// with ops.control_dependencies(m):
// return [i + 1]
//
// z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
TEST(FunctionalizeControlFlow, Complex) {
Graph graph(OpRegistry::Global());
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
auto y = ops::Add(scope.WithOpName("y"), x, three);
auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
TensorShape({}));
// Outer loop
auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
auto enter_i =
ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
std::initializer_list<Input>{enter_i, dummy});
auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
.WithControlDependencies(merge_i.output),
10);
auto less_i =
ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
auto outer_loop_cond =
ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
merge_i.output, outer_loop_cond);
auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
switch_i.output_false);
auto identity_i =
ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
auto enter_x_outer =
ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
ops::internal::Enter::Attrs().IsConstant(true));
auto enter_k_outer =
ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
ops::internal::Enter::Attrs().IsConstant(true));
auto enter_var_outer =
ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
ops::internal::Enter::Attrs().IsConstant(true));
// Inner loop
auto one_j = ops::Const<int32>(
scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
one_j, "inner");
auto enter_k =
ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
.WithControlDependencies(identity_i),
enter_k_outer, "inner");
auto enter_x = ops::internal::Enter(
scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
ops::internal::Enter::Attrs().IsConstant(true));
auto enter_var = ops::internal::Enter(
scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
ops::internal::Enter::Attrs().IsConstant(true));
auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
std::initializer_list<Input>{enter_j, dummy});
auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
std::initializer_list<Input>{enter_k, dummy});
auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
.WithControlDependencies(merge_j.output),
5);
auto less_j =
ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j);
auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
merge_j.output, loop_cond);
auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
merge_k.output, loop_cond);
auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
switch_j.output_false);
auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
switch_k.output_false);
auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
switch_j.output_true);
auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
switch_k.output_true);
// Variable update
auto mul_jk =
ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
auto add_jkx =
ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
auto assign = ops::AssignAddVariableOp(
scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
auto one =
ops::Const<int32>(scope.WithOpName("outer/inner/One")
.WithControlDependencies(
gtl::ArraySlice<Operation>{assign.operation}),
1);
auto add_j =
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
auto next_iteration_j = ops::NextIteration(
scope.WithOpName("outer/inner/NextIteration_j"), add_j);
auto next_iteration_k = ops::NextIteration(
scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
// Body and backedge for outer loop.
auto one_outer = ops::Const<int32>(
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
auto add_i =
ops::Add(scope.WithOpName("outer/add")
.WithControlDependencies(gtl::ArraySlice<Operation>{
exit_j.output.op(), exit_k.output.op()}),
identity_i, one_outer);
auto next_iteration_i =
ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
// Remove the dummy node and add the loop backedge.
scope.graph()->RemoveNode(dummy.node());
scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
1);
scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
1);
scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
1);
TF_EXPECT_OK(scope.ToGraph(&graph));
}
FunctionLibraryDefinition library(OpRegistry::Global(), {});
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
GraphDef graph_def;
graph.ToGraphDef(&graph_def);
NameAttrList outer_cond_fn, outer_body_fn;
TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
// Outer graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
auto y = ops::Add(scope.WithOpName("y"), x, three);
auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
TensorShape({}));
auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"),
std::initializer_list<Input>{zero, y, x, var},
outer_cond_fn, outer_body_fn);
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
// Outer condition graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
auto ten = ops::Const<int32>(
scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
10);
auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
// Outer body graph.
NameAttrList inner_cond_fn, inner_body_fn;
{
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
// Find the inner condition and body names.
TF_EXPECT_OK(
FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
auto one_j = ops::Const<int32>(
scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
auto while_op =
ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"),
std::initializer_list<Input>{one_j, arg1, arg2, arg3},
inner_cond_fn, inner_body_fn);
auto one_outer = ops::Const<int32>(
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
auto add_i =
ops::Add(scope.WithOpName("outer/add")
.WithControlDependencies(gtl::ArraySlice<Operation>{
while_op[0].op(), while_op[1].op()}),
identity_i, one_outer);
auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0);
auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1);
auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
// Inner condition graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
auto five = ops::Const<int32>(
scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5);
auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
// Inner body graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
auto identity_j =
ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
auto identity_k =
ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
auto mul_jk =
ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
auto assign = ops::AssignAddVariableOp(
scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
auto one =
ops::Const<int32>(scope.WithOpName("outer/inner/One")
.WithControlDependencies(
gtl::ArraySlice<Operation>{assign.operation}),
1);
auto add_j =
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0);
auto retval1 =
ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1);
auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
}
} // namespace
} // namespace tensorflow

View File

@ -68,6 +68,7 @@ tf_kernel_library(
"reduction_ops.h",
],
deps = [
":while_op",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal_util",
@ -91,6 +92,21 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "while_op",
srcs = ["while_op.cc"],
hdrs = ["while_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:functional_ops",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow_opensource",
],
)
# Kernels that only work on CPU, because they use XLA custom calls.
# Only link this when using the CPU backend for XLA.
#

View File

@ -51,13 +51,26 @@ class ArgOp : public XlaOpKernel {
XlaContext& xc = XlaContext::Get(ctx);
const XlaContext::Argument& arg = xc.args()[index_];
if (arg.is_variable) {
if (arg.is_resource) {
XlaResource::Kind kind;
switch (arg.kind) {
case XlaCompiler::Argument::kVariable:
kind = XlaResource::kVariable;
break;
case XlaCompiler::Argument::kTensorArray:
kind = XlaResource::kTensorArray;
break;
default:
CHECK(false);
}
// TODO(phawkins): this code assumes that variables do not alias.
XlaVariable* var;
OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type,
arg.value.handle, &var));
var->tensor_array_size = arg.tensor_array_size;
ctx->SetVariableOutput(0, var);
XlaResource* resource;
OP_REQUIRES_OK(ctx,
xc.CreateResource(kind, index_, arg.name, arg.value.type,
arg.value.handle, &resource));
resource->tensor_array_size = arg.tensor_array_size;
ctx->SetResourceOutput(0, resource);
} else if (arg.value.is_constant) {
ctx->SetConstantOutput(0, arg.value.constant_value);
} else {

View File

@ -127,8 +127,8 @@ void BatchToSpace(XlaOpKernelContext* ctx,
std::vector<int64> end_indices = reshaped_permuted_shape;
std::vector<int64> strides(input_rank, 1);
for (int i = 0; i < block_rank; ++i) {
int64 crop_start = xla::LiteralUtil::Get<int64>(crops, {i, 0});
int64 crop_end = xla::LiteralUtil::Get<int64>(crops, {i, 1});
int64 crop_start = crops.Get<int64>({i, 0});
int64 crop_end = crops.Get<int64>({i, 1});
OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0,
errors::InvalidArgument("Crops must be non-negative"));
start_indices[1 + i] = crop_start;

View File

@ -55,7 +55,7 @@ class BCastGradArgsOp : public XlaOpKernel {
BCast::Vec vec;
for (int64 i = 0; i < in_shape.num_elements(); ++i) {
vec.push_back(xla::LiteralUtil::Get<int>(literal, {i}));
vec.push_back(literal.Get<int>({i}));
}
shapes.push_back(vec);
}

View File

@ -52,7 +52,7 @@ class ConcatBaseOp : public XlaOpKernel {
xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal));
// TODO(annarev): add a helper to support int64 input.
const int32 concat_dim = xla::LiteralUtil::Get<int>(literal, {});
const int32 concat_dim = literal.Get<int>({});
std::vector<xla::ComputationDataHandle> values;
std::vector<TensorShape> shapes;
@ -163,7 +163,7 @@ class ConcatOffsetOp : public XlaOpKernel {
xla::Literal concat_dim_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal));
const int64 cdim = xla::LiteralUtil::Get<int>(concat_dim_literal, {});
const int64 cdim = concat_dim_literal.Get<int>({});
VLOG(1) << "ConcatOffset " << cdim << "," << dims;
int32 axis = cdim < 0 ? cdim + dims : cdim;
@ -185,12 +185,10 @@ class ConcatOffsetOp : public XlaOpKernel {
for (int64 j = 0; j < dims; ++j) {
if (j == axis) {
out_vec(j) = offset;
offset += xla::LiteralUtil::Get<int>(inp_literal, {j});
offset += inp_literal.Get<int>({j});
} else {
const int32 inp0_element =
xla::LiteralUtil::Get<int>(inp0_literal, {j});
const int32 inp_element =
xla::LiteralUtil::Get<int>(inp_literal, {j});
const int32 inp0_element = inp0_literal.Get<int>({j});
const int32 inp_element = inp_literal.Get<int>({j});
OP_REQUIRES(
ctx, (inp0_element == inp_element),
errors::InvalidArgument("input[", i, ",", j, "] mismatch: ",

View File

@ -103,8 +103,7 @@ class DynamicStitchOp : public XlaOpKernel {
int max_index = -1;
for (int input_num = 0; input_num < indices.size(); input_num++) {
for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) {
max_index = std::max(
max_index, xla::LiteralUtil::Get<int>(indices[input_num], {i}));
max_index = std::max(max_index, indices[input_num].Get<int>({i}));
}
}
int number_of_indices = max_index + 1;
@ -118,7 +117,7 @@ class DynamicStitchOp : public XlaOpKernel {
int index_used_count = 0;
for (int input_num = 0; input_num < indices.size(); input_num++) {
for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) {
int index = xla::LiteralUtil::Get<int>(indices[input_num], {i});
int index = indices[input_num].Get<int>({i});
src_input_vector[index] = input_num;
src_slice_vector[index] = i;
if (!src_index_used[index]) {

View File

@ -52,7 +52,7 @@ class FillOp : public XlaOpKernel {
std::vector<int64> broadcast;
broadcast.reserve(dims_literal.shape().dimensions(0));
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
broadcast.push_back(xla::LiteralUtil::Get<int>(dims_literal, {i}));
broadcast.push_back(dims_literal.Get<int>({i}));
}
// Look up the value input, reshaping to a scalar if it was a
// 'legacy' scalar (secretly a vector).

View File

@ -66,10 +66,10 @@ class GatherOp : public XlaOpKernel {
std::vector<xla::ComputationDataHandle> args;
args.push_back(tc.GetOrCreateRuntimeContextParameter());
args.push_back(b.ConstantLiteral(
*xla::LiteralUtil::CreateR0<int64>(indices_shape.num_elements())));
*xla::Literal::CreateR0<int64>(indices_shape.num_elements())));
args.push_back(b.ConstantLiteral(
*xla::LiteralUtil::CreateR0<int64>(params_shape.dim_size(0))));
args.push_back(b.ConstantLiteral(*xla::LiteralUtil::CreateR0<int64>(
*xla::Literal::CreateR0<int64>(params_shape.dim_size(0))));
args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0<int64>(
params_shape.num_elements() / params_shape.dim_size(0))));
args.push_back(ctx->Input(0));
args.push_back(ctx->Input(1));

View File

@ -69,7 +69,7 @@ class ArgMaxOp : public XlaOpKernel {
// XLA op would have the same requirement.
xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
const int32 dim = xla::LiteralUtil::Get<int32>(literal, {});
const int32 dim = literal.Get<int32>({});
OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
OP_REQUIRES(
ctx, dim < input_shape.dims(),
@ -97,14 +97,13 @@ class ArgMaxOp : public XlaOpKernel {
std::vector<xla::ComputationDataHandle> args;
args.push_back(ctx->Input(0));
args.push_back(b.ConstantLiteral(
*xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
*xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
args.push_back(b.ConstantLiteral(
*xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
b.ConstantLiteral(*xla::LiteralUtil::CreateR0<int32>(dim)));
*xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0<int32>(dim)));
}
xla::Shape xla_shape =

View File

@ -60,8 +60,8 @@ class PadOp : public XlaOpKernel {
xla::PaddingConfig config;
for (int i = 0; i < fixed_dims; ++i) {
auto* dim = config.add_dimensions();
int before = xla::LiteralUtil::Get<int32>(pad_literal, {i, 0});
int after = xla::LiteralUtil::Get<int32>(pad_literal, {i, 1});
int before = pad_literal.Get<int32>({i, 0});
int after = pad_literal.Get<int32>({i, 1});
OP_REQUIRES(ctx, before >= 0 && after >= 0,
errors::InvalidArgument("Paddings must be non-negative: ",
before, " ", after));

View File

@ -63,7 +63,7 @@ class MinOp : public XlaReductionOp {
xla::ComputationBuilder* builder) override {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
return builder->ConstantLiteral(xla::LiteralUtil::MaxValue(type));
return builder->ConstantLiteral(xla::Literal::MaxValue(type));
}
void BuildReducer(xla::ComputationBuilder* builder,
@ -83,7 +83,7 @@ class MaxOp : public XlaReductionOp {
xla::ComputationBuilder* builder) override {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
return builder->ConstantLiteral(xla::LiteralUtil::MinValue(type));
return builder->ConstantLiteral(xla::Literal::MinValue(type));
}
void BuildReducer(xla::ComputationBuilder* builder,

View File

@ -66,13 +66,13 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
1, {axes_tensor_shape.num_elements()}, &axes_literal));
VLOG(1) << "data shape: " << data_shape.DebugString();
VLOG(1) << "axes : " << xla::LiteralUtil::ToString(axes_literal);
VLOG(1) << "axes : " << axes_literal.ToString();
gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
std::vector<int64> xla_axes;
int64 num_elements_reduced = 1LL;
for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) {
int32 index = xla::LiteralUtil::Get<int>(axes_literal, {i});
int32 index = axes_literal.Get<int>({i});
OP_REQUIRES(ctx,
!(index < -data_shape.dims() || index >= data_shape.dims()),
errors::InvalidArgument("Invalid reduction dimension (", index,

View File

@ -50,7 +50,7 @@ class ReshapeOp : public XlaOpKernel {
int64 product = 1;
int unknown_index = -1;
for (int d = 0; d < num_dims; ++d) {
const int32 size = xla::LiteralUtil::Get<int>(literal, {d});
const int32 size = literal.Get<int>({d});
if (size == -1) {
OP_REQUIRES(
ctx, unknown_index == -1,

View File

@ -32,7 +32,7 @@ template <typename T>
Status GetValue(int index, XlaOpKernelContext* ctx, T* value) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
*value = xla::LiteralUtil::Get<T>(literal, {});
*value = literal.Get<T>({});
return Status::OK();
}
@ -41,10 +41,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
switch (literal.shape().element_type()) {
case xla::S32:
*value = xla::LiteralUtil::Get<int32>(literal, {});
*value = literal.Get<int32>({});
break;
case xla::S64:
*value = xla::LiteralUtil::Get<int64>(literal, {});
*value = literal.Get<int64>({});
break;
default:
return errors::InvalidArgument("Invalid argument type for argument",
@ -58,9 +58,9 @@ template <typename T>
Status CreateRangeTensor(const xla::Literal& start_literal,
const xla::Literal& limit_literal,
const xla::Literal& delta_literal, Tensor* output) {
T start = xla::LiteralUtil::Get<T>(start_literal, {});
T limit = xla::LiteralUtil::Get<T>(limit_literal, {});
T delta = xla::LiteralUtil::Get<T>(delta_literal, {});
T start = start_literal.Get<T>({});
T limit = limit_literal.Get<T>({});
T delta = delta_literal.Get<T>({});
if (delta == 0) {
return errors::InvalidArgument("Requires delta != 0: ", delta);

View File

@ -56,8 +56,8 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
padding_config.add_dimensions(); // Don't pad the batch dimension.
for (int i = 0; i < block_rank; ++i) {
auto* dim = padding_config.add_dimensions();
int64 pad_start = xla::LiteralUtil::Get<int64>(paddings, {i, 0});
int64 pad_end = xla::LiteralUtil::Get<int64>(paddings, {i, 1});
int64 pad_start = paddings.Get<int64>({i, 0});
int64 pad_end = paddings.Get<int64>({i, 1});
OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0,
errors::InvalidArgument("Paddings must be non-negative"));
dim->set_edge_padding_low(pad_start);

View File

@ -39,7 +39,7 @@ class SplitOp : public XlaOpKernel {
int32 split_dim;
if (index_shape.dims() == 0) {
split_dim = xla::LiteralUtil::Get<int>(literal_index, {});
split_dim = literal_index.Get<int>({});
} else {
OP_REQUIRES(
ctx, index_shape.dims() == 1,
@ -49,7 +49,7 @@ class SplitOp : public XlaOpKernel {
ctx, index_shape.dim_size(0) == 1,
errors::InvalidArgument("split_index input to Split Op must be a "
"scalar or a vector with 1 element"));
split_dim = xla::LiteralUtil::Get<int>(literal_index, {0});
split_dim = literal_index.Get<int>({0});
}
const int32 num_split = num_outputs();
const TensorShape input_shape = ctx->InputShape(1);
@ -115,7 +115,7 @@ class SplitVOp : public XlaOpKernel {
OP_REQUIRES(ctx, index_shape.dims() == 0,
errors::InvalidArgument("split_dim input to Split Op must be a "
"scalar"));
split_dim = xla::LiteralUtil::Get<int>(literal_index, {});
split_dim = literal_index.Get<int>({});
xla::ComputationDataHandle input = ctx->Input(0);
const TensorShape input_shape = ctx->InputShape(0);
@ -152,7 +152,7 @@ class SplitVOp : public XlaOpKernel {
for (int i = 0; i < num_split; ++i) {
int slice_size;
slice_size = xla::LiteralUtil::Get<int>(split_size_literal, {i});
slice_size = split_size_literal.Get<int>({i});
if (slice_size == -1) {
OP_REQUIRES(
ctx, neg_one_dim == -1,

View File

@ -41,32 +41,36 @@ namespace {
// Since the element shape is not always provided to the TensorArrayV3 operator,
// we must support lazily initialization of the TensorArray at the time of the
// first write.
// If a TensorArray `var` has not been initialized, constructs storage for the
// TensorArray with elements of `elem_shape`. For both initialized and
// If a TensorArray `resource` has not been initialized, constructs storage for
// the TensorArray with elements of `elem_shape`. For both initialized and
// uninitialized TensorArrays, checks that the tensor has a type compatible with
// 'dtype' and shape compatible with 'elem_shape'.
Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
XlaVariable* var, DataType dtype,
XlaResource* resource, DataType dtype,
const TensorShape& elem_shape) {
if (var->type != dtype) {
if (resource->kind != XlaResource::kTensorArray) {
return errors::InvalidArgument("Unexpected non-TensorArray resource");
}
if (resource->type != dtype) {
return errors::InvalidArgument(
"TensorArray dtype is ", DataTypeString(var->type),
"TensorArray dtype is ", DataTypeString(resource->type),
" but op has dtype ", DataTypeString(dtype), ".");
}
TF_RET_CHECK(var->tensor_array_size >= 0)
<< var->name << " size " << var->tensor_array_size;
TF_RET_CHECK(resource->tensor_array_size >= 0)
<< resource->name << " size " << resource->tensor_array_size;
TensorShape ta_shape;
ta_shape.AddDim(var->tensor_array_size);
ta_shape.AddDim(resource->tensor_array_size);
ta_shape.AppendShape(elem_shape);
if (var->value.handle() == 0) {
if (resource->value.handle() == 0) {
// TensorArray has not been initialized.
xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type);
var->value = builder->Broadcast(zero, ta_shape.dim_sizes());
xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type);
resource->value = builder->Broadcast(zero, ta_shape.dim_sizes());
} else {
// Checks the elem_shape matches the TensorArray shape.
auto shape_or_status = builder->GetShape(var->value);
auto shape_or_status = builder->GetShape(resource->value);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
@ -80,6 +84,44 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
return Status::OK();
}
// Checks that the TensorArray 'resource' has been initialized, and has type
// 'dtype'. Sets 'shape' to the shape
Status CheckTensorArrayIsInitialized(const string& op_name,
const XlaResource* resource,
DataType dtype) {
if (resource->kind != XlaResource::kTensorArray) {
return errors::InvalidArgument(
"Unexpected non-TensorArray resource passed "
"to ",
op_name);
}
if (resource->value.handle() == 0) {
return errors::InvalidArgument("Uninitialized TensorArray passed to ",
op_name);
}
if (resource->type != dtype) {
return errors::InvalidArgument(
"TensorArray dtype is ", DataTypeString(resource->type),
" but op has dtype ", DataTypeString(dtype), ".");
}
return Status::OK();
}
Status GetTensorArrayShape(const XlaResource* resource,
xla::ComputationBuilder* builder,
TensorShape* shape) {
auto shape_or_status = builder->GetShape(resource->value);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
*shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie());
if (shape->dims() < 1) {
return errors::InvalidArgument("TensorArray rank must be >= 1");
}
return Status::OK();
}
// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
xla::ComputationDataHandle PadIndexWithZeros(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
@ -125,7 +167,6 @@ class TensorArrayOp : public XlaOpKernel {
errors::InvalidArgument("TensorArray size must be >= 0"));
xla::ComputationBuilder* b = ctx->builder();
b->set_die_immediately_on_error(true);
// Initializes the TensorArray value if we know the element shape.
// Otherwise, defer initialization to the first write.
@ -141,12 +182,13 @@ class TensorArrayOp : public XlaOpKernel {
}
XlaContext& xc = XlaContext::Get(ctx);
XlaVariable* var;
XlaResource* var;
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
OP_REQUIRES_OK(ctx,
xc.CreateVariable(-1, std::move(name), dtype_, value, &var));
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
dtype_, value, &var));
var->tensor_array_size = size;
ctx->SetVariableOutput(0, var);
ctx->SetResourceOutput(0, var);
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
}
@ -173,11 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel {
// Initializes the TensorArray, if the element shape was not known at
// construction time.
XlaVariable* var;
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
xla::ComputationDataHandle ta = var->value;
xla::ComputationDataHandle ta = resource->value;
xla::ComputationDataHandle index = ctx->Input(1);
xla::ComputationDataHandle value = ctx->Input(2);
@ -191,7 +234,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written));
resource->value = written;
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
}
@ -210,20 +253,17 @@ class TensorArrayReadOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
DataType ta_type;
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
OP_REQUIRES(ctx, ta_type == dtype_,
errors::InvalidArgument(
"TensorArray dtype is ", DataTypeString(ta_type),
" but Op requested dtype ", DataTypeString(dtype_), "."));
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
errors::InvalidArgument("TensorArray rank must be >= 1"));
xla::ComputationBuilder* b = ctx->builder();
xla::ComputationDataHandle ta;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES_OK(ctx,
CheckTensorArrayIsInitialized(name(), resource, dtype_));
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
xla::ComputationDataHandle ta = resource->value;
xla::ComputationDataHandle index = ctx->Input(1);
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
@ -255,13 +295,15 @@ class TensorArrayGatherOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
DataType ta_type;
xla::ComputationBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES_OK(ctx,
CheckTensorArrayIsInitialized(name(), resource, dtype_));
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
OP_REQUIRES(ctx, ta_type == dtype_,
errors::InvalidArgument("TensorArray type mismatch"));
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
errors::InvalidArgument("TensorArray rank must be >= 1"));
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
const TensorShape indices_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
@ -269,10 +311,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
const int num_indices = indices_shape.dim_size(0);
auto indices = ctx->Input(1);
xla::ComputationBuilder* b = ctx->builder();
xla::ComputationDataHandle ta;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
xla::ComputationDataHandle ta = resource->value;
// For each index in `indices`, add the corresponding slice to `slices`.
std::vector<xla::ComputationDataHandle> slices(num_indices);
@ -320,11 +359,12 @@ class TensorArrayScatterOp : public XlaOpKernel {
const TensorShape value_shape = ctx->InputShape(2);
XlaVariable* var;
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
TensorShape elem_shape = value_shape;
elem_shape.RemoveDim(0);
OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
const TensorShape indices_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
@ -332,7 +372,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
const int num_indices = indices_shape.dim_size(0);
const xla::ComputationDataHandle indices = ctx->Input(1);
xla::ComputationDataHandle ta = var->value;
xla::ComputationDataHandle ta = resource->value;
const xla::ComputationDataHandle value = ctx->Input(2);
auto slice_dims = value_shape.dim_sizes();
@ -357,7 +397,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
resource->value = ta;
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
}
@ -376,18 +416,17 @@ class TensorArrayConcatOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
DataType ta_type;
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
OP_REQUIRES(ctx, ta_type == dtype_,
errors::InvalidArgument("TensorArray type mismatch"));
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
errors::InvalidArgument("TensorArray rank must be >= 1"));
xla::ComputationBuilder* b = ctx->builder();
xla::ComputationDataHandle ta;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES_OK(ctx,
CheckTensorArrayIsInitialized(name(), resource, dtype_));
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
xla::ComputationDataHandle ta = resource->value;
auto ta_dims = ta_shape.dim_sizes();
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
@ -438,19 +477,20 @@ class TensorArraySplitOp : public XlaOpKernel {
elem_shape.set_dim(0, length);
xla::ComputationBuilder* b = ctx->builder();
XlaVariable* var;
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
xla::ComputationDataHandle ta = var->value;
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
xla::ComputationDataHandle ta = resource->value;
TensorShape ta_shape;
ta_shape.AddDim(var->tensor_array_size);
ta_shape.AddDim(resource->tensor_array_size);
ta_shape.AppendShape(elem_shape);
OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size,
OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size,
errors::InvalidArgument(
"TensorArray's size is not equal to the size of lengths (",
lengths.size(), " vs. ", var->tensor_array_size, ")"));
lengths.size(), " vs. ", resource->tensor_array_size, ")"));
const xla::ComputationDataHandle value = ctx->Input(1);
@ -459,8 +499,7 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
}
@ -478,8 +517,8 @@ class TensorArraySizeOp : public XlaOpKernel {
explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
XlaVariable* var;
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
XlaResource* var;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var));
Tensor size_tensor(DT_INT32, {});
size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size);
ctx->SetConstantOutput(0, size_tensor);
@ -500,31 +539,31 @@ class TensorArrayGradOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
XlaVariable* var;
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
DataType ta_type;
OP_REQUIRES_OK(
ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type));
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
errors::InvalidArgument("TensorArray rank must be >= 1"));
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
// Finds or looks up the corresponding gradient TensorArray, which stores
// gradients computed during backpropagation.
XlaVariable*& gradient = var->tensor_array_gradient[source_];
XlaResource*& gradient = resource->tensor_array_gradient[source_];
if (!gradient) {
xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type);
xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type);
xla::ComputationDataHandle value =
b->Broadcast(zero, ta_shape.dim_sizes());
XlaContext& xc = XlaContext::Get(ctx);
string name = strings::StrCat("TensorArrayGrad: ", var->name);
OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type,
value, &gradient));
gradient->tensor_array_size = var->tensor_array_size;
string name = strings::StrCat("TensorArrayGrad: ", resource->name);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
resource->type, value, &gradient));
gradient->tensor_array_size = resource->tensor_array_size;
}
ctx->SetVariableOutput(0, gradient);
ctx->SetResourceOutput(0, gradient);
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
}
@ -536,5 +575,19 @@ class TensorArrayGradOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp);
class TensorArrayCloseOp : public XlaOpKernel {
public:
explicit TensorArrayCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
// Do nothing; XLA handles resource management.
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayCloseOp);
};
REGISTER_XLA_OP(Name("TensorArrayCloseV3"), TensorArrayCloseOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -68,7 +68,7 @@ class TileOp : public XlaOpKernel {
bool all_multiples_are_one = true;
bool one_dimension_is_broadcasted_without_multiple = true;
for (int i = 0; i < input_dims; ++i) {
int multiple = xla::LiteralUtil::Get<int>(literal, {i});
int multiple = literal.Get<int>({i});
OP_REQUIRES(ctx, multiple,
errors::InvalidArgument("Expected multiples[", i,
"] >= 0, but got ", multiple));

View File

@ -44,6 +44,7 @@ namespace {
// Return x if x>0, otherwise -x.
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x));
XLAJIT_MAKE_UNARY(Cos, b->Cos(x));
XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.

View File

@ -0,0 +1,265 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace {
// Builds XlaCompiler argument descriptions `args` from `ctx`.
Status MakeXlaCompilerArgumentsFromInputs(
XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args,
bool* has_uninitialized_vars) {
VLOG(2) << "Num inputs " << ctx->num_inputs();
args->resize(ctx->num_inputs());
*has_uninitialized_vars = false;
for (int i = 0; i < ctx->num_inputs(); ++i) {
VLOG(2) << " Input " << i
<< " type: " << DataTypeString(ctx->input_type(i))
<< " shape: " << ctx->InputShape(i).DebugString();
XlaCompiler::Argument& arg = (*args)[i];
DataType type = ctx->input_type(i);
// When reading a resource input, use the type and shape of the resource's
// current value.
if (type == DT_RESOURCE) {
XlaResource* resource;
TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource));
arg.initialized = resource->value.handle() > 0;
switch (resource->kind) {
case XlaResource::kVariable:
arg.kind = XlaCompiler::Argument::kVariable;
break;
case XlaResource::kTensorArray:
arg.kind = XlaCompiler::Argument::kTensorArray;
break;
case XlaResource::kInvalid:
CHECK(false);
}
arg.type = resource->type;
if (arg.initialized) {
auto shape = ctx->builder()->GetShape(resource->value);
TF_RETURN_IF_ERROR(shape.status());
arg.shape = XLAShapeToTensorShape(*shape.ValueOrDie());
} else {
*has_uninitialized_vars = true;
}
arg.tensor_array_size = resource->tensor_array_size;
arg.name = resource->name;
// TODO(phawkins): propagate TensorArray gradients into loops.
VLOG(2) << " resource " << resource->name
<< " type: " << DataTypeString(arg.type)
<< " shape: " << arg.shape.DebugString()
<< " initialized: " << arg.initialized;
} else {
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = ctx->input_type(i);
arg.shape = ctx->InputShape(i);
}
}
return Status::OK();
}
} // anonymous namespace
XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
const NameAttrList* name_attr;
OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &name_attr));
cond_name_attr_ = *name_attr;
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
body_name_attr_ = *name_attr;
}
void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "WhileOp::Compile";
std::vector<XlaCompiler::Argument> arguments;
bool has_uninitialized_vars;
OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs(
ctx, &arguments, &has_uninitialized_vars));
const bool use_tuple_arg = (arguments.size() != 1);
xla::ComputationBuilder* builder = ctx->builder();
XlaCompiler* compiler = ctx->compiler();
VLOG(1) << "Compiling body";
// All resource that are inputs to the loop's body must also be
// present as loop body outputs; the signature of the loop's input and
// output must match. We ensure this by asking the compiler to include the
// current values of all resources, even if they haven't been updated by the
// computation.
// TODO(phawkins): consider adding loop-invariant inputs to XLA's While()
// operator.
XlaCompiler::CompileOptions body_options;
body_options.use_tuple_arg = use_tuple_arg;
body_options.return_updated_values_for_all_resources = true;
XlaCompiler::CompilationResult body;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
// We must use a static shape for parameters to an XLA compilation. However,
// we may not know the shape of a TensorArray if it is first written inside
// the loop. Ideally we would require the user to provide a static shape,
// but this is not always easy.
// So if uninitialized resource are used by the loop body, we compile the
// body function twice:
// 1) once with uninitialized resource inputs. We discard the computation
// but we assume resource shapes reach a fixpoint after one iteration.
// So we can use the output shapes of the resource as the "true" shapes.
// 2) again with the "correct" input shapes determined by (1).
if (has_uninitialized_vars) {
// Initializes any uninitialized resource with zero values of the
// shape determined by the first compilation.
for (int i = 0; i < body.resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
XlaCompiler::Argument& arg = arguments[update.input_index];
if (!arg.initialized) {
arg.initialized = true;
arg.shape = update.shape;
XlaResource* resource;
OP_REQUIRES_OK(ctx,
ctx->GetResourceInput(update.input_index, &resource));
xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, arg.type);
resource->value = builder->Broadcast(zero, update.shape.dim_sizes());
}
}
// Recompile the body with the "correct" shapes.
body = {};
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
}
VLOG(1) << "Compiling condition";
XlaCompiler::CompileOptions cond_options;
cond_options.use_tuple_arg = use_tuple_arg;
XlaCompiler::CompilationResult cond;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
arguments, &cond));
xla::Shape body_input_shape, cond_input_shape;
if (use_tuple_arg) {
body_input_shape = xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes);
cond_input_shape = xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes);
} else {
CHECK(!body.xla_input_shapes.empty());
body_input_shape = body.xla_input_shapes[0];
CHECK(!body.xla_input_shapes.empty());
cond_input_shape = cond.xla_input_shapes[0];
}
VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape)
<< " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape);
VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape)
<< " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape);
OP_REQUIRES(ctx,
xla::ShapeUtil::Compatible(body_input_shape, cond_input_shape),
errors::InvalidArgument(
"Input shapes of loop body and condition do not match: ",
xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
xla::ShapeUtil::HumanString(cond_input_shape)));
OP_REQUIRES(
ctx, xla::ShapeUtil::Compatible(body_input_shape, body.xla_output_shape),
errors::InvalidArgument(
"Input and output shapes of loop body do not match: ",
xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
xla::ShapeUtil::HumanString(body.xla_output_shape)));
xla::ComputationDataHandle data;
int num_inputs = body.input_mapping.size();
std::vector<xla::ComputationDataHandle> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
inputs[i] = resource->value;
} else {
inputs[i] = ctx->Input(i);
}
}
xla::ComputationDataHandle init;
if (use_tuple_arg) {
init = builder->Tuple(inputs);
} else {
init = inputs[0];
}
VLOG(1) << "Building while loop";
xla::ComputationDataHandle while_result =
builder->While(*cond.computation, *body.computation, init);
auto get_loop_output = [&](int i) {
if (use_tuple_arg) {
return builder->GetTupleElement(while_result, i);
} else {
return while_result;
}
};
// Sets non-variable outputs.
for (int i = 0; i < ctx->num_outputs(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
ctx->SetOutput(body.input_mapping[i], get_loop_output(i));
}
}
// Updates the values of any resource variables modified by the loop.
for (int i = 0; i < body.resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
if (update.modified) {
int pos = body.outputs.size() + i;
resource->value = get_loop_output(pos);
}
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
<< " name: " << resource->name << " modified: " << update.modified
<< " type: " << DataTypeString(update.type)
<< " shape: " << update.shape.DebugString();
// Copies the identity of the resource variable from input to output
// unchanged, even if the variable was not modified.
ctx->op_kernel_context()->set_output(
update.input_index,
ctx->op_kernel_context()->input(update.input_index));
}
VLOG(1) << "Done building while loop";
}
REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp);
} // namespace tensorflow

View File

@ -0,0 +1,65 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/core/framework/attr_value.pb.h"
namespace tensorflow {
// This TensorFlow op provides a functional iteration primitive.
//
// The inputs and outputs of the loop body must agree on the number, types, and
// shapes of the Tensors carried around the loop body.
//
// Computations in while loops may read from and write to resource variables.
// Resource variables may be passed as arguments to a function's body and
// condition functions. The XlaCompiler converts resource variable arguments
// into parameters to the XLA computation and moves them to the end of the
// parameter list, and by using the `return_updated_values_for_all_variables`
// we ensure that all variables that appear in the input also appear at the
// end of the body's output. This ensures the loop body's input and output
// signatures match.
//
// It is the user's responsibility to ensure that each non-variable _Arg matches
// the corresponding _Retval.
//
// For example, suppose we have a loop body with arguments:
// DT_INT32, DT_RESOURCE (pointing to a DT_BOOL var), DT_FLOAT
// and return values
// DT_INT32, DT_FLOAT
// It is an error for the body to return DT_RESOURCE values.
//
// The body will be lowered into an XLA computation that takes and returns a
// tuple with XLA type (I32, F32, PRED). Note the resource variable appears at
// the end of both the loop body's input and output argument lists.
class XlaWhileOp : public XlaOpKernel {
public:
explicit XlaWhileOp(OpKernelConstruction* ctx);
void Compile(XlaOpKernelContext* ctx) override;
private:
NameAttrList cond_name_attr_;
NameAttrList body_name_attr_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_

View File

@ -27,13 +27,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape()));
xla::LiteralUtil::Reserve(host_tensor.NumElements(), literal);
literal->Reserve(host_tensor.NumElements());
// memcpy over the payload ...
// TODO(phawkins): handle string types.
size_t total_bytes = host_tensor.TotalBytes();
if (total_bytes > 0) {
void* dst_ptr = xla::LiteralUtil::MutableInternalData(literal);
void* dst_ptr = literal->MutableInternalData();
const void* src_ptr = DMAHelper::base(&host_tensor);
memcpy(dst_ptr, src_ptr, total_bytes);
}
@ -55,7 +55,7 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
*host_tensor = Tensor(target_type, shape);
size_t total_bytes = host_tensor->TotalBytes();
if (total_bytes > 0) {
const void* src_ptr = xla::LiteralUtil::InternalData(literal);
const void* src_ptr = literal.InternalData();
void* dst_ptr = DMAHelper::base(host_tensor);
memcpy(dst_ptr, src_ptr, total_bytes);
}

View File

@ -27,7 +27,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
{
std::vector<int64> int64_values = {1, 2, 3};
std::unique_ptr<xla::Literal> int64_values_literal =
xla::LiteralUtil::CreateR1(gtl::ArraySlice<int64>(int64_values));
xla::Literal::CreateR1(gtl::ArraySlice<int64>(int64_values));
Tensor host_tensor;
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
@ -48,7 +48,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
Tensor host_tensor;
std::vector<int32> int32_values = {10, 11};
std::unique_ptr<xla::Literal> int32_values_literal =
xla::LiteralUtil::CreateR1(gtl::ArraySlice<int32>(int32_values));
xla::Literal::CreateR1(gtl::ArraySlice<int32>(int32_values));
EXPECT_TRUE(
LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
.ok());

View File

@ -0,0 +1,38 @@
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
cc_library(
name = "functional_ops",
srcs = ["functional_ops.cc"],
deps = [
"//tensorflow/core:framework",
],
alwayslink = 1,
)
tf_gen_op_wrapper_py(
name = "gen_functional_ops",
out = "gen_functional_ops.py",
deps = [
":functional_ops",
],
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,45 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
// TODO(b/37549631) setting the While Op to always be stateful is too
// conservative.
REGISTER_OP("XlaWhile")
.Input("input: T")
.Output("output: T")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
.SetIsStateful()
.Doc(R"doc(
output = input; While (Cond(output)) { output = Body(output) }
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A function takes 'input' and returns a tensor. If the tensor is
a scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical
value, non-zero means True and zero means False; if the scalar is
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
body: A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified by T.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,42 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/test_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
Status InstantiateFunctionForTest(const string& name,
const FunctionLibraryDefinition& library,
InstantiationResultForTest* result) {
const FunctionDef* fdef = library.Find(name);
TF_RET_CHECK(fdef != nullptr);
auto get_func_sig = [&library](const string& op, const OpDef** sig) {
return library.LookUpOpDef(op, sig);
};
InstantiationResult inst;
TF_RETURN_IF_ERROR(
InstantiateFunction(*fdef, AttrSlice(), get_func_sig, &inst));
result->arg_types = inst.arg_types;
result->ret_types = inst.ret_types;
for (NodeDef& n : inst.nodes) {
*result->gdef.add_node() = std::move(n);
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,46 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Helper functions for tests.
#ifndef TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_
#include <map>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Same as InstantiationResult, but has a GraphDef instead of just nodes.
struct InstantiationResultForTest {
DataTypeVector arg_types;
DataTypeVector ret_types;
GraphDef gdef;
};
// Instantiates a function, producing a GraphDef to compare against the
// expected graph.
Status InstantiateFunctionForTest(const string& name,
const FunctionLibraryDefinition& library,
InstantiationResultForTest* result);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_

View File

@ -64,26 +64,35 @@ class XlaCompilationDevice : public LocalDevice {
std::unique_ptr<XlaCompilationAllocator> allocator_;
};
struct XlaVariable {
// If this variable is visible externally, what was its argument number?
// Represents a resource, such as a Variable or TensorArray.
struct XlaResource {
enum Kind {
kInvalid,
kVariable,
kTensorArray,
};
Kind kind = kInvalid;
// If this resource is visible externally, what was its argument number?
int arg_num = -1;
// A descriptive name for the variable, used in error messages.
// A descriptive name for the resource, used in error messages.
string name;
// Current type and value of the variable. Uninitialized variables are
// Current type and value of the resource. Uninitialized resources are
// represented by a default (zero) handle and type DT_INVALID.
// While the type of a variable is notionally fixed during execution, when
// a variable is first initialized we do not yet know its type, so we keep
// While the type of a resource is notionally fixed during execution, when
// a resource is first initialized we do not yet know its type, so we keep
// track of its type dynamically.
DataType type = DT_INVALID;
xla::ComputationDataHandle value;
// Value of the variable at computation entry. Used to detect which
// Value of the resource at computation entry. Used to detect which
// variables have new values that need to be written back.
xla::ComputationDataHandle initial_value;
// We treat TensorArrays as a Variable with some extra metadata.
// TensorArray-specific fields
// 'tensor_array_size' stores the expected size of the TensorArray. We need
// to store this since sometimes TensorArrays must be initialized lazily since
@ -91,10 +100,10 @@ struct XlaVariable {
int64 tensor_array_size = -1;
// 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
// to an XlaVariable containing the gradient TensorArrays. We store a pointer
// to an XlaResource containing the gradient TensorArrays. We store a pointer
// here since there should only be one gradient TensorArray per 'source'
// string, irrespective of the number of calls to TensorArrayGrad.
std::unordered_map<string, XlaVariable*> tensor_array_gradient;
std::unordered_map<string, XlaResource*> tensor_array_gradient;
};
// A XlaExpression wraps an XLA computation. Each Tensor on an
@ -115,8 +124,8 @@ class XlaExpression {
bool has_constant_value() const { return has_constant_value_; }
const Tensor& constant_value() const { return constant_value_; }
void set_variable(XlaVariable* variable) { variable_ = variable; }
XlaVariable* variable() const { return variable_; }
void set_resource(XlaResource* resource) { resource_ = resource; }
XlaResource* resource() const { return resource_; }
private:
// The XLA handle of the expression's computation.
@ -128,7 +137,7 @@ class XlaExpression {
bool has_constant_value_ = false;
Tensor constant_value_;
XlaVariable* variable_ = nullptr; // Not owned.
XlaResource* resource_ = nullptr; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
};

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
@ -85,9 +86,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
(*options_.populate_resource_manager)(device_->resource_manager());
}
flib_def_.reset(new FunctionLibraryDefinition(*options.flib_def));
flib_runtime_.reset(NewFunctionLibraryRuntime(
&device_mgr_, Env::Default(), device_, options.graph_def_version,
options.flib_def, OptimizerOptions(),
flib_def_.get(), OptimizerOptions(),
nullptr /* custom_kernel_creator */));
}
@ -249,35 +251,36 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
std::vector<xla::Shape>* input_shapes) {
context_args->resize(args.size());
// Argument numbers of arguments and variables that are to be passed to the
// Argument numbers of arguments and resources that are to be passed to the
// XLA computation as runtime parameters.
std::vector<int> parameters, variables;
std::vector<int> parameters, resources;
parameters.reserve(args.size());
variables.reserve(args.size());
resources.reserve(args.size());
for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
++i) {
XlaContext::Argument& context_arg = (*context_args)[i];
context_arg.kind = args[i].kind;
context_arg.name = args[i].name;
context_arg.value.constant_value = args[i].constant_value;
context_arg.value.type = args[i].type;
switch (args[i].kind) {
case XlaCompiler::Argument::kVariable:
variables.push_back(i);
context_arg.is_variable = true;
context_arg.value.is_constant = false;
case XlaCompiler::Argument::kTensorArray:
context_arg.is_resource = true;
if (args[i].initialized) {
resources.push_back(i);
context_arg.value.is_constant = false;
} else {
context_arg.value.is_constant = true;
}
context_arg.tensor_array_size = args[i].tensor_array_size;
break;
case XlaCompiler::Argument::kParameter:
parameters.push_back(i);
context_arg.value.is_constant = false;
break;
case XlaCompiler::Argument::kUninitializedVariable:
context_arg.is_variable = true;
context_arg.value.is_constant = true;
context_arg.tensor_array_size = args[i].tensor_array_size;
break;
case XlaCompiler::Argument::kConstant:
context_arg.value.is_constant = true;
break;
@ -288,7 +291,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
// Append parameters containing variable values after the other runtime
// parameters.
parameters.insert(parameters.end(), variables.begin(), variables.end());
parameters.insert(parameters.end(), resources.begin(), resources.end());
if (parameters.empty()) {
return Status::OK();
}
@ -329,22 +332,22 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
// variable states, generated by the symbolic evaluation.
// If `has_side_effects` is true, the computation has side effects and should be
// built even if it has no outputs.
// If `return_updated_values_for_all_variables` is true, all variables will be
// included in `variable_updates`, regardless of whether their value changed.
// If `return_updated_values_for_all_resources` is true, all resources will be
// included in `resource_updates`, regardless of whether their value changed.
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
// Sets `*variable_updates` to a description of variables whose values are
// Sets `*resource_updates` to a description of resources whose values are
// written by the computation; the variable writes are the last
// `variable_updates.size()` return values from the computation. Each entry in
// `variable_updates` is a (input_index, type) pair, where `input_index` is the
// `resource_updates.size()` return values from the computation. Each entry in
// `resource_updates` is a (input_index, type) pair, where `input_index` is the
// index of a resource variable argument to the computation, and `type` is the
// type of the final output.
Status BuildComputation(
const std::vector<XlaContext::HandleOrConstant>& retvals,
const std::vector<std::unique_ptr<XlaVariable>>& variables,
bool has_side_effects, bool return_updated_values_for_all_variables,
const std::vector<std::unique_ptr<XlaResource>>& resources,
bool has_side_effects, bool return_updated_values_for_all_resources,
xla::ComputationBuilder* builder, xla::Computation* computation,
int* num_nonconst_outputs,
std::vector<XlaCompiler::VariableUpdate>* variable_updates) {
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
std::vector<xla::ComputationDataHandle> elems;
elems.reserve(retvals.size());
for (const XlaContext::HandleOrConstant& retval : retvals) {
@ -354,24 +357,24 @@ Status BuildComputation(
}
*num_nonconst_outputs = elems.size();
// Add return values for variables whose values have changed.
std::vector<const XlaVariable*> arg_vars;
arg_vars.reserve(variables.size());
for (const auto& var : variables) {
// Add return values for resources whose values have changed.
std::vector<const XlaResource*> arg_vars;
arg_vars.reserve(resources.size());
for (const auto& var : resources) {
if (var->arg_num >= 0) {
arg_vars.push_back(var.get());
}
}
std::sort(arg_vars.begin(), arg_vars.end(),
[](const XlaVariable* a, const XlaVariable* b) {
[](const XlaResource* a, const XlaResource* b) {
return a->arg_num < b->arg_num;
});
for (const XlaVariable* var : arg_vars) {
for (const XlaResource* var : arg_vars) {
bool modified = var->value.handle() != var->initial_value.handle();
if (return_updated_values_for_all_variables || modified) {
variable_updates->emplace_back();
XlaCompiler::VariableUpdate& update = variable_updates->back();
if (return_updated_values_for_all_resources || modified) {
resource_updates->emplace_back();
XlaCompiler::ResourceUpdate& update = resource_updates->back();
update.input_index = var->arg_num;
update.type = var->type;
update.modified = modified;
@ -413,6 +416,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
// Report the error here if initialization failed.
TF_RETURN_IF_ERROR(initialization_status_);
// Converts Tensorflow's graph control-flow constructs into functional
// control-flow that can be compiled into XLA code.
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(graph.get(), flib_def_.get()));
xla::ComputationBuilder builder(client(), name);
XlaContext* context =
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
@ -433,10 +440,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
int num_nonconst_outputs;
result->computation = std::make_shared<xla::Computation>();
TF_RETURN_IF_ERROR(BuildComputation(
context->retvals(), context->variables(), context->has_side_effects(),
options.return_updated_values_for_all_variables, &builder,
context->retvals(), context->resources(), context->has_side_effects(),
options.return_updated_values_for_all_resources, &builder,
result->computation.get(), &num_nonconst_outputs,
&result->variable_updates));
&result->resource_updates));
result->requires_runtime_context = context->has_context_parameter();
@ -511,15 +518,15 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
}
}
for (std::vector<VariableUpdate>::size_type i = 0;
i < result->variable_updates.size(); ++i) {
for (std::vector<ResourceUpdate>::size_type i = 0;
i < result->resource_updates.size(); ++i) {
if (num_computation_outputs > 1) {
result->variable_updates[i].shape =
result->resource_updates[i].shape =
XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(
result->xla_output_shape, computation_output));
} else {
CHECK_EQ(0, computation_output);
result->variable_updates[i].shape =
result->resource_updates[i].shape =
XLAShapeToTensorShape(result->xla_output_shape);
}
++computation_output;

View File

@ -85,14 +85,14 @@ class XlaCompiler {
// Argument is a compile-time constant. No associated runtime parameter.
kConstant,
// Argument is a variable that has not been initialized yet. No associated
// runtime parameter.
kUninitializedVariable,
// Argument is a variable that already has a value set. Expects a runtime
// parameter containing the current value.
// Argument is a variable resource. Has an associated runtime parameter
// iff `initialized` is true.
kVariable,
// Argument is a TensorArray resource. Has an associated runtime parameter
// iff `initialized` is true.
kTensorArray,
// Argument is a run-time parameter.
kParameter,
};
@ -114,8 +114,11 @@ class XlaCompiler {
// The name of this argument, used for debugging.
string name;
// For a kVariable or kUninitializedVariable corresponding to a TensorArray,
// what is the tensor array's declared size?
// For a kVariable or kTensorArray, has this resource been initialized?
bool initialized = false;
// For a kTensorArray, what is the array's declared size? (Used for lazy
// initialization.)
int64 tensor_array_size = -1;
bool operator==(const Argument& other) const;
@ -133,7 +136,7 @@ class XlaCompiler {
};
// Describes a variable write side effect of the computation.
struct VariableUpdate {
struct ResourceUpdate {
// Index of the input that contains the variable resource to write to.
int input_index;
@ -142,14 +145,14 @@ class XlaCompiler {
TensorShape shape;
// Was the value of the variable modified by the computation?
// (Always true, unless `return_updated_values_for_all_variables` is true.)
// (Always true, unless `return_updated_values_for_all_resources` is true.)
bool modified;
};
struct CompilationResult {
// Vector that maps from the parameters of the XLA computation to their
// original argument positions. To handle compile-time constant inputs and
// variables, the parameters to the XLA computation may be a subset of the
// resources, the parameters to the XLA computation may be a subset of the
// original arguments, and are not necessarily in the same order.)
std::vector<int> input_mapping;
@ -172,10 +175,10 @@ class XlaCompiler {
// containing both constant and non-constant results.
std::vector<OutputDescription> outputs;
// Variables whose values were updated by the computation, ordered
// by return value position. Variable updates follow the non-constant
// Resources whose values were updated by the computation, ordered
// by return value position. Resource updates follow the non-constant
// results in the outputs of XLA computation.
std::vector<VariableUpdate> variable_updates;
std::vector<ResourceUpdate> resource_updates;
// The XLA computation built from the tensorflow subgraph. May be null
// if the output consists solely of compile-time constants.
@ -229,12 +232,12 @@ class XlaCompiler {
// arguments; if false, each argument gets its own parameter.
bool use_tuple_arg = false;
// If 'return_updated_values_for_all_variables' is true, then updated
// values of all resource variables arguments will be included in the
// 'variable_updates' of the computation, even if the variable was not
// If 'return_updated_values_for_all_resources' is true, then updated
// values of all resource resources arguments will be included in the
// 'resource_updates' of the computation, even if the resource was not
// modified by the computation. Used when compiling loop bodies to ensure
// the input and output signatures match.
bool return_updated_values_for_all_variables = false;
bool return_updated_values_for_all_resources = false;
};
// Compiles a Tensorflow function `fn_name_attrs` into an XLA computation.
@ -294,6 +297,7 @@ class XlaCompiler {
XlaCompilationDevice* device_; // Owned by device_mgr_
DeviceMgr device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
struct SignatureHash {

View File

@ -163,9 +163,9 @@ TEST_F(XlaCompilerTest, Simple) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@ -179,7 +179,7 @@ TEST_F(XlaCompilerTest, Simple) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::CreateR1<int32>({4, 143});
xla::Literal::CreateR1<int32>({4, 143});
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
}
@ -225,7 +225,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@ -236,7 +236,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::CreateR1<int32>({-7, -42});
xla::Literal::CreateR1<int32>({-7, -42});
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
}
@ -260,7 +260,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@ -270,12 +270,11 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::LiteralUtil::CreateR0<int32>(7);
std::unique_ptr<xla::Literal> expected0 = xla::Literal::CreateR0<int32>(7);
std::unique_ptr<xla::Literal> expected1 =
xla::LiteralUtil::CreateR1<int32>({-7, -42});
xla::Literal::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected =
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal);
}
}

View File

@ -129,16 +129,18 @@ void XlaContext::AddSideEffects() {
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateVariable(int arg_num, string name, DataType type,
Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num,
string name, DataType type,
const xla::ComputationDataHandle& handle,
XlaVariable** variable) {
variables_.emplace_back(new XlaVariable);
*variable = variables_.back().get();
XlaVariable& var = **variable;
var.arg_num = arg_num;
var.name = std::move(name);
var.type = type;
var.initial_value = var.value = handle;
XlaResource** resource) {
resources_.emplace_back(new XlaResource);
*resource = resources_.back().get();
XlaResource& r = **resource;
r.kind = kind;
r.arg_num = arg_num;
r.name = std::move(name);
r.type = type;
r.initial_value = r.value = handle;
return Status::OK();
}

View File

@ -52,11 +52,13 @@ class XlaContext : public ResourceBase {
};
struct Argument {
// Descriptive name for the variable, for use in error messages.
XlaCompiler::Argument::Kind kind;
// Descriptive name for the resource, for use in error messages.
string name;
// Is this a variable?
bool is_variable = false;
// Is this a resource?
bool is_resource = false;
HandleOrConstant value;
@ -106,15 +108,15 @@ class XlaContext : public ResourceBase {
bool has_side_effects() const { return has_side_effects_; }
// Creates a variable with variable `variable_id` and initial type `type` and
// Creates a resource with resource `kind` and initial type `type` and
// value `handle`. `name` is a descriptive name for use in error messages.
// Fails if the variable already exists.
Status CreateVariable(int arg_num, string name, DataType type,
const xla::ComputationDataHandle& handle,
XlaVariable** variable);
// Fails if the resource already exists.
Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
DataType type, const xla::ComputationDataHandle& handle,
XlaResource** resource);
const std::vector<std::unique_ptr<XlaVariable>>& variables() {
return variables_;
const std::vector<std::unique_ptr<XlaResource>>& resources() {
return resources_;
}
// Get an XLA lambda to compute Max. This is cached in the
@ -166,8 +168,8 @@ class XlaContext : public ResourceBase {
// Does the computation have side effects, i.e., Send() calls?
bool has_side_effects_ = false;
// Holds ownership of variables. The variables are not ordered.
std::vector<std::unique_ptr<XlaVariable>> variables_;
// Holds ownership of resources. The resources are not ordered.
std::vector<std::unique_ptr<XlaResource>> resources_;
// Cache of prebuilt computations indexed by their type.
using ComputationMap = std::map<DataType, xla::Computation>;

View File

@ -30,28 +30,28 @@ xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b,
DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::LiteralUtil::MinValue(type));
return b->ConstantLiteral(xla::Literal::MinValue(type));
}
xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b,
DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::LiteralUtil::MaxValue(type));
return b->ConstantLiteral(xla::Literal::MaxValue(type));
}
xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b,
DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::LiteralUtil::Zero(type));
return b->ConstantLiteral(xla::Literal::Zero(type));
}
xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::LiteralUtil::One(type));
return b->ConstantLiteral(xla::Literal::One(type));
}
xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
@ -61,28 +61,28 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
switch (type) {
case xla::U8:
literal = *xla::LiteralUtil::CreateR0<uint8>(value);
literal = *xla::Literal::CreateR0<uint8>(value);
break;
case xla::U32:
literal = *xla::LiteralUtil::CreateR0<uint32>(value);
literal = *xla::Literal::CreateR0<uint32>(value);
break;
case xla::U64:
literal = *xla::LiteralUtil::CreateR0<uint64>(value);
literal = *xla::Literal::CreateR0<uint64>(value);
break;
case xla::S8:
literal = *xla::LiteralUtil::CreateR0<int8>(value);
literal = *xla::Literal::CreateR0<int8>(value);
break;
case xla::S32:
literal = *xla::LiteralUtil::CreateR0<int32>(value);
literal = *xla::Literal::CreateR0<int32>(value);
break;
case xla::S64:
literal = *xla::LiteralUtil::CreateR0<int64>(value);
literal = *xla::Literal::CreateR0<int64>(value);
break;
case xla::F32:
literal = *xla::LiteralUtil::CreateR0<float>(value);
literal = *xla::Literal::CreateR0<float>(value);
break;
case xla::F64:
literal = *xla::LiteralUtil::CreateR0<double>(value);
literal = *xla::Literal::CreateR0<double>(value);
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
@ -91,7 +91,7 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::F16:
literal =
*xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
*xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";

View File

@ -39,7 +39,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
CHECK(expression->handle().handle() != 0 ||
expression->variable() != nullptr);
expression->resource() != nullptr);
VLOG(1) << "Fetched T" << expression->handle().handle();
return expression;
}
@ -144,9 +144,9 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) {
return errors::InvalidArgument("value is not a scalar");
}
if (literal.shape().element_type() == xla::S32) {
*out = xla::LiteralUtil::Get<int32>(literal, {});
*out = literal.Get<int32>({});
} else if (literal.shape().element_type() == xla::S64) {
*out = xla::LiteralUtil::Get<int64>(literal, {});
*out = literal.Get<int64>({});
} else {
return errors::InvalidArgument("value must be either int32 or int64");
}
@ -168,11 +168,11 @@ static Status LiteralToInt64Vector(const xla::Literal& literal,
int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
if (literal.shape().element_type() == xla::S32) {
for (int64 i = 0; i < size; ++i) {
out->push_back(xla::LiteralUtil::Get<int32>(literal, {i}));
out->push_back(literal.Get<int32>({i}));
}
} else if (literal.shape().element_type() == xla::S64) {
for (int64 i = 0; i < size; ++i) {
out->push_back(xla::LiteralUtil::Get<int64>(literal, {i}));
out->push_back(literal.Get<int64>({i}));
}
} else {
return errors::InvalidArgument("value must be either int32 or int64");
@ -252,8 +252,9 @@ Status XlaOpKernelContext::ReadVariableInput(
int index, xla::ComputationDataHandle* value) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaVariable* variable = expression->variable();
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind == XlaResource::kVariable);
if (variable->value.handle() == 0) {
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name);
@ -262,22 +263,13 @@ Status XlaOpKernelContext::ReadVariableInput(
return Status::OK();
}
string XlaOpKernelContext::VariableDebugString(int index) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaVariable* variable = expression->variable();
if (!variable) {
return "<invalid variable ID>";
}
return variable->name;
}
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaVariable* variable = expression->variable();
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind == XlaResource::kVariable);
if (variable->value.handle() == 0) {
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name);
@ -337,33 +329,34 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
expression->set_constant_value(constant);
}
void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) {
void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
Tensor* output = nullptr;
// The shape of the output tensor is the shape of the variable resource
// (i.e., a scalar), not the shape of the variable's value.
// The shape of the output tensor is the shape of the resource itself
// (i.e., a scalar), not the shape of the resource's value.
OP_REQUIRES_OK(context_,
context_->allocate_output(index, TensorShape(), &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
expression->set_variable(variable);
expression->set_resource(resource);
}
Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) {
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
const XlaExpression* expression =
CastExpressionFromTensor(context_->input(index));
TF_RET_CHECK(expression->variable() != nullptr);
*variable = expression->variable();
TF_RET_CHECK(expression->resource() != nullptr);
*resource = expression->resource();
return Status::OK();
}
Status XlaOpKernelContext::AssignVariable(
int index, DataType type, const xla::ComputationDataHandle& handle) {
int input_index, DataType type, const xla::ComputationDataHandle& handle) {
TF_RET_CHECK(handle.handle() != 0);
SetOpHasSideEffects();
const XlaExpression* expression =
CastExpressionFromTensor(context_->input(index));
XlaVariable* variable = expression->variable();
CastExpressionFromTensor(context_->input(input_index));
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind == XlaResource::kVariable);
if (!((variable->type == DT_INVALID && type != DT_INVALID) ||
(variable->type == type))) {
return errors::InvalidArgument(

View File

@ -148,6 +148,12 @@ class XlaOpKernelContext {
// Variables
// Sets '*resource' to the resource associated with input `index`.
Status GetResourceInput(int index, XlaResource** resource);
// Sets output 'index' to be a reference to resource 'resource'.
void SetResourceOutput(int index, XlaResource* resource);
// Sets `*type` and `*shape` to the current type and shape of a variable's
// value.
Status GetVariableTypeAndShape(int index, DataType* type,
@ -158,20 +164,10 @@ class XlaOpKernelContext {
Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
// Assigns the value `handle` to the variable referenced by input
// `variable_index`. Marks the operator as having side effects.
Status AssignVariable(int variable_index, DataType type,
// `input_index`. Marks the operator as having side effects.
Status AssignVariable(int input_index, DataType type,
const xla::ComputationDataHandle& handle);
// Sets '*variable' to the variable associated with input `index`.
Status GetVariableInput(int index, XlaVariable** variable);
// Sets output 'index' to be a reference to variable 'variable'. Used
// to propagate resource variables through the compilation.
void SetVariableOutput(int index, XlaVariable* variable);
// Returns a human-readable debug string describing 'variable_index'.
string VariableDebugString(int variable_index);
// Helper routines for the OP_REQUIRES macros
void CtxFailure(Status s);
void CtxFailureWithWarning(Status s);

View File

@ -34,11 +34,18 @@ const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT";
const char* const DEVICE_XLA_CPU = "XLA_CPU";
const char* const DEVICE_XLA_GPU = "XLA_GPU";
// Is platform 'id' supported by XLA?
static bool IsPlatformSupported(perftools::gputools::Platform::Id id) {
auto platform = perftools::gputools::MultiPlatformManager::PlatformWithId(id);
if (!platform.ok()) return false;
return xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()).ok();
static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def));
NodeDef node_def;
node_def.set_name("_XlaLaunch-op");
node_def.set_op("_XlaLaunch");
string kernel_class_name;
TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr,
&kernel_class_name));
VLOG(1) << "LaunchOpHasKernelForDevice"
<< " kernel_class_name: " << kernel_class_name;
return Status::OK();
}
XlaOpRegistry::XlaOpRegistry() = default;
@ -75,7 +82,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
// GetCompilationDevice is called.
static void* registration_init = [&registry]() {
mutex_lock lock(registry.mutex_);
if (IsPlatformSupported(perftools::gputools::host::kHostPlatformId)) {
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) {
DeviceRegistration& registration =
registry.compilation_devices_[DEVICE_CPU];
registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
@ -83,7 +90,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
registration.enable_jit_by_default = false;
registration.compile_resource_ops = false;
}
if (IsPlatformSupported(perftools::gputools::cuda::kCudaPlatformId)) {
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
DeviceRegistration& registration =
registry.compilation_devices_[DEVICE_GPU];
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;

View File

@ -46,21 +46,18 @@ xla_proto_library(
],
)
# This is a headers target that extra XLA devices can use to prevent
# circular dependencies. Devices that are compiled as separate shared
# objects can also use it to prevent linking of library code.
cc_header_only_library(
name = "xla_headers_lib",
visibility = ["//visibility:public"],
cc_library(
name = "execution_options_util",
srcs = [
"execution_options_util.cc",
],
hdrs = [
"execution_options_util.h",
],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_evaluator",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:stream_executor_headers_lib",
":xla_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
],
)
@ -602,3 +599,18 @@ filegroup(
),
visibility = ["//tensorflow:__subpackages__"],
)
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
cc_header_only_library(
name = "xla_headers_lib",
visibility = ["//visibility:public"],
deps = [
":xla_data_proto",
":xla_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:stream_executor_headers_lib",
],
)

View File

@ -114,7 +114,6 @@ cc_library(
"//tensorflow/compiler/xla/service:compile_only_service",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@llvm//:support",
],

View File

@ -971,6 +971,11 @@ ComputationDataHandle ComputationBuilder::Sign(
return UnaryOp(UNOP_SIGN, operand);
}
ComputationDataHandle ComputationBuilder::Cos(
const ComputationDataHandle& operand) {
return UnaryOp(UNOP_COS, operand);
}
ComputationDataHandle ComputationBuilder::Tanh(
const ComputationDataHandle& operand) {
return UnaryOp(UNOP_TANH, operand);
@ -1411,6 +1416,52 @@ ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding(
return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::BatchNormTraining(
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
const ComputationDataHandle& offset, float epsilon, int64 feature_index) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
}
BatchNormTrainingRequest request;
*request.mutable_operand() = operand;
*request.mutable_scale() = scale;
*request.mutable_offset() = offset;
request.set_epsilon(epsilon);
request.set_feature_index(feature_index);
OpRequest op_request;
*op_request.mutable_batch_norm_training_request() = request;
*op_request.mutable_computation() = computation_.handle();
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making BatchNormTraining request";
Status s = client_->stub()->Op(&op_request, &response);
return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::BatchNormInference(
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
const ComputationDataHandle& offset, const ComputationDataHandle& mean,
const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
// TODO(b/62843645): Implement BatchNormInference.
NoteError(Unimplemented("BatchNormInference is not implemented yet."));
return ComputationDataHandle();
}
ComputationDataHandle ComputationBuilder::BatchNormGrad(
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
const ComputationDataHandle& batch_mean,
const ComputationDataHandle& batch_var,
const ComputationDataHandle& grad_output, float epsilon,
int64 feature_index) {
// TODO(b/62843645): Implement BatchNormGrad.
NoteError(Unimplemented("BatchNormGrad is not implemented yet."));
return ComputationDataHandle();
}
ComputationDataHandle ComputationBuilder::CrossReplicaSum(
const ComputationDataHandle& operand) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
@ -1487,6 +1538,28 @@ ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::ReducePrecision(
const ComputationDataHandle& operand, const int exponent_bits,
const int mantissa_bits) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
}
ReducePrecisionRequest request;
*request.mutable_operand() = operand;
request.set_exponent_bits(exponent_bits);
request.set_mantissa_bits(mantissa_bits);
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_reduce_precision_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making reduce-precision request";
Status s = client_->stub()->Op(&op_request, &response);
return ParseOpResponse(s, &response);
}
void ComputationBuilder::Send(const ComputationDataHandle& operand,
const ChannelHandle& handle) {
if (!first_error_.ok() || !PrepareComputation().ok()) {

View File

@ -510,6 +510,9 @@ class ComputationBuilder {
// Enqueues a sign instruction onto the computation.
ComputationDataHandle Sign(const ComputationDataHandle& operand);
// Enqueues a cosine instruction onto the computation.
ComputationDataHandle Cos(const ComputationDataHandle& operand);
// Enqueues a tanh instruction onto the computation.
ComputationDataHandle Tanh(const ComputationDataHandle& operand);
@ -597,6 +600,11 @@ class ComputationBuilder {
const Computation& body,
const ComputationDataHandle& init);
// Enqueues a ReducePrecision node onto the computation.
ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand,
const int exponent_bits,
const int mantissa_bits);
// Enqueues a Send node onto the computation, to send the given operand to
// a Recv instruction that shares the same channel handle.
void Send(const ComputationDataHandle& operand, const ChannelHandle& handle);
@ -820,87 +828,80 @@ class ComputationBuilder {
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) {
return ConstantOp(
[value](Literal* literal) { LiteralUtil::PopulateR0(value, literal); });
return ConstantOp([value](Literal* literal) { literal->PopulateR0(value); });
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR1(
tensorflow::gtl::ArraySlice<NativeT> values) {
return ConstantOp([&values](Literal* literal) {
LiteralUtil::PopulateR1(values, literal);
});
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR1(values); });
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR1(int64 length,
NativeT value) {
return ConstantOp([length, value](Literal* literal) {
LiteralUtil::PopulateWithValue(value, {length}, literal);
literal->PopulateWithValue(value, {length});
});
}
inline ComputationDataHandle ComputationBuilder::ConstantR1(
const tensorflow::core::Bitmap& values) {
return ConstantOp([&values](Literal* literal) {
LiteralUtil::PopulateR1(values, literal);
});
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR1(values); });
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return ConstantOp([&values](Literal* literal) {
LiteralUtil::PopulateR2(values, literal);
});
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR2(values); });
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
LiteralUtil::PopulateR2FromArray2DWithLayout(values, layout, literal);
literal->PopulateR2FromArray2DWithLayout(values, layout);
});
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
const Array2D<NativeT>& values) {
return ConstantOp([&values](Literal* literal) {
LiteralUtil::PopulateR2FromArray2D(values, literal);
});
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR2FromArray2D(values); });
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
LiteralUtil::PopulateR3FromArray3DWithLayout(values, layout, literal);
literal->PopulateR3FromArray3DWithLayout(values, layout);
});
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
const Array3D<NativeT>& values) {
return ConstantOp([&values](Literal* literal) {
LiteralUtil::PopulateR3FromArray3D(values, literal);
});
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR3FromArray3D(values); });
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
LiteralUtil::PopulateR4FromArray4DWithLayout(values, layout, literal);
literal->PopulateR4FromArray4DWithLayout(values, layout);
});
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
const Array4D<NativeT>& values) {
return ConstantOp([&values](Literal* literal) {
LiteralUtil::PopulateR4FromArray4D(values, literal);
});
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR4FromArray4D(values); });
}
} // namespace xla

View File

@ -32,6 +32,7 @@ cc_library(
srcs = ["testing.cc"],
hdrs = ["testing.h"],
deps = [
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -34,11 +35,11 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
client,
tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape)));
// TODO(b/26811613): Replace this when RNG is supported on all backends.
b.Broadcast(b.ConstantLiteral(LiteralUtil::One(shape.element_type())),
b.Broadcast(b.ConstantLiteral(Literal::One(shape.element_type())),
AsInt64Slice(shape.dimensions()));
Computation computation = b.Build().ConsumeValueOrDie();
ExecutionOptions execution_options;
auto execution_options = CreateDefaultExecutionOptions();
*execution_options.mutable_shape_with_output_layout() = shape;
return client->Execute(computation, /*arguments=*/{}, &execution_options)
.ConsumeValueOrDie();

View File

@ -77,4 +77,14 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const {
return execution_profile_;
}
ExecutableRunOptions& ExecutableRunOptions::set_device_assignment(
DeviceAssignment* device_assignment) {
device_assignment_ = device_assignment;
return *this;
}
DeviceAssignment* ExecutableRunOptions::device_assignment() const {
return device_assignment_;
}
} // namespace xla

View File

@ -40,6 +40,7 @@ struct ThreadPoolDevice;
namespace xla {
class DeviceMemoryAllocator;
class DeviceAssignment;
class ExecutionProfile;
// Class containing options for running a LocalExecutable.
@ -79,9 +80,14 @@ class ExecutableRunOptions {
ExecutionProfile* execution_profile() const;
ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
ExecutableRunOptions& set_device_assignment(
DeviceAssignment* device_assignment);
DeviceAssignment* device_assignment() const;
private:
DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1;
DeviceAssignment* device_assignment_ = nullptr;
perftools::gputools::Stream* stream_ = nullptr;
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;

View File

@ -1,6 +1,4 @@
<!--
@license
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,19 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
==============================================================================*/
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
<link rel="import" href="../polymer/polymer.html">
<link rel="import" href="../tf-imports/d3.html">
namespace xla {
<!--
tf-color-scale is a plumbing component that takes in an array of runs, and produces
an upward-bindable outColorScale, which is a color scale mapping from those runs to
a set of colors.
ExecutionOptions CreateDefaultExecutionOptions() {
ExecutionOptions execution_options;
*(execution_options.mutable_debug_options()) =
legacy_flags::GetDebugOptionsFromFlags();
return execution_options;
}
@element tf-color-scale
-->
<dom-module id="tf-color-scale">
<script src="palettes.js"></script>
<script src="colorScale.js"></script>
</dom-module>
} // namespace xla

View File

@ -0,0 +1,29 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
#include "tensorflow/compiler/xla/xla.pb.h"
namespace xla {
// Create a default ExecutionOptions proto; this proto has its debug options
// popupated to the default values taken from flags.
ExecutionOptions CreateDefaultExecutionOptions();
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_

View File

@ -73,26 +73,12 @@ cc_library(
deps =
[
":parse_flags_from_env",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "cpu_compiler_flags",
srcs = ["cpu_compiler_flags.cc"],
hdrs = ["cpu_compiler_flags.h"],
deps =
[
":parse_flags_from_env",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "cpu_runtime_flags",
srcs = ["cpu_runtime_flags.cc"],
@ -128,30 +114,6 @@ cc_library(
],
)
cc_library(
name = "gpu_compiler_flags",
srcs = ["gpu_compiler_flags.cc"],
hdrs = ["gpu_compiler_flags.h"],
deps = [
":parse_flags_from_env",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "gpu_backend_lib_flags",
srcs = ["gpu_backend_lib_flags.cc"],
hdrs = ["gpu_backend_lib_flags.h"],
deps = [
":parse_flags_from_env",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "stream_assignment_flags",
srcs = ["stream_assignment_flags.cc"],
@ -175,28 +137,6 @@ cc_library(
],
)
cc_library(
name = "alias_analysis_flags",
srcs = ["alias_analysis_flags.cc"],
hdrs = ["alias_analysis_flags.h"],
deps = [
":parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "llvm_util_flags",
srcs = ["llvm_util_flags.cc"],
hdrs = ["llvm_util_flags.h"],
deps = [
":parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "service_flags",
srcs = ["service_flags.cc"],

View File

@ -1,62 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legacy flags for XLA's alias_analysis module.
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <vector>
#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static AliasAnalysisFlags* flags;
static std::vector<tensorflow::Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new AliasAnalysisFlags;
flags->xla_emit_alias_scope = true;
flag_list = new std::vector<tensorflow::Flag>({
tensorflow::Flag("xla_emit_alias_scope", &flags->xla_emit_alias_scope,
"Use buffer analysis to refine alias-analysis."),
});
ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with XLA's alias_analysis
// module.
void AppendAliasAnalysisFlags(std::vector<tensorflow::Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
// Return a pointer to the AliasAnalysisFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
AliasAnalysisFlags* GetAliasAnalysisFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace xla

View File

@ -1,46 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_
// Legacy flags for XLA's alias_analysis module.
#include <vector>
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Append to *flag_list flag definitions associated with XLA's alias_analysis
// module.
void AppendAliasAnalysisFlags(std::vector<tensorflow::Flag>* flag_list);
// The values of flags associated with XLA's alias_analysis module.
typedef struct {
bool xla_emit_alias_scope; // Use buffer analysis to refine alias-analysis.
} AliasAnalysisFlags;
// Return a pointer to the AliasAnalysisFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
AliasAnalysisFlags* GetAliasAnalysisFlags();
} // namespace legacy_flags
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_

View File

@ -1,68 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legacy flags for XLA's cpu_compiler module.
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <vector>
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static CpuCompilerFlags* flags;
static std::vector<tensorflow::Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new CpuCompilerFlags;
flags->xla_cpu_embed_ir = false;
flags->xla_cpu_dump_debug_json_to = "";
flag_list = new std::vector<tensorflow::Flag>({
tensorflow::Flag(
"xla_cpu_embed_ir", &flags->xla_cpu_embed_ir,
"Embed the LLVM IR module string in the resultant CpuExecutable."),
tensorflow::Flag("xla_cpu_dump_debug_json_to",
&flags->xla_cpu_dump_debug_json_to,
"Dump debug JSON to this directory."),
});
ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with XLA's cpu_compiler
// module.
void AppendCpuCompilerFlags(std::vector<tensorflow::Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
// Return a pointer to the CpuCompilerFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
CpuCompilerFlags* GetCpuCompilerFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace xla

View File

@ -1,49 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_
// Legacy flags for the XLA's cpu_compiler module.
#include <vector>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Append to *flag_list flag definitions associated with XLA's cpu_compiler
// module.
void AppendCpuCompilerFlags(std::vector<tensorflow::Flag>* flag_list);
// The values of flags associated with XLA's cpu_compiler module.
typedef struct {
bool xla_cpu_embed_ir; // Embed the LLVM IR module string in the resultant
// CpuExecutable
string xla_cpu_dump_debug_json_to; // Dump debug JSON to this directory.
} CpuCompilerFlags;
// Return a pointer to the CpuCompilerFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
CpuCompilerFlags* GetCpuCompilerFlags();
} // namespace legacy_flags
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_

View File

@ -28,6 +28,12 @@ struct DebugOptionsFlags {
string xla_disable_hlo_passes;
bool xla_enable_fast_math;
int32 xla_backend_optimization_level;
bool xla_embed_ir_in_executable;
string xla_dump_debug_json_to;
string xla_gpu_cuda_data_dir;
bool xla_gpu_ftz;
string xla_backend_extra_options;
};
@ -44,7 +50,11 @@ void AllocateFlags() {
flag_values->xla_generate_hlo_graph = "";
flag_values->xla_disable_hlo_passes = "";
flag_values->xla_enable_fast_math = true;
flag_values->xla_backend_optimization_level = 2;
flag_values->xla_backend_optimization_level = 3;
flag_values->xla_embed_ir_in_executable = false;
flag_values->xla_dump_debug_json_to = "";
flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib";
flag_values->xla_gpu_ftz = false;
flag_values->xla_backend_extra_options = "";
flag_objects = new std::vector<tensorflow::Flag>(
@ -52,7 +62,6 @@ void AllocateFlags() {
"xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph,
"HLO modules matching this regex will be dumped to a .dot file "
"throughout various stages in compilation."),
tensorflow::Flag(
"xla_enable_fast_math", &flag_values->xla_enable_fast_math,
"Enable unsafe fast-math optimizations in the compiler; "
@ -61,18 +70,31 @@ void AllocateFlags() {
"xla_backend_optimization_level",
&flag_values->xla_backend_optimization_level,
"Numerical optimization level for the XLA compiler backend."),
tensorflow::Flag(
"xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes,
"Comma-separated list of hlo passes to be disabled. These names "
"must exactly match the passes' names; no whitespace around "
"commas."),
tensorflow::Flag("xla_embed_ir_in_executable",
&flag_values->xla_embed_ir_in_executable,
"Embed the compiler IR as a string in the executable."),
tensorflow::Flag("xla_gpu_cuda_data_dir",
&flag_values->xla_gpu_cuda_data_dir,
"If non-empty, speficies a local directory containing "
"ptxas and nvvm libdevice files; otherwise we use "
"those from runfile directories."),
tensorflow::Flag("xla_gpu_ftz", &flag_values->xla_gpu_ftz,
"If true, flush-to-zero semantics are enabled in the "
"code generated for GPUs."),
tensorflow::Flag(
"xla_dump_debug_json_to", &flag_values->xla_dump_debug_json_to,
"Dump compilation artifacts as JSON into this directory."),
tensorflow::Flag("xla_backend_extra_options",
&flag_values->xla_backend_extra_options,
"Extra options to pass to a backend; "
"comma-separated list of 'key=val' strings (=val "
"may be omitted); no whitespace around commas."),
"may be omitted); no whitespace around commas.")});
tensorflow::Flag(
"xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes,
"Comma-separated list of HLO passes to be disabled. These names "
"must exactly match the passes' names; "
"no whitespace around commas.")});
ParseFlagsFromEnv(*flag_objects);
}
@ -99,6 +121,11 @@ xla::DebugOptions GetDebugOptionsFromFlags() {
options.set_xla_enable_fast_math(flag_values->xla_enable_fast_math);
options.set_xla_backend_optimization_level(
flag_values->xla_backend_optimization_level);
options.set_xla_embed_ir_in_executable(
flag_values->xla_embed_ir_in_executable);
options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to);
options.set_xla_gpu_cuda_data_dir(flag_values->xla_gpu_cuda_data_dir);
options.set_xla_gpu_ftz(flag_values->xla_gpu_ftz);
std::vector<string> extra_options_parts =
tensorflow::str_util::Split(flag_values->xla_backend_extra_options, ',');

View File

@ -1,88 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legacy flags for XLA's gpu_backend_lib module.
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <vector>
#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static GpuBackendLibFlags* flags;
static std::vector<tensorflow::Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new GpuBackendLibFlags;
flags->dump_temp_products_to = "";
flags->ftz = false;
flags->fma = true;
flags->verbose_ptx_asm = false;
flags->kernel = "";
flags->llvm_dump_passes = false;
flags->llvm_cl_opts = "";
flags->dump_ir_before_passes = false;
flags->opt_level = 3;
flag_list = new std::vector<tensorflow::Flag>({
tensorflow::Flag("dump_temp_products_to", &flags->dump_temp_products_to,
"dump temporary compilation products to this directory. "
"If empty, no dump is produced"),
tensorflow::Flag("ftz", &flags->ftz, "flush to zero semantics"),
tensorflow::Flag("fma", &flags->fma, "use FMA synthesis"),
tensorflow::Flag("verbose_ptx_asm", &flags->verbose_ptx_asm,
"emit PTX assembly with extra comments"),
tensorflow::Flag("kernel", &flags->kernel,
"only emit the IR and PTX for this kernel"),
tensorflow::Flag("llvm_dump_passes", &flags->llvm_dump_passes,
"dump the passes LLVM runs to stderr"),
tensorflow::Flag(
"llvm_cl_opts", &flags->llvm_cl_opts,
"comma-separated list of command line options to pass to "
"LLVM. For example, --llvm_cl_opts=--print-before=loop-unroll"),
tensorflow::Flag("dump_ir_before_passes", &flags->dump_ir_before_passes,
"dump the IR before each optimization pass in "
"sequentially-named files."),
tensorflow::Flag("opt_level", &flags->opt_level,
"optimization level (default to 3)"),
});
ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with XLA's gpu_backend_lib
// module.
void AppendGpuBackendLibFlags(std::vector<tensorflow::Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
// Return a pointer to the GpuBackendLibFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
GpuBackendLibFlags* GetGpuBackendLibFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace xla

View File

@ -1,55 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_
// Legacy flags for XLA's gpu_backend_lib module.
#include <vector>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Append to *flag_list flag definitions associated with XLA's gpu_backend_lib
// module.
void AppendGpuBackendLibFlags(std::vector<tensorflow::Flag>* flag_list);
// The values of flags associated with XLA's gpu_backend_lib module.
typedef struct {
string dump_temp_products_to; // temporary compilation products dir
bool ftz; // flush to zero semantics
bool fma; // use FMA synthesis
bool verbose_ptx_asm; // emit PTX assembly with extra comments
string kernel; // only emit the IR and PTX for this kernel
bool llvm_dump_passes; // dump the passes LLVM runs to stderr
string llvm_cl_opts; // comma-separated list of LLVM options
bool dump_ir_before_passes; // dump IR before each pass
int32 opt_level; // optimization level
} GpuBackendLibFlags;
// Return a pointer to the GpuBackendLibFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
GpuBackendLibFlags* GetGpuBackendLibFlags();
} // namespace legacy_flags
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_

View File

@ -1,76 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legacy flags for XLA's gpu_compiler module.
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <vector>
#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static GpuCompilerFlags* flags;
static std::vector<tensorflow::Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new GpuCompilerFlags;
flags->xla_gpu_embed_ir = false;
flags->xla_cuda_data_dir = "./cuda_sdk_lib";
flags->xla_gpu_dump_debug_json_to = "";
flag_list = new std::vector<tensorflow::Flag>({
tensorflow::Flag(
"xla_gpu_embed_ir", &flags->xla_gpu_embed_ir,
"Embed the LLVM IR module string in the resultant GpuExecutable."),
tensorflow::Flag(
"xla_cuda_data_dir", &flags->xla_cuda_data_dir,
"If non-empty, specifies a local directory containing ptxas and "
"nvvm libdevice files. Otherwise, by default, we use those from "
"runfile directories."),
tensorflow::Flag("xla_ptxas_path", &flags->xla_ptxas_path,
"The path to ptxas. Required to log stats of the ptx."),
tensorflow::Flag("xla_gpu_dump_debug_json_to",
&flags->xla_gpu_dump_debug_json_to,
"Dump debug JSON to this directory."),
});
ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with XLA's gpu_compiler
// module.
void AppendGpuCompilerFlags(std::vector<tensorflow::Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
// Return a pointer to the GpuCompilerFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
GpuCompilerFlags* GetGpuCompilerFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace xla

View File

@ -1,55 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_
// Legacy flags for XLA's gpu_compiler module.
#include <vector>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Append to *flag_list flag definitions associated with XLA's gpu_compiler
// module.
void AppendGpuCompilerFlags(std::vector<tensorflow::Flag>* flag_list);
// The values of flags associated with XLA's gpu_compiler module.
typedef struct {
bool xla_gpu_embed_ir; // Embed the LLVM IR module string in the resultant
// GpuExecutable.
string xla_cuda_data_dir; // If non-empty, specifies a local directory
// containing ptxas and nvvm libdevice files.
// Otherwise, by default, we use those from runfile
// directories.
string xla_ptxas_path; // The path to ptxas. Required to log stats of
// the ptx.
string xla_gpu_dump_debug_json_to; // Dump debug JSON to this directory.
} GpuCompilerFlags;
// Return a pointer to the GpuCompilerFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
GpuCompilerFlags* GetGpuCompilerFlags();
} // namespace legacy_flags
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_

View File

@ -1,63 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legacy flags for XLA's llvm_util module.
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <vector>
#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static LlvmUtilFlags* flags;
static std::vector<tensorflow::Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new LlvmUtilFlags;
flags->xla_emit_tbaa = true;
flag_list = new std::vector<tensorflow::Flag>({
tensorflow::Flag("xla_emit_tbaa", &flags->xla_emit_tbaa,
"Perform type-based alias analysis optimizations for "
"LLVM-based backends."),
});
ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with XLA's llvm_util
// module.
void AppendLlvmUtilFlags(std::vector<tensorflow::Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
// Return a pointer to the LlvmUtilFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
LlvmUtilFlags* GetLlvmUtilFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace xla

View File

@ -1,46 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_
// Legacy flags for XLA's llvm_util module.
#include <vector>
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Append to *flag_list flag definitions associated with XLA's llvm_util module.
void AppendLlvmUtilFlags(std::vector<tensorflow::Flag>* flag_list);
// The values of flags associated with XLA's llvm_util module.
typedef struct {
bool xla_emit_tbaa; // Perform type-based alias analysis optimizations for
// LLVM-based backends.
} LlvmUtilFlags;
// Return a pointer to the LlvmUtilFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
LlvmUtilFlags* GetLlvmUtilFlags();
} // namespace legacy_flags
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_

View File

@ -321,6 +321,7 @@ Status Literal::Copy(const Literal& src_literal,
}
std::unique_ptr<Literal> Literal::Relayout(const Layout& layout) const {
CHECK(ShapeUtil::IsArray(shape()));
std::unique_ptr<Literal> result = CloneToUnique();
*result->mutable_shape()->mutable_layout() = layout;
@ -754,10 +755,30 @@ void Literal::EachCellAsString(
}
namespace {
template <typename NativeSrcT, typename NativeDestT>
std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
auto result_literal = MakeUnique<Literal>();
Shape* result_shape = result_literal->mutable_shape();
*result_shape = src_literal.shape();
result_shape->set_element_type(
primitive_util::NativeToPrimitiveType<NativeDestT>());
result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape));
tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
src_literal.GetArraySlice<NativeSrcT>();
tensorflow::gtl::MutableArraySlice<NativeDestT> dest_data =
result_literal->GetMutableArraySlice<NativeDestT>();
int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape());
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = static_cast<NativeDestT>(src_data[i]);
}
return result_literal;
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
return LiteralUtil::Convert<
return ConvertBetweenNativeTypes<
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
typename primitive_util::PrimitiveTypeToNative<
primitive_dest_type>::type>(src_literal);
@ -782,19 +803,20 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
#undef CONVERT_IF_TYPES_MATCH
// Other types are not yet supported.
default:
return tensorflow::errors::InvalidArgument(
"Unimplemented: ConvertIfDestTypeMatches for type " +
PrimitiveType_Name(src_literal.shape().element_type()));
return InvalidArgument(
"Unimplemented: Convert from type %s to type %s",
PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
PrimitiveType_Name(primitive_dest_type).c_str());
}
}
}
} // namespace
StatusOr<std::unique_ptr<Literal>> LiteralUtil::ConvertIfSrcTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (src_literal.shape().element_type()) {
StatusOr<std::unique_ptr<Literal>> Literal::Convert(
PrimitiveType primitive_dest_type) const {
switch (shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
case (type): \
return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
return ConvertIfDestTypeMatches<(type)>(*this, primitive_dest_type);
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
CONVERT_IF_DEST_TYPE_MATCHES(S8)
CONVERT_IF_DEST_TYPE_MATCHES(S32)
@ -807,9 +829,9 @@ StatusOr<std::unique_ptr<Literal>> LiteralUtil::ConvertIfSrcTypeMatches(
#undef CONVERT_IF_DEST_TYPE_MATCHES
// Other types are not yet supported.
default:
return tensorflow::errors::InvalidArgument(
"Unimplemented: ConvertIfSrcTypeMatches for type " +
PrimitiveType_Name(src_literal.shape().element_type()));
return InvalidArgument("Unimplemented: Convert from type %s to type %s",
PrimitiveType_Name(shape().element_type()).c_str(),
PrimitiveType_Name(primitive_dest_type).c_str());
}
}

View File

@ -251,7 +251,7 @@ class Literal {
*other = temp;
}
// CreatesCreate new literal of a given rank. To minimize ambiguity (for users
// Creates a new literal of a given rank. To minimize ambiguity (for users
// and the compiler) these CreateR[0-2] methods should explicitly specify the
// native type. For example:
//
@ -362,10 +362,10 @@ class Literal {
template <typename NativeT>
std::unique_ptr<Literal> Replicate(int64 times) const;
// Creates a literal by converting each element in this literal to a new
// type.
template <typename NativeSrcT, typename NativeDestT>
std::unique_ptr<Literal> Convert() const;
// Converts this literal to another primitive type. Returns an error if the
// conversion is not possible.
StatusOr<std::unique_ptr<Literal>> Convert(
PrimitiveType primitive_dest_type) const;
// Creates a literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
@ -444,10 +444,21 @@ class Literal {
template <typename NativeT>
void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
// Retrieves the mutable array slice interface which can be used to manipulate
// pre-allocated literal values.
// Returns a (Mutable)ArraySlice view of the array for this literal for the
// given NativeT (e.g., float). These functions map native type to XLA
// PrimitiveType via template specialization. The unspecialized forms below
// aborts to handle the error case where the given native type does not map to
// an XLA primitive type.
template <typename NativeT>
tensorflow::gtl::MutableArraySlice<NativeT> GetMutableArraySlice();
tensorflow::gtl::ArraySlice<NativeT> GetArraySlice() const {
static_assert(!std::is_same<NativeT, NativeT>::value,
"Cannot map native type to primitive type.");
}
template <typename NativeT>
tensorflow::gtl::MutableArraySlice<NativeT> GetMutableArraySlice() {
static_assert(!std::is_same<NativeT, NativeT>::value,
"Cannot map native type to primitive type.");
}
// Returns the element value at index (0, ..., 0), however many zeroes are
// required for that index.
@ -588,17 +599,6 @@ class Literal {
bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
private:
// Returns an ArraySlice view of the array for this literal for the given
// NativeT (e.g., float). These functions map native type to XLA PrimitiveType
// via template specialization. The unspecialized forms below aborts to handle
// the error case where the given native type does not map to an XLA primitive
// type.
template <typename NativeT>
tensorflow::gtl::ArraySlice<NativeT> GetArraySlice() const {
static_assert(!std::is_same<NativeT, NativeT>::value,
"Cannot map native type to primitive type.");
}
// Copy from a LiteralProto instance.
void CopyFromProto(const LiteralProto& literal_proto);
@ -646,544 +646,6 @@ class Literal {
std::vector<Literal> tuple_literals_;
};
// Utility class for dealing with XLA literal values. Most methods are
// templated by native (host) type which corresponds to a unique XLA
// PrimitiveType. See ComputationBuilder for details. Not all primitive types
// defined in xla_data.proto have a corresponding native type or even have a
// storage location in the Literal proto yet (for example, primitive type F16).
//
// TODO(dnovillo) - All functions in this class simply redirect to the
// corresponding function in class Literal. Remove this class after converting
// all user code to use Literal directly.
class LiteralUtil {
public:
// Creates new literal of a given rank. To minimize ambiguity (for users and
// the compiler) these CreateR[0-2] methods should explicitly specify the
// native type. For example:
//
// CreateR1<float>({1.0, 42.0});
// CreateR2<uint32>({{1, 2}, {3, 4}});
//
// The variants not ending with WithLayout use the default XLA layout for the
// literal's linear representation in memory.
template <typename NativeT>
static std::unique_ptr<Literal> CreateR0(NativeT value) {
return Literal::CreateR0(value);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR1(
tensorflow::gtl::ArraySlice<NativeT> values) {
return Literal::CreateR1(values);
}
static std::unique_ptr<Literal> CreateR1(
const tensorflow::core::Bitmap& values) {
return Literal::CreateR1(values);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return Literal::CreateR2(values);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
return Literal::CreateR2WithLayout(values, layout);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
values) {
return Literal::CreateR3(values);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3WithLayout(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout) {
return Literal::CreateR3WithLayout(values, layout);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values) {
return Literal::CreateR4(values);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
const Layout& layout) {
return Literal::CreateR4WithLayout(values, layout);
}
// Creates a new Literal object with the shape specified as parameter.
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
static std::unique_ptr<Literal> CreateFromShape(const Shape& shape) {
return Literal::CreateFromShape(shape);
}
// Creates a new Literal object with its values havings the primitive_type
// type, and with dimensions defined by the dimensions parameter.
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
static std::unique_ptr<Literal> CreateFromDimensions(
PrimitiveType primitive_type,
tensorflow::gtl::ArraySlice<int64> dimensions) {
return Literal::CreateFromDimensions(primitive_type, dimensions);
}
// Copies the values from src_literal, starting at src_base shape indexes,
// to dest_literal, starting at dest_base, where the copy size in each
// dimension is specified by copy_size.
//
// The src_literal and dest_literal must have the same primitive type,
// src_base+copy_size must fit the source literal dimensions, as well as
// dest_base+copy_size must fit the destination literal dimensions.
static Status Copy(const Literal& src_literal,
tensorflow::gtl::ArraySlice<int64> src_base,
Literal* dest_literal,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) {
return dest_literal->Copy(src_literal, src_base, dest_base, copy_size);
}
// Creates a new value that has the equivalent value as literal, but conforms
// to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major
// dimension layout can be re-laid-out as {1, 0} minor-to-major dimension
// layout and the value in the cell at any given logical index (i0, i1) will
// be the same.
//
// Note: this is useful when the client wants to ensure that a value placed in
// the XLA allocation tracker has a particular layout; for efficiency
// purposes or avoiding unimplemented operation/layout combinations.
static std::unique_ptr<Literal> Relayout(const Literal& literal,
const Layout& new_layout) {
return literal.Relayout(new_layout);
}
// Reshapes literal 'input' to have 'shape'. Both the original shape and
// 'shape' must contain the same number of elements. The implementation
// currently only supports monotonic dim0-major layouts.
static StatusOr<std::unique_ptr<Literal>> Reshape(
const xla::Literal& input, tensorflow::gtl::ArraySlice<int64> shape) {
return input.Reshape(shape);
}
// Creates a new literal by reordering the dimensions of the original literal.
// The given `permutation` must be a permutation of the dimension numbers
// in the original literal, and it specifies the order of the new dimensions
// in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
static std::unique_ptr<Literal> Transpose(
const Literal& literal, tensorflow::gtl::ArraySlice<int64> permutation) {
return literal.Transpose(permutation);
}
// Creates a sub-array from the given literal by extracting the indices
// [start_index, limit_index) of each dimension. The result literal has the
// same rank and layout as for the given literal. The number of indices in
// start_indices and limit_indices must be the rank of the literal, and the
// indices follow the order of the dimensions.
static std::unique_ptr<Literal> Slice(
const Literal& literal, tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices) {
return literal.Slice(start_indices, limit_indices);
}
// Creates a literal with a prepended dimension with bound "times"; e.g. a
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input
// literal replicated four times.
template <typename NativeT>
static std::unique_ptr<Literal> Replicate(const Literal& input, int64 times) {
return input.Replicate<NativeT>(times);
}
// Creates a literal by converting each element in an original literal to a
// new type.
template <typename NativeSrcT, typename NativeDestT>
static std::unique_ptr<Literal> Convert(const Literal& literal) {
return literal.Convert<NativeSrcT, NativeDestT>();
}
// Convert a literal to another primitive type, but only if the literal
// type is connvertable into the destination type
static StatusOr<std::unique_ptr<Literal>> ConvertIfSrcTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type);
// Creates a literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type) {
return Literal::Zero(primitive_type);
}
// Creates a literal value one of the given primitive type.
static Literal One(PrimitiveType primitive_type) {
return Literal::One(primitive_type);
}
// Creates a literal value containing the minimum value of the given
// primitive type. For floating-point types, returns -inf.
static Literal MinValue(PrimitiveType primitive_type) {
return Literal::MinValue(primitive_type);
}
// Creates a literal value containing the maximum value of the given
// primitive type. For floating-point types, returns inf.
static Literal MaxValue(PrimitiveType primitive_type) {
return Literal::MaxValue(primitive_type);
}
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
static std::unique_ptr<Literal> CreateFullWithMonotonicDim0MajorLayout(
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
return Literal::CreateFullWithMonotonicDim0MajorLayout(dimensions, value);
}
// Creates a new literal from an array. The variants not ending with
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return Literal::CreateR2FromArray2D(values);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return Literal::CreateR2FromArray2DWithLayout(values, layout);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return Literal::CreateR3FromArray3D(values);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return Literal::CreateR3FromArray3DWithLayout(values, layout);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return Literal::CreateR4FromArray4D(values);
}
template <typename NativeT>
static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) {
return Literal::CreateR4FromArray4DWithLayout(values, layout);
}
// Creates a new vector of U8s literal value from a string.
static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value) {
return Literal::CreateR1U8(value);
}
// Creates a linspace-populated literal with the given number of rows and
// columns.
static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
int64 rows, int64 cols) {
return Literal::CreateR2F32Linspace(from, to, rows, cols);
}
// Creates a literal that projects the (x, y) dimensions given in values into
// the z dimension given by "projection".
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection) {
return Literal::CreateR3Projected(values, projection);
}
// Creates a literal that projects the (x, y) dimensions given in values into
// the z and p dimensions given.
template <typename NativeT>
static std::unique_ptr<Literal> CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z) {
return Literal::CreateR4Projected(values, projection_p, projection_z);
}
// Clones literal into an owned unique_ptr version.
static std::unique_ptr<Literal> CloneToUnique(const Literal& literal) {
return literal.CloneToUnique();
}
// Returns the linear index of the given index within the literal's
// element_type repeated field.
static int64 LinearIndex(const Literal& literal,
tensorflow::gtl::ArraySlice<int64> multi_index) {
return literal.LinearIndex(multi_index);
}
// Gets or sets an element in the literal at the given index. The index is
// CHECKed against the dimension sizes.
template <typename NativeT>
static NativeT Get(const Literal& literal,
tensorflow::gtl::ArraySlice<int64> multi_index) {
return literal.Get<NativeT>(multi_index);
}
template <typename NativeT>
static void Set(Literal* literal,
tensorflow::gtl::ArraySlice<int64> multi_index,
NativeT value) {
literal->Set(multi_index, value);
}
// Retrieves the mutable array slice interface which can be used to manipulate
// pre-allocated literal values.
template <typename NativeT>
static tensorflow::gtl::MutableArraySlice<NativeT> GetMutableArraySlice(
Literal* literal) {
return literal->GetMutableArraySlice<NativeT>();
}
// Returns the element value at index (0, ..., 0), however many zeroes are
// required for that index.
template <typename NativeT>
static NativeT GetFirstElement(const Literal& literal) {
return literal.GetFirstElement<NativeT>();
}
// As Get(), but determines the correct type and converts the value
// into text.
static string GetAsString(const Literal& literal,
tensorflow::gtl::ArraySlice<int64> multi_index) {
return literal.GetAsString(multi_index);
}
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
static std::unique_ptr<Literal> MakeIdentityR2(int64 size) {
return Literal::MakeIdentityR2<NativeT>(size);
}
// Returns a tuple literal composed of given literals.
static std::unique_ptr<Literal> MakeTuple(
tensorflow::gtl::ArraySlice<const Literal*> elements) {
return Literal::MakeTuple(elements);
}
// Validates that the data payload of the literal matches the literal shape;
// if it does not, an appropriate status is returned.
static tensorflow::Status ValidateLiteral(const Literal& literal) {
return literal.ValidateLiteral();
}
// Returns a string representation of the literal value.
static string ToString(const Literal& literal) { return literal.ToString(); }
// Invokes the "per cell" callback for each element in the provided
// literal with the element's indices and a string representation of
// the element's value.
//
// This function is useful if you want a polymorphic representation
// of the tensor's elements (turning it to a string for something
// like representation in a protobuf).
static void EachCellAsString(
const Literal& literal,
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
const string& value)>& per_cell) {
literal.EachCellAsString(per_cell);
}
template <typename NativeT>
static void EachCell(
const Literal& literal,
std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
NativeT value)>
per_cell) {
literal.EachCell<NativeT>(per_cell);
}
// Templated methods which populate the given repeated field in the Literal
// proto with the given value(s). The Shape field of the Literal proto is set
// to match the array dimensions and type. Examples:
//
// // Populate with floats.
// Array2D<float> float_values = ...
// PopulateR2FromArray2D(values, literal);
//
// // Populate with int32s.
// PopulateR2({{1, 2}, {3, 4}}, literal);
//
template <typename NativeT>
static void PopulateR0(NativeT values, Literal* literal) {
literal->PopulateR0(values);
}
template <typename NativeT>
static void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values,
Literal* literal) {
literal->PopulateR1(values);
}
static void PopulateR1(const tensorflow::core::Bitmap& values,
Literal* literal) {
literal->PopulateR1(values);
}
template <typename NativeT>
static void PopulateR2(
std::initializer_list<std::initializer_list<NativeT>> values,
Literal* literal) {
literal->PopulateR2(values);
}
template <typename NativeT>
static void PopulateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout, Literal* literal) {
literal->PopulateR2WithLayout(values, layout);
}
template <typename NativeT>
static void PopulateR2FromArray2D(const Array2D<NativeT>& values,
Literal* literal) {
literal->PopulateR2FromArray2D(values);
}
template <typename NativeT>
static void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
const Layout& layout,
Literal* literal) {
literal->PopulateR2FromArray2DWithLayout(values, layout);
}
template <typename NativeT>
static void PopulateR3FromArray3D(const Array3D<NativeT>& values,
Literal* literal) {
literal->PopulateR3FromArray3D(values);
}
template <typename NativeT>
static void PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
const Layout& layout,
Literal* literal) {
literal->PopulateR3FromArray3DWithLayout(values, layout);
}
template <typename NativeT>
static void PopulateR4FromArray4D(const Array4D<NativeT>& values,
Literal* literal) {
literal->PopulateR4FromArray4D(values);
}
template <typename NativeT>
static void PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
const Layout& layout,
Literal* literal) {
literal->PopulateR4FromArray4DWithLayout(values, layout);
}
// Populates literal values by calling the generator function for every cell
// in the literal object.
template <typename NativeT>
static Status Populate(
Literal* literal,
const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
generator) {
return literal->Populate(generator);
}
// Creates a Literal of the given dimensions with all elements set to the
// given value.
template <typename NativeT>
static void PopulateWithValue(NativeT value,
tensorflow::gtl::ArraySlice<int64> dimensions,
Literal* literal) {
return literal->PopulateWithValue(value, dimensions);
}
// Returns a pointer to the underlying vector containing the array data. Use
// with care.
static const void* InternalData(const Literal& literal) {
return literal.InternalData();
}
static void* MutableInternalData(Literal* literal) {
return literal->MutableInternalData();
}
// Allocates space in the underlying vector of the literal sufficient to hold
// num_elements of the literal's primitive type. Values in the vector are set
// to zero. num_elements must equal the number of elements in the literals
// shape.
static void Reserve(int64 num_elements, Literal* literal) {
literal->Reserve(num_elements);
}
// Allocates space in the underlying vector of the literal sufficient to hold
// num_elements of the literal's primitive type and sets each element in the
// literal to the given value. num_elements must equal the number of elements
// in the literals shape.
template <typename NativeT>
static void Resize(int64 num_elements, NativeT value, Literal* literal) {
literal->Resize(num_elements, value);
}
// Returns true if the two given literals have the same shape and
// values. Layout is not considered in the comparison.
static bool Equal(const Literal& literal1, const Literal& literal2) {
return literal1.Equal(literal2);
}
// Returns whether every element in the given literal is equal to value.
//
// value is an int8 because we expect this to be called with small
// compile-time constants (0, -1, etc.) and so that whatever value you pass
// can be represented exactly by floating-point types as small as 16 bits.
//
// If value doesn't fit in literal's type, returns false. Values of 1/0 are
// considered equal to true/false; other values are not considered equal to
// true.
static bool IsAll(const Literal& literal, int8 value) {
return literal.IsAll(value);
}
// Like IsAll(const Literal&, int8), except we check whether the literal is
// equal to a particular floating-point number.
//
// If the literal is not a floating-point value, this always returns false.
//
// This casts value to the type of literal, then compares using ==. The usual
// admonishments about floating-point equality checks apply. We expect you to
// use this to check for values that can be expressed precisely as a float,
// e.g. -0.5.
static bool IsAllFloat(const Literal& literal, float value) {
return literal.IsAllFloat(value);
}
// Returns whether the literal is zero at the specified index. The literal
// must be an array.
static bool IsZero(const Literal& literal,
tensorflow::gtl::ArraySlice<int64> indices) {
return literal.IsZero(indices);
}
TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil);
};
// Declarations of template specializations for GetArraySlice and
// GetMutableArraySlice. The specializations map native type to XLA primitive
// type.
@ -1759,27 +1221,6 @@ void Literal::PopulateWithValue(NativeT value,
Resize<NativeT>(ShapeUtil::ElementsIn(shape()), value);
}
template <typename NativeSrcT, typename NativeDestT>
std::unique_ptr<Literal> Literal::Convert() const {
const Shape& this_shape = shape();
auto result_literal = MakeUnique<Literal>();
Shape* result_shape = result_literal->mutable_shape();
*result_shape = this_shape;
result_shape->set_element_type(
primitive_util::NativeToPrimitiveType<NativeDestT>());
result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape));
tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
GetArraySlice<NativeSrcT>();
tensorflow::gtl::MutableArraySlice<NativeDestT> dest_data =
result_literal->GetMutableArraySlice<NativeDestT>();
int64 num_elements = ShapeUtil::ElementsIn(this_shape);
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = static_cast<NativeDestT>(src_data[i]);
}
return result_literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal>
Literal::CreateFullWithMonotonicDim0MajorLayout(

File diff suppressed because it is too large Load Diff

View File

@ -58,8 +58,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
}
int64 elements = ShapeUtil::ElementsIn(shape);
LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
result.get());
result.get()->Resize(elements, std::numeric_limits<float>::quiet_NaN());
std::vector<float>* field = result->mutable_f32s();
char* data = tensorflow::bit_cast<char*>(field->data());
uint64 bytes = elements * sizeof(float);

View File

@ -52,7 +52,7 @@ class ReferenceUtilTest : public ::testing::Test {
TEST_F(ReferenceUtilTest, TransposeArray2D) {
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
auto actual_literal = Literal::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
*actual_literal, ErrorSpec(0.0001));
}
@ -62,7 +62,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
{7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f},
});
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
auto actual_literal = Literal::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
*actual_literal, ErrorSpec(0.0001));
}
@ -70,7 +70,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
auto actual_literal = Literal::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
ErrorSpec(0.0001));
}
@ -78,7 +78,7 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
auto actual_literal = Literal::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
ErrorSpec(0.0001));
}
@ -86,7 +86,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
TEST_F(ReferenceUtilTest, MapArray2D) {
auto identity = [](float value) { return log(exp(value)); };
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
auto actual_literal = Literal::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
ErrorSpec(0.0001));
}
@ -96,7 +96,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
return value + row + col;
};
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
auto actual_literal = Literal::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
*actual_literal, ErrorSpec(0.0001));
}
@ -107,7 +107,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
input->FillWithMultiples(1.0f);
auto multiply_by_two = [](float value) { return 2 * value; };
auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two);
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
auto actual_literal = Literal::CreateR4FromArray4D(*result);
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.FillWithMultiples(2.0f);
@ -124,7 +124,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width);
};
auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index);
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
auto actual_literal = Literal::CreateR4FromArray4D(*result);
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.Fill(0.0f);
@ -161,7 +161,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
}));
// clang-format on
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@ -195,7 +195,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
}));
// clang-format on
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@ -247,7 +247,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
}});
// clang-format on
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@ -296,7 +296,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
Array4D<float> expected({{{{2514, 2685}}}});
// clang-format on
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
auto actual_literal = Literal::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@ -309,7 +309,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
auto actual = ReferenceUtil::ApplyElementwise2D(
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
auto actual_literal = Literal::CreateR2FromArray2D(*actual);
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
*actual_literal, ErrorSpec(0.0001));
}

View File

@ -112,6 +112,7 @@ cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
@ -330,6 +331,7 @@ cc_library(
hdrs = ["backend.h"],
deps = [
":compiler",
":computation_placer",
":device_memory_allocator",
":platform_util",
":pool",
@ -338,7 +340,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@ -382,6 +383,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
"//tensorflow/compiler/xla/legacy_flags:service_flags",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:lib",
@ -416,6 +418,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:service_flags",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@ -948,6 +951,26 @@ cc_test(
],
)
cc_library(
name = "computation_placer",
srcs = ["computation_placer.cc"],
hdrs = ["computation_placer.h"],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
alwayslink = True, # Contains per-platform computation placer registration
)
cc_library(
name = "generic_transfer_manager",
srcs = ["generic_transfer_manager.cc"],
@ -1165,6 +1188,7 @@ cc_library(
deps = [
":call_graph",
":hlo",
":hlo_ordering",
":liveness_util",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
@ -1398,7 +1422,6 @@ cc_library(
":call_graph",
":flatten_call_graph",
":hlo",
":hlo_cost_analysis",
":hlo_dce",
":hlo_ordering",
":liveness_util",
@ -1572,10 +1595,8 @@ cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
],
)
@ -1777,7 +1798,6 @@ cc_library(
":hlo",
":hlo_proto",
"//tensorflow/compiler/xla:status",
"//tensorflow/core:lib",
],
)

View File

@ -48,7 +48,7 @@ namespace {
// Returns whether operand is a literal with the given value.
bool IsLiteralWithValue(const HloInstruction* operand, int8 value) {
return operand->opcode() == HloOpcode::kConstant &&
LiteralUtil::IsAll(operand->literal(), value);
operand->literal().IsAll(value);
}
bool IsAll(const HloInstruction* op, int8 value) {
@ -126,10 +126,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleConvert(HloInstruction* convert,
HloInstruction* operand) override;
Status HandleConvert(HloInstruction* convert) override;
Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
HloInstruction* rhs, const Window& window) override;
@ -179,11 +178,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs,
HloInstruction* rhs) override;
Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
HloInstruction* rhs) override;
Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs,
HloInstruction* rhs) override;
Status HandleMaximum(HloInstruction* maximum) override;
Status HandleMinimum(HloInstruction* minimum) override;
// Returns whether algebraic simplification has occurred.
const bool changed() const { return changed_; }
@ -334,16 +330,16 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add,
return Status::OK();
}
Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy,
HloInstruction* operand) {
Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
// If a copy feeds a copy, make it a single copy.
if (operand->opcode() == HloOpcode::kCopy) {
if (copy->operand(0)->opcode() == HloOpcode::kCopy) {
return ReplaceWithNewInstruction(
copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy,
operand->operands()[0]));
copy, HloInstruction::CreateUnary(
copy->shape(), HloOpcode::kCopy,
copy->mutable_operand(0)->mutable_operand(0)));
}
// All copies can be eliminated (assuming layout constraints are satisified).
ReplaceInstructionIfSameShape(copy, operand);
ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0));
return Status::OK();
}
@ -469,7 +465,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
ShapeUtil::HasZeroElements(lhs->shape()) ||
ShapeUtil::HasZeroElements(rhs->shape())) {
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
}
@ -507,7 +503,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
computation_->parent(), F32, HloOpcode::kAdd);
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
{0}, add_reduce_computation));
@ -531,7 +527,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
computation_->parent(), F32, HloOpcode::kAdd);
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
HloInstruction* reduce;
if (ShapeUtil::Rank(rhs->shape()) == 1) {
auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
@ -571,7 +567,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
computation_->parent(), F32, HloOpcode::kAdd);
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(dot->shape().element_type(),
{lhs->shape().dimensions(0)}),
@ -792,12 +788,11 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
// A conversion to the same element type as the operand is a nop and can be
// removed. A conversion of a constant can be simplified by making a new
// constant.
Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert,
HloInstruction* operand) {
PrimitiveType src_type = operand->shape().element_type();
Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
PrimitiveType src_type = convert->operand(0)->shape().element_type();
PrimitiveType dest_type = convert->shape().element_type();
if (src_type == dest_type) {
return ReplaceInstruction(convert, operand);
return ReplaceInstruction(convert, convert->mutable_operand(0));
}
return Status::OK();
}
@ -897,8 +892,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power,
HloInstruction* rhs) {
VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique(
LiteralUtil::One(power->shape().element_type())));
auto one = HloInstruction::CreateConstant(
Literal::One(power->shape().element_type()).CloneToUnique());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
@ -923,9 +918,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power,
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CloneToUnique(
LiteralUtil::One(rhs->shape().element_type()))));
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
Literal::One(rhs->shape().element_type()).CloneToUnique()));
return ReplaceWithNewInstruction(
power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
one, lhs));
@ -1008,7 +1002,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
// dimension.
if (ShapeUtil::HasZeroElements(reshape->shape())) {
auto empty_constant = HloInstruction::CreateConstant(
LiteralUtil::CreateFromShape(reshape->shape()));
Literal::CreateFromShape(reshape->shape()));
return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
}
@ -1208,8 +1202,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
// try to get more fancy about proving equivalence in cases beyond that.
if (pad_value->opcode() != HloOpcode::kConstant ||
reduce_init_value->opcode() != HloOpcode::kConstant ||
!LiteralUtil::Equal(pad_value->literal(),
reduce_init_value->literal())) {
!pad_value->literal().Equal(reduce_init_value->literal())) {
VLOG(10) << "Not folding pad into reduce-window due to different pad "
"values.";
return Status::OK();
@ -1396,9 +1389,7 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
return true;
}
Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum,
HloInstruction* lhs,
HloInstruction* rhs) {
Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
// Match the following tree:
// min_operand operand
// \ /
@ -1429,9 +1420,7 @@ Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum,
return Status::OK();
}
Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum,
HloInstruction* lhs,
HloInstruction* rhs) {
Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
// Match the following tree:
// max_operand operand
// \ /

View File

@ -55,7 +55,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
@ -76,7 +76,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
HloInstruction* bcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
builder.AddInstruction(
@ -99,7 +99,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
HloInstruction::CreateConstant(Literal::CreateR1<float>({0, 0, 0})));
HloInstruction* bcast =
builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
builder.AddInstruction(
@ -123,7 +123,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
@ -145,7 +145,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
@ -167,7 +167,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
Literal::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
@ -300,7 +300,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
@ -315,7 +315,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->literal()), 1);
EXPECT_EQ(root->literal().GetFirstElement<float>(), 1);
}
// Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
@ -325,7 +325,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
@ -344,8 +344,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
<< ShapeUtil::HumanString(root->shape());
EXPECT_EQ(root->dimensions().size(), 0);
EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->operand(0)->literal()),
1);
EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
}
// Test that pow(A, 1) is simplified to A.
@ -355,7 +354,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
@ -378,7 +377,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* two = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(2)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
@ -401,7 +400,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* negative_one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(-1)));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
param0, negative_one));
@ -416,8 +415,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Divide(op::Constant(), param0));
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->operand(0)->literal()),
1);
EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
}
TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
@ -451,7 +449,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
@ -519,7 +517,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r1f32, "param1"));
HloInstruction* empty_literal = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
@ -550,7 +548,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* empty_literal = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
@ -735,7 +733,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
HloOpcode::kMaximum, movable_reshape, zero));
@ -1035,7 +1033,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
PaddingConfig no_padding;
for (int i = 0; i < 2; ++i) {
auto dimension = no_padding.add_dimensions();
@ -1066,7 +1064,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
PaddingConfig padding;
int64 low_padding[2] = {-1, -2};
int64 high_padding[2] = {2, -3};
@ -1376,9 +1374,9 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* min_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
HloInstruction* max_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32, HloOpcode::kMinimum, param0, min_value));
builder.AddInstruction(
@ -1406,9 +1404,9 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* min_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
HloInstruction* max_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32, HloOpcode::kMaximum, param0, max_value));
builder.AddInstruction(
@ -1437,9 +1435,9 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* min_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
HloInstruction* max_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kMaximum, param0, max_value));
builder.AddInstruction(
@ -1497,9 +1495,9 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* min_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
HloInstruction* max_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32, HloOpcode::kMaximum, param0, max_value));
HloInstruction* fmax = builder.AddInstruction(
@ -1566,7 +1564,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
HloComputation::Builder builder(TestName());
HloInstruction* forty_two = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
HloInstruction* broadcast =
@ -1614,7 +1612,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
padding.mutable_dimensions(3)->set_edge_padding_high(2);
HloInstruction* pad_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
@ -1645,7 +1643,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
const Shape reduce_window_shape =
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
HloInstruction* reduce_init_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
HloInstruction* reduce_window =
builder.AddInstruction(HloInstruction::CreateReduceWindow(
reduce_window_shape, pad, reduce_init_value, window,
@ -1714,9 +1712,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
HloComputation::Builder call_builder(TestName() + ".Call");
HloInstruction* zero = call_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
HloInstruction::CreateConstant(Literal::CreateR1<float>({0.0f})));
HloInstruction* one = call_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0f})));
builder.AddInstruction(
HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));

View File

@ -22,7 +22,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -51,13 +50,6 @@ perftools::gputools::Platform* BackendOptions::platform() const {
return platform_;
}
BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) {
number_of_replicas_ = number_of_replicas;
return *this;
}
int BackendOptions::number_of_replicas() const { return number_of_replicas_; }
BackendOptions& BackendOptions::set_intra_op_parallelism_threads(
int num_threads) {
intra_op_parallelism_threads_ = num_threads;
@ -85,20 +77,17 @@ struct Backend::EigenThreadPoolWrapper {
/* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
const BackendOptions& options) {
int64 replica_count = options.number_of_replicas();
if (replica_count == -1) {
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
replica_count = flags->xla_replicas;
}
perftools::gputools::Platform* platform = options.platform();
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
TF_ASSIGN_OR_RETURN(auto stream_executors,
PlatformUtil::GetStreamExecutors(platform));
TF_ASSIGN_OR_RETURN(auto transfer_manager,
TransferManager::GetForPlatform(platform));
TF_ASSIGN_OR_RETURN(auto computation_placer,
ComputationPlacer::GetForPlatform(platform));
std::unique_ptr<Backend> backend(
new Backend(replica_count, platform, compiler, stream_executors,
transfer_manager, options.intra_op_parallelism_threads()));
new Backend(platform, compiler, stream_executors, transfer_manager,
computation_placer, options.intra_op_parallelism_threads()));
return std::move(backend);
}
@ -132,34 +121,25 @@ StatusOr<Backend::StreamPtr> Backend::BorrowStream(
}
Backend::Backend(
int64 replica_count, perftools::gputools::Platform* platform,
Compiler* compiler,
perftools::gputools::Platform* platform, Compiler* compiler,
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
TransferManager* transfer_manager, int intra_op_parallelism_threads)
TransferManager* transfer_manager, ComputationPlacer* computation_placer,
int intra_op_parallelism_threads)
: platform_(platform),
compiler_(compiler),
transfer_manager_(transfer_manager),
replica_count_(replica_count) {
computation_placer_(computation_placer) {
// The given set of stream executors set may include invalid executors.
for (se::StreamExecutor* exec : stream_executors) {
if (exec != nullptr) {
stream_executors_.push_back(exec);
}
}
CHECK_GE(replica_count, 1) << "Must request at least 1 replica.";
// Create a memory allocator for the valid stream executors.
memory_allocator_ =
MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
// First check that there are some non-null stream executors to avoid issuing
// an error mentioning replicas in the common case of requesting just 1
// replica, which means no replication.
CHECK(!stream_executors_.empty())
<< "Service found no devices for backend " << platform_->Name() << '.';
CHECK_GE(stream_executors_.size(), replica_count)
<< "Requested more replicas than there are devices for backend "
<< platform_->Name() << '.';
if (platform->id() == se::host::kHostPlatformId) {
inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool(
@ -179,36 +159,6 @@ int Backend::default_device_ordinal() const {
return default_stream_executor()->device_ordinal();
}
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Backend::Replicas(
int device_ordinal) const {
if (stream_executors_[device_ordinal] == nullptr) {
return InvalidArgument("device %s not supported by XLA service",
device_name(device_ordinal).c_str());
}
// Find replica_count_ stream executors starting from the given device
// ordinal.
std::vector<perftools::gputools::StreamExecutor*> replicas;
for (se::StreamExecutor* exec : stream_executors_) {
CHECK(exec != nullptr);
if (exec->device_ordinal() >= device_ordinal) {
replicas.push_back(exec);
if (replicas.size() >= replica_count_) {
return replicas;
}
}
}
return InvalidArgument(
"Not enough devices for replicas for the device ordinal %d",
device_ordinal);
}
std::vector<perftools::gputools::StreamExecutor*> Backend::Replicas() const {
CHECK_GE(stream_executors_.size(), replica_count_);
return Replicas(default_device_ordinal()).ValueOrDie();
}
tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const {
return inter_op_thread_pool_.get();
}

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@ -46,12 +47,6 @@ class BackendOptions {
BackendOptions& set_platform(perftools::gputools::Platform* platform);
perftools::gputools::Platform* platform() const;
// Set the number of replicas to use when compiling replicated
// programs. The default is -1 meaning that the value is read from
// the xla_replicas flag.
BackendOptions& set_number_of_replicas(int number_of_replicas);
int number_of_replicas() const;
// Sets the thread pool size for parallel execution of an individual operator.
// The default value of -1 will result in initializing the thread pool with
// the number of threads equal to the number of cores in the system.
@ -60,7 +55,6 @@ class BackendOptions {
private:
perftools::gputools::Platform* platform_ = nullptr;
int number_of_replicas_ = -1;
int intra_op_parallelism_threads_ = -1;
};
@ -74,8 +68,7 @@ class Backend {
public:
using StreamPtr = Pool<perftools::gputools::Stream>::SmartPtr;
// Creates a new backend for the given platform with the given number of
// replicas.
// Creates a new backend.
static StatusOr<std::unique_ptr<Backend>> CreateBackend(
const BackendOptions& options);
@ -92,6 +85,7 @@ class Backend {
return memory_allocator_.get();
}
TransferManager* transfer_manager() const { return transfer_manager_; }
ComputationPlacer* computation_placer() const { return computation_placer_; }
// Returns the number of devices of the platform type which are visible. Not
// all of these devices may be usable by XLA.
@ -107,24 +101,13 @@ class Backend {
return stream_executors_;
}
// Returns the replicas for the default stream executor.
//
// When the number of replicas is R, the first R stream executors are assigned
// to the replicas of the default stream executor.
std::vector<perftools::gputools::StreamExecutor*> Replicas() const;
// Returns the replicas for the given device_ordinal. The given device ordinal
// is considered to be the first device ordinal among the replicas. Returns an
// error status if the stream executor for the given given device ordinal does
// not exist or if there are not enough stream executors for the replicas.
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
int device_ordinal) const;
// Return the stream executor for the given device ordinal.
// Returns the stream executor for the given device ordinal.
StatusOr<perftools::gputools::StreamExecutor*> stream_executor(
int device_ordinal) const;
// Return the stream executor for the default device ordinal.
// Returns the stream executor for the default device ordinal. This stream
// executor can only be used when the number of computations is 1 (replication
// can be > 1).
perftools::gputools::StreamExecutor* default_stream_executor() const {
CHECK(!stream_executors_.empty());
return stream_executors_[0];
@ -174,18 +157,19 @@ class Backend {
private:
struct EigenThreadPoolWrapper;
Backend(int64 replica_count, perftools::gputools::Platform* platform,
Compiler* compiler,
Backend(perftools::gputools::Platform* platform, Compiler* compiler,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors,
TransferManager* transfer_manager, int intra_op_parallelism_threads);
TransferManager* transfer_manager,
ComputationPlacer* computation_placer,
int intra_op_parallelism_threads);
Backend(const Backend&) = delete;
Backend& operator=(const Backend&) = delete;
perftools::gputools::Platform* platform_;
Compiler* compiler_;
TransferManager* transfer_manager_;
int64 replica_count_ = -1;
ComputationPlacer* computation_placer_;
// Vector of stream executors. stream_executors_[0] is the default executor.
std::vector<perftools::gputools::StreamExecutor*> stream_executors_;

View File

@ -1074,7 +1074,8 @@ void BufferAssigner::AddSetToColocatedBufferSets(
// different while instructions.
void BufferAssigner::AddWhileSetToColocatedBufferSets(
const std::vector<const LogicalBuffer*>& colocated_set,
const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo,
const LogicalBuffer* while_init_buffer,
const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo,
const HloComputation& computation, const BufferLiveness& buffer_liveness,
const LogicalBuffer::SizeFunction& buffer_size,
std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
@ -1137,16 +1138,30 @@ void BufferAssigner::AddWhileSetToColocatedBufferSets(
continue;
}
// Skip predecessor set if the live range of any predecessor buffers
// overlaps with 'while_init_buffer'. Note that tuple element buffer
// forwarding can cause the same buffer to appear on both sides of the
// interference comparison below.
if (std::any_of(
predecessor_while_buffers.begin(), predecessor_while_buffers.end(),
[while_init_buffer, &buffer_liveness](const LogicalBuffer* buffer) {
return while_init_buffer->id() != buffer->id() &&
buffer_liveness.MayInterfere(*while_init_buffer, *buffer);
})) {
// Skip predecessor set if the live range of any predecessor
// buffers overlaps with 'while_init_buffer' or
// 'while_result_buffer' (we need to check both since they're
// aliased together, but the points-to analysis is unaware of this
// aliasing). Note that tuple element buffer forwarding can cause
// the same buffer to appear on both sides of the interference
// comparison below.
auto may_interfere_with_init_or_result = [&](const LogicalBuffer* buffer) {
if (while_init_buffer->id() != buffer->id() &&
buffer_liveness.MayInterfere(*while_init_buffer, *buffer)) {
return true;
}
if (while_result_buffer->id() != buffer->id() &&
buffer_liveness.MayInterfere(*while_result_buffer, *buffer)) {
return true;
}
return false;
};
if (std::any_of(predecessor_while_buffers.begin(),
predecessor_while_buffers.end(),
may_interfere_with_init_or_result)) {
continue;
}
@ -1209,8 +1224,8 @@ void BufferAssigner::BuildColocatedBufferSets(
AddBufferToColocatedSet(while_hlo->operand(0), index,
points_to_analysis, &colocated_set);
// Add while.result.
AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
&colocated_set);
auto* result_buffer = AddBufferToColocatedSet(
while_hlo, index, points_to_analysis, &colocated_set);
// Add while.cond.parameter.
AddBufferToColocatedSet(
while_hlo->while_condition()->parameter_instruction(0), index,
@ -1224,8 +1239,9 @@ void BufferAssigner::BuildColocatedBufferSets(
while_hlo->while_body()->root_instruction(), index,
points_to_analysis, &colocated_set);
AddWhileSetToColocatedBufferSets(
colocated_set, init_buffer, while_hlo, *computation,
buffer_liveness, buffer_size, colocated_buffer_sets);
colocated_set, init_buffer, result_buffer, while_hlo,
*computation, buffer_liveness, buffer_size,
colocated_buffer_sets);
});
} else if (opcode == HloOpcode::kCall) {
const HloInstruction* call_hlo = instruction;

View File

@ -511,7 +511,8 @@ class BufferAssigner {
// colocated buffers for while instructions.
void AddWhileSetToColocatedBufferSets(
const std::vector<const LogicalBuffer*>& colocated_set,
const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo,
const LogicalBuffer* while_init_buffer,
const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo,
const HloComputation& computation, const BufferLiveness& buffer_liveness,
const LogicalBuffer::SizeFunction& buffer_size,
std::vector<ColocatedBufferSet>* colocated_buffer_sets);

View File

@ -105,7 +105,7 @@ class BufferAssignmentTest : public HloTestBase {
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
auto value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
return builder.Build();
@ -122,7 +122,7 @@ class BufferAssignmentTest : public HloTestBase {
const string& name) {
auto builder = HloComputation::Builder(name);
auto const4 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
auto index = builder.AddInstruction(
@ -147,9 +147,9 @@ class BufferAssignmentTest : public HloTestBase {
const string& name) {
auto builder = HloComputation::Builder(name);
auto const1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
auto indexc = builder.AddInstruction(
@ -264,7 +264,7 @@ static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
TEST_F(BufferAssignmentTest, ScalarConstant) {
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@ -278,9 +278,9 @@ TEST_F(BufferAssignmentTest, BufferForConst) {
// no buffers assigned, and their consumer has a buffer.
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
Literal::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
auto module = CreateNewModule();
@ -298,7 +298,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) {
// This computation copies a constant to output.
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto copy = builder.AddInstruction(
HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
auto module = CreateNewModule();
@ -586,7 +586,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
auto exp2 = builder.AddInstruction(
HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
/*shape=*/f32vec10_,
/*operand=*/exp2,
@ -634,9 +634,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) {
// Creates the main kernel and verifies instruction counts.
auto builder = HloComputation::Builder(TestName());
auto const3 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
@ -1075,9 +1075,8 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) {
// Test that a tuple constant which is forwarded to the computation output is
// properly handled.
auto builder = HloComputation::Builder(TestName());
builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
LiteralUtil::CreateR0<int64>(1).get()})));
builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple(
{Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()})));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@ -1369,9 +1368,9 @@ class WhileBufferAssignmentTest : public HloTestBase {
builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
auto ten = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
HloInstruction::CreateConstant(Literal::CreateR0<int>(10)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
return builder.Build();
@ -1429,7 +1428,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
HloInstruction::CreateParameter(2, data_shape_, "weights1"));
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto output1 = builder.AddInstruction(
@ -1484,7 +1483,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto output1 = builder.AddInstruction(
@ -1532,16 +1531,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) {
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param"));
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
sub_computation = module->AddEmbeddedComputation(builder.Build(add));
}
auto builder = HloComputation::Builder(TestName());
auto constant2 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
auto call1 = builder.AddInstruction(
HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
auto call2 = builder.AddInstruction(
@ -1565,6 +1564,104 @@ TEST_F(BufferAssignmentTest, TwoCalls) {
EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
}
static bool IsPostOrderTraversal(
const std::vector<const HloInstruction*>& sequence) {
tensorflow::gtl::FlatSet<const HloInstruction*> seen_so_far;
auto has_not_been_seen_yet = [&](const HloInstruction* instruction) {
return seen_so_far.count(instruction) == 0;
};
for (auto instruction : sequence) {
if (std::any_of(instruction->operands().begin(),
instruction->operands().end(), has_not_been_seen_yet) ||
std::any_of(instruction->control_predecessors().begin(),
instruction->control_predecessors().end(),
has_not_been_seen_yet)) {
return false; // Not a post order.
}
if (!seen_so_far.insert(instruction).second) {
return false; // Not a "traversal".
}
}
return true;
}
TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto module = MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder(TestName());
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto input0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape_, "input0"));
auto weights0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto input1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, data_shape_, "input1"));
auto weights1 = builder.AddInstruction(
HloInstruction::CreateParameter(3, data_shape_, "weights1"));
auto output1 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
auto cond =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({input0, weights0, output0}));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({input1, weights1, output1}));
auto while0 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
auto root_add = builder.AddInstruction(HloInstruction::CreateBinary(
while0->shape(), HloOpcode::kAdd, while0, while1));
module->AddEntryComputation(builder.Build());
RunCopyInsertion(module.get());
{
FlattenCallGraph flatten;
TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get()));
EXPECT_TRUE(result);
}
auto sequence =
CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
// To trigger b/38494731, we want a specific Hlo sequence for the
// root computation, so we overwrite that entry with a manually
// crafted sequence.
std::vector<const HloInstruction*> sequence_for_buffer_assigment = {
input1, weights1, one, output1, tuple1, while1, input0,
weights0, zero, output0, tuple0, while0, root_add};
// If this ASSERT_TRUE fails, we constructed a bogus sequence above
// and this test itself is buggy.
ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment));
sequence[module->entry_computation()] =
std::move(sequence_for_buffer_assigment);
auto assignment = BufferAssigner::Run(module.get(),
MakeUnique<SequentialHloOrdering>(
module.get(), sequence),
ByteSizeOf, 1)
.ConsumeValueOrDie();
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
}
// Test buffer assignment for while nodes with multiple uses.
// TODO(b/37245345): Fix buffer assignment for this case.
TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) {
@ -1577,7 +1674,7 @@ TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) {
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));

View File

@ -122,7 +122,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
if (b.instruction()->IsUserOf(alias.instruction()) &&
!CanShareOperandBufferWithUser(alias.instruction(), alias.index(),
b.instruction(), b.index(),
points_to_analysis())) {
&points_to_analysis())) {
return false;
}
}

View File

@ -397,13 +397,11 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
// the buffers containing {3} and 3 are dead.
auto builder = HloComputation::Builder(TestName());
auto inner_tuple0 =
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
LiteralUtil::CreateR0<int64>(1).get()});
auto inner_tuple1 =
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
auto inner_tuple0 = Literal::MakeTuple(
{Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()});
auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0<int64>(3).get()});
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
inner_tuple0->shape(), tuple_constant, 0));
@ -450,7 +448,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element0_shape, tuple_param0, 0));
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
@ -462,7 +460,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element1_shape, tuple_param0, 1));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
@ -513,7 +511,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element0_shape, tuple_param0, 0));
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
@ -585,7 +583,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
Literal::CreateR1<float>({2.f, 2.f, 2.f})));
HloInstruction* slice = nullptr;
if (update_uses_tuple_element1) {
// Create a slice instruction as an additional user of 'gte1'.
@ -596,7 +594,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
}
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
auto starts = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@ -715,7 +713,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
Literal::CreateR1<float>({2.f, 2.f, 2.f})));
if (tuple_element1_has_two_uses) {
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
@ -724,7 +722,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
}
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
auto starts = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));

View File

@ -133,6 +133,37 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
return nodes_[it->second];
}
bool CallGraph::DominatesHelper(
const HloComputation* a, const HloComputation* b,
tensorflow::gtl::FlatSet<const HloComputation*>* visited) const {
if (a == b || ContainsKey(*visited, b)) {
// The call graph is guaranteed to be acyclic so any previously visited node
// we encounter was already determined to be dominated.
return true;
}
const CallGraphNode& b_node = GetNode(b);
if (b_node.callers().empty()) {
// We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
return false;
}
// Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
visited->insert(b);
for (const HloComputation* b_caller : b_node.callers()) {
if (!DominatesHelper(a, b_caller, visited)) {
return false;
}
}
return true;
}
bool CallGraph::Dominates(const HloComputation* a,
const HloComputation* b) const {
tensorflow::gtl::FlatSet<const HloComputation*> visited;
return DominatesHelper(a, b, &visited);
}
namespace {
// Returns the call context of a computation which is called from contexts 'a'

View File

@ -189,6 +189,20 @@ class CallGraph {
Status VisitNodes(const VisitorFunction& visitor_func,
bool visit_unreachable_nodes = true) const;
// Returns true if 'a' dominates 'b' in the call graph. Computation 'a'
// dominates computation 'b' iff all callgraph paths in the caller-to-callee
// direction from a root computation to 'b' pass through computation
// 'a'. Trivially, a computation dominates itself.
bool Dominates(const HloComputation* a, const HloComputation* b) const;
// Returns whether 'instruction' is contained in 'computation' either directly
// ('instruction->parent' is 'computation') or indirectly ('computation'
// dominates 'instruction->parent' in the call graph).
bool InstructionIsNestedIn(const HloInstruction* instruction,
const HloComputation* computation) const {
return Dominates(computation, instruction->parent());
}
string ToString() const;
private:
@ -205,6 +219,13 @@ class CallGraph {
const VisitorFunction& visitor_func, const CallGraphNode& node,
tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const;
// Recursive helper for computing whether 'a' dominates 'b' in the call
// graph. 'b_ancestor' is the currently visited node (which starts at 'b'),
// and 'visited' is the set of computations which have been visited.
bool DominatesHelper(
const HloComputation* a, const HloComputation* b,
tensorflow::gtl::FlatSet<const HloComputation*>* visited) const;
// The HLO module represented by this call graph.
const HloModule* module_ = nullptr;

Some files were not shown because too many files have changed in this diff Show More