Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Avijit 2018-08-15 17:00:22 -07:00
commit bc6be507c7
437 changed files with 14182 additions and 5232 deletions

View File

@ -3,7 +3,7 @@
## Major Features And Improvements
* The `tf.lite` runtime now supports `complex64`.
* Initial Bigtable integration for `tf.data`.
* Initial [Google Cloud Bigtable integration](https://github.com/tensorflow/tensorflow/tree/r1.10/tensorflow/contrib/bigtable) for `tf.data`.
* Improved local run behavior in `tf.estimator.train_and_evaluate` which does not reload checkpoints for evaluation.
* `RunConfig` now sets device_filters to restrict how workers and PS can communicate. This can speed up training and ensure clean shutdowns in some situations. But if you have jobs that require communication between workers, you will have to set custom session_options in your `RunConfig`.
* Moved Distributions and Bijectors from `tf.contrib.distributions` to [Tensorflow Probability (TFP)](https://github.com/tensorflow/probability). `tf.contrib.distributions` is now deprecated and will be removed by the end of 2018.

View File

@ -839,15 +839,16 @@ def set_tf_cuda_version(environ_cp):
cuda_toolkit_path = cygpath(cuda_toolkit_path)
if is_windows():
cuda_rt_lib_path = 'lib/x64/cudart.lib'
cuda_rt_lib_paths = ['lib/x64/cudart.lib']
elif is_linux():
cuda_rt_lib_path = 'lib64/libcudart.so.%s' % tf_cuda_version
cuda_rt_lib_paths = ['%s/libcudart.so.%s' % (x, tf_cuda_version)
for x in ['lib64', 'lib/x86_64-linux-gnu']]
elif is_macos():
cuda_rt_lib_path = 'lib/libcudart.%s.dylib' % tf_cuda_version
cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version]
cuda_toolkit_path_full = os.path.join(cuda_toolkit_path, cuda_rt_lib_path)
if os.path.exists(cuda_toolkit_path_full):
break
cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths]
if any([os.path.exists(x) for x in cuda_toolkit_paths_full]):
break
# Reset and retry
print('Invalid path to CUDA %s toolkit. %s cannot be found' %
@ -1398,10 +1399,6 @@ def set_grpc_build_flags():
write_to_bazelrc('build --define grpc_no_ares=true')
def set_build_strip_flag():
write_to_bazelrc('build --strip=always')
def set_windows_build_flags(environ_cp):
"""Set Windows specific build options."""
# The non-monolithic build is not supported yet
@ -1560,7 +1557,6 @@ def main():
set_grpc_build_flags()
set_cc_opt_flags(environ_cp)
set_build_strip_flag()
if is_windows():
set_windows_build_flags(environ_cp)

View File

@ -124,12 +124,6 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "windows_msvc",
values = {"cpu": "x64_windows_msvc"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_tensorflow_py_deps",
define_values = {"no_tensorflow_py_deps": "true"},
@ -439,14 +433,14 @@ package_group(
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
"if_mkl_ml",
)
filegroup(
name = "intel_binary_blob",
data = if_mkl(
data = if_mkl_ml(
[
"//third_party/mkl:intel_binary_blob",
"//third_party/intel_mkl_ml",
],
),
)
@ -497,7 +491,6 @@ tf_cc_shared_object(
linkopts = select({
"//tensorflow:darwin": [],
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//conditions:default": [
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
"$(location //tensorflow:tf_framework_version_script.lds)",
@ -539,7 +532,6 @@ tf_cc_shared_object(
"-Wl,-install_name,@rpath/libtensorflow.so",
],
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//conditions:default": [
"-z defs",
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
@ -564,7 +556,6 @@ tf_cc_shared_object(
"$(location //tensorflow:tf_exported_symbols.lds)",
],
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//conditions:default": [
"-z defs",
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file

View File

@ -628,7 +628,6 @@ tf_cc_binary(
copts = tf_copts(),
linkopts = select({
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//tensorflow:darwin": [
"-lm",
"-lpthread",

View File

@ -314,12 +314,16 @@ cc_library(
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
],
hdrs = [
"build_xla_launch_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
],
deps = [
":common",
@ -354,6 +358,7 @@ cc_library(
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
],
@ -418,10 +423,12 @@ tf_cc_test(
srcs = [
"encapsulate_subgraphs_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
],
deps = [
":common",
":compilation_passes",
":xla_cluster_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
@ -23,15 +24,18 @@ namespace tensorflow {
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
PartiallyDeclusterPass);
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
// also need to run it after the graph been rewritten to have _Send nodes added
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
// name, and encapsulation might remove that node from the graph.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
EncapsulateSubgraphsPass);
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
BuildXlaLaunchOpsPass);
} // namespace tensorflow

View File

@ -210,7 +210,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie());
OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
ctx, kernel, run_result.ConsumeValueOrDie()));
VLOG(1) << "Done";
}

View File

@ -65,6 +65,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
// XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
// such nodes out of XLA clusters.
if (HasForwardedRefInput(node)) {
VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
return false;
}
@ -84,14 +85,13 @@ bool IsCompilableCall(const NodeDef& call_def,
bool IsCompilableWhile(const Node& while_node,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
VLOG(2) << "Loop marking: " << while_node.type_string();
const NameAttrList* name_attr;
NodeDef call;
Status status;
status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
if (!status.ok()) {
VLOG(2) << "Missing 'cond' attribute on While node.";
VLOG(2) << "Rejecting While " << while_node.name()
<< ": missing 'cond' attribute on While node.";
return false;
}
const string cond_func = name_attr->name();
@ -99,12 +99,14 @@ bool IsCompilableWhile(const Node& while_node,
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
VLOG(2) << "Can't compile loop condition: " << cond_func;
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop condition: " << cond_func;
return false;
}
status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
if (!status.ok()) {
VLOG(2) << "Missing 'body' attribute on While node.";
VLOG(2) << "Rejecting While " << while_node.name()
<< ": missing 'body' attribute on While node.";
return false;
}
const string body_func = name_attr->name();
@ -112,10 +114,10 @@ bool IsCompilableWhile(const Node& while_node,
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
VLOG(2) << "Can't compile loop body: " << body_func;
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop body: " << body_func;
return false;
}
VLOG(2) << "Loop is compilable.";
return true;
}
@ -125,10 +127,9 @@ bool IsCompilableWhile(const Node& while_node,
bool IsCompilableCall(const NodeDef& call_def,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
VLOG(2) << "Function marking: " << call_def.op();
if (depth > kMaxRecursionDepth) {
VLOG(2) << "Function depth limit exceeded";
VLOG(2) << "Rejecting " << call_def.op()
<< ": function depth limit exceeded.";
return false;
}
@ -136,7 +137,8 @@ bool IsCompilableCall(const NodeDef& call_def,
Status status =
lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle);
if (!status.ok()) {
VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
VLOG(2) << "Rejecting " << call_def.op()
<< ": could not instantiate: " << status;
return false;
}
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
@ -150,7 +152,8 @@ bool IsCompilableCall(const NodeDef& call_def,
// tf2xla to translate the TF graph into XLA. So we avoid this for now.
//
// TODO(b/36139787): Create a mechanism to set inlining hints.
VLOG(2) << "Can't compile noinline function: " << fdef.DebugString();
VLOG(2) << "Rejecting " << call_def.op()
<< ": can't compile noinline function.";
return false;
}
@ -164,23 +167,14 @@ bool IsCompilableCall(const NodeDef& call_def,
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, depth + 1,
lib_runtime)) {
VLOG(2) << "Function marking failed: unsupported op " << node->name()
<< ": " << node->def().ShortDebugString();
VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
<< node->name() << ": " << node->def().ShortDebugString();
return false;
}
}
VLOG(2) << "Function is compilable: " << call_def.op();
return true;
}
// Tests whether `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node) {
return std::find(node.input_types().begin(), node.input_types().end(),
DT_RESOURCE) != node.input_types().end() ||
std::find(node.output_types().begin(), node.output_types().end(),
DT_RESOURCE) != node.output_types().end();
}
// Returns true if the op can be decomposed into XLA ops for which
// there are fusable elemental implementations.
//
@ -357,24 +351,27 @@ Status FindCompilationCandidates(
}
std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
if (fuel >= std::numeric_limits<int64>::max() / 2) {
// The assumption is that if fuel started out as INT64_MAX, it will forever
// stay greater than INT64_MAX / 2.
VLOG(2) << "Starting fuel: infinity";
} else {
VLOG(2) << "Starting fuel: " << fuel;
}
for (Node* node : sorted_nodes) {
VLOG(2) << "Fuel: " << fuel;
if (fuel <= 0) {
VLOG(2)
VLOG(1)
<< "Hit fuel limit; not marking any remaining ops as clusterable.";
break;
}
VLOG(2) << "FindCompilationCandidates(): Processing "
<< node->DebugString();
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(node->assigned_device_name(), &device_type));
if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
VLOG(2) << "Compilation rejected node: not compilable " << node->name()
<< ": " << node->type_string();
// is_compilable_fn has already logged the reason if it returned false.
continue;
}
@ -384,14 +381,14 @@ Status FindCompilationCandidates(
DeviceType jit_device_type(registration->compilation_device_name);
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) {
VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
<< ": " << node->type_string();
VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
<< node->type_string();
continue;
}
if (!registration->compile_resource_ops &&
HasResourceInputOrOutput(*node)) {
VLOG(2) << "Compilation rejected node: resource input/output "
<< node->name() << ": " << node->type_string();
VLOG(2) << "Rejecting: " << node->name() << ": resource input/output "
<< node->type_string();
continue;
}
if (node->type_string() == "While" &&
@ -401,15 +398,11 @@ Status FindCompilationCandidates(
// _Arg nodes in a top-level function represent feeds.
// Do not compile them.
if (node->type_string() == "_Arg") {
VLOG(2) << "Skipping jit compilation for '_Arg'-typed node "
<< node->DebugString();
continue;
}
// _Retval nodes in a top-level function represent fetches.
// Do not compile them.
if (node->type_string() == "_Retval") {
VLOG(2) << "Compilation rejected node: return value " << node->name()
<< ": " << node->type_string();
continue;
}
candidates->insert(node);
@ -475,6 +468,7 @@ Status MarkForCompilationPass::Run(
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
&registration)) {
VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device.";
return false;
}
@ -484,21 +478,36 @@ Status MarkForCompilationPass::Run(
// If there is a _XlaCompile annotation, use its value.
bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
if (status.ok()) return compile;
if (status.ok()) {
if (!compile) {
VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
<< kXlaCompileAttr << ") is false.";
}
return compile;
}
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
if (status.ok()) return compile;
if (status.ok()) {
if (!compile) {
VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
<< kXlaCompileAttr << ") on callee is false.";
}
return compile;
}
// If inputs to `node` can have conflicting deadness (i.e. some are alive
// and some are dead) then don't compile it. XLA cannot represent the
// deadness semantics of these nodes correctly and auto-clustering these
// nodes can cause deadness to propagate to nodes that should be live.
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness.";
return false;
}
// Check for fusable ops only if requested.
if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
VLOG(2) << "Rejecting " << node->name()
<< ": not fusable op but fusion_only enabled.";
return false;
}
@ -506,8 +515,17 @@ Status MarkForCompilationPass::Run(
// Ignore enable_jit_by_default if global jit compilation for CPU
// is explicitly requested via tf_xla_cpu_global_jit flag
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
return (ignore_registration || registration->enable_jit_by_default) &&
global_jit_level > 0;
bool should_compile =
(ignore_registration || registration->enable_jit_by_default) &&
global_jit_level > 0;
if (!should_compile) {
if (global_jit_level <= 0) {
VLOG(2) << "Rejecting " << node->name() << ": global jit disabled.";
} else {
VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled.";
}
}
return should_compile;
};
return RunImpl(options, is_compilable);
}

View File

@ -40,20 +40,18 @@ class MarkForCompilationPass : public GraphOptimizationPass {
Status Run(const GraphOptimizationPassOptions& options) override;
// Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass
// unconditionally, call RunImpl() directly.
// is_compilable_fn, if set, is a predicate that must be true for a node to
// be compiled.
private:
Status RunImpl(const GraphOptimizationPassOptions& options,
const std::function<bool(const Node*, const DeviceType&)>&
is_compilable_fn = {});
friend class MarkForCompilationPassTestHelper;
};
// Returns true iff 'ndef' is a call to a function that is compilable. A
// function is compilable iff every operator in the function body is
// compilable.
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
@ -39,27 +39,6 @@ namespace {
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
Status MarkForCompilation(std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def) {
// Assign all nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
n->set_assigned_device_name(kCpuDevice);
}
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.flib_def = flib_def;
MarkForCompilationPass pass;
return pass.RunImpl(opt_options);
}
Status MarkForCompilation(std::unique_ptr<Graph>* graph) {
FunctionDefLibrary flib;
FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
return MarkForCompilation(graph, &flib_def);
}
std::unordered_map<string, string> GetClusters(const Graph& graph) {
std::unordered_map<string, string> ids;
for (Node* node : graph.nodes()) {
@ -88,7 +67,7 @@ TEST(XlaCompilationTest, Chains) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(4, clusters.size());
EXPECT_EQ(clusters["B"], clusters["C"]);
@ -113,7 +92,7 @@ TEST(XlaCompilationTest, UncompilableCycles) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@ -133,7 +112,7 @@ TEST(XlaCompilationTest, CompilableCycles) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size());
@ -156,7 +135,7 @@ TEST(XlaCompilationTest, Complex128Unsupported) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
}
@ -177,7 +156,7 @@ TEST(XlaCompilationTest, HalfSupported) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_FALSE(clusters.empty());
}
@ -206,7 +185,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size()); // Everything should be compiled.
}
@ -241,7 +220,8 @@ TEST(XlaCompilationTest, FunctionCalls) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def));
TF_ASSERT_OK(
MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@ -272,7 +252,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
@ -359,7 +339,7 @@ TEST(XlaCompilationTest, SymbolicGradients) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@ -384,7 +364,7 @@ TEST(XlaCompilationTest, Loops) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// Nothing should be compiled. In particular, 'd' and 'c' must not be
@ -411,7 +391,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: C = A + relu(A)
@ -442,7 +422,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: D = relu(A) + (A @ relu(A))
@ -472,7 +452,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: C = A @ relu(A)
@ -512,7 +492,7 @@ TEST(XlaCompilationTest, Resources) {
ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
@ -542,7 +522,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
TF_EXPECT_OK(root.ToGraph(graph.get()));
Status status = MarkForCompilation(&graph);
Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(status.ToString(),
"Edge from c to a would create a cycle.\n"
@ -570,7 +550,7 @@ TEST(XlaCompilationTest, Retval) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@ -588,7 +568,7 @@ TEST(XlaCompilationTest, DontCountIdentityOps) {
auto r = ops::_Retval(root.WithOpName("R"), c, 0);
}
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@ -604,7 +584,7 @@ TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) {
auto r = ops::_Retval(root.WithOpName("R"), b, 0);
}
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@ -618,7 +598,7 @@ TEST(XlaCompilationTest, ConstOp) {
auto c = ops::Const(root.WithOpName("const"), 0.5f);
c.node()->AddAttr(kXlaCompileAttr, true);
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
EXPECT_EQ(1, GetClusters(*graph).size());
}
@ -629,7 +609,7 @@ TEST(XlaCompilationTest, ConstOp) {
auto c = ops::Const(root.WithOpName("const"), string("string"));
c.node()->AddAttr(kXlaCompileAttr, true);
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
EXPECT_TRUE(GetClusters(*graph).empty());
}
}
@ -644,7 +624,7 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
@ -667,7 +647,7 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
@ -699,7 +679,7 @@ TEST(XlaCompilationTest, ClusterControlTrigger) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilation(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);

View File

@ -0,0 +1,40 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
namespace tensorflow {
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
// Assign all nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
n->set_assigned_device_name(kCpuDevice);
}
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.flib_def = flib_def;
MarkForCompilationPass pass;
return pass.RunImpl(opt_options);
}
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
std::unique_ptr<Graph>* graph) {
FunctionDefLibrary flib;
FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
return MarkForCompilation(graph, &flib_def);
}
} // namespace tensorflow

View File

@ -0,0 +1,35 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_
#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
namespace tensorflow {
class MarkForCompilationPassTestHelper {
public:
// Runs the MarkForCompilation pass on `graph` after assigning all nodes in
// `graph` to the CPU device. To make testing easier, ignores device
// registration, _XlaCompile attributes, input deadness and global jit level.
static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def);
// Like `MarkForCompilation` but creates `flib_def` from the op registry.
static Status MarkForCompilation(std::unique_ptr<Graph>* graph);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_

View File

@ -0,0 +1,177 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/gtl/flatset.h"
namespace tensorflow {
namespace {
Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
gtl::ArraySlice<Node*> post_order) {
// Find nodes that have at least one user outside their cluster that expects
// hostmem output. These nodes should be cloned to outside the cluster to
// avoid the device-host copy we'd otherwise need.
MemoryTypeVector input_mtypes, output_mtypes;
for (Node* n : post_order) {
gtl::optional<StringPiece> from_cluster = GetXlaClusterForNode(*n);
if (!from_cluster) {
continue;
}
// We assume the only XLA-auto-clusterable operations with side effects are
// resource variable updates. We can't execute these twice.
if (HasResourceInputOrOutput(*n)) {
continue;
}
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(n->assigned_device_name(), &device_type));
TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
n->def(), &input_mtypes,
&output_mtypes));
for (const Edge* e : n->out_edges()) {
Node* dst = e->dst();
if (e->IsControlEdge()) {
continue;
}
bool edge_incurs_extra_device_to_host_copy;
if (output_mtypes[e->src_output()] == DEVICE_MEMORY) {
// If the output of the *TensorFlow* operation is in DEVICE_MEMORY then
// keep the node clustered -- XLA will also produce the output in device
// memory and we will get some benefit from clustering.
edge_incurs_extra_device_to_host_copy = false;
} else {
MemoryTypeVector dst_input_mtypes, dst_output_mtypes;
DeviceType dst_device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type));
TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
dst->def(), &dst_input_mtypes,
&dst_output_mtypes));
edge_incurs_extra_device_to_host_copy =
dst_input_mtypes[e->dst_input()] == HOST_MEMORY;
}
if (!edge_incurs_extra_device_to_host_copy) {
continue;
}
// Check if `dst` is in a different cluster, unclustered, or about to be
// partially declustered (here we rely on the post-order traversal order).
// If yes, decluster `n` to avoid the device-to-host memcpy.
gtl::optional<StringPiece> dst_cluster =
result->count(dst) ? gtl::nullopt : GetXlaClusterForNode(*dst);
if (from_cluster != dst_cluster) {
CHECK(result->insert(n).second);
break;
}
}
}
return Status::OK();
}
Status PartiallyDeclusterNode(Graph* graph, Node* n) {
StringPiece cluster_name = *GetXlaClusterForNode(*n);
gtl::InlinedVector<const Edge*, 6> out_edges_to_clone;
for (const Edge* out_edge : n->out_edges()) {
if (out_edge->IsControlEdge()) {
continue;
}
Node* dst = out_edge->dst();
gtl::optional<StringPiece> dst_cluster_name = GetXlaClusterForNode(*dst);
if (dst_cluster_name != cluster_name) {
out_edges_to_clone.push_back(out_edge);
}
}
CHECK(!out_edges_to_clone.empty()) << n->DebugString();
NodeDef ndef = n->def();
ndef.set_name(strings::StrCat(n->name(), "/declustered"));
RemoveFromXlaCluster(&ndef);
Status s;
Node* cloned_node = graph->AddNode(ndef, &s);
cloned_node->set_assigned_device_name(n->assigned_device_name());
TF_RETURN_IF_ERROR(s);
for (const Edge* in_edge : n->in_edges()) {
graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node,
in_edge->dst_input());
}
for (const Edge* out_edge_to_clone : out_edges_to_clone) {
graph->AddEdge(cloned_node, out_edge_to_clone->src_output(),
out_edge_to_clone->dst(), out_edge_to_clone->dst_input());
graph->RemoveEdge(out_edge_to_clone);
}
return Status::OK();
}
} // namespace
Status PartiallyDeclusterPass::Run(
const GraphOptimizationPassOptions& options) {
// NB! In this pass we assume the only XLA-auto-clusterable operations that
// may have side effects are resource variable operations so we don't cluster
// those. The pass will have to be updated if this assumption becomes
// invalid.
Graph* graph = options.graph->get();
// When deciding whether to decluster a particular node, we base our decision
// on if we've decided that some of its consumers have to be declustered too.
// Iterating the graph in post-order guarantees that consumers have been
// visited before producers.
std::vector<Node*> post_order;
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});
gtl::FlatSet<Node*> nodes_to_partially_decluster;
TF_RETURN_IF_ERROR(FindNodesToDecluster(
**options.graph, &nodes_to_partially_decluster, post_order));
if (VLOG_IS_ON(3)) {
for (Node* n : post_order) {
if (nodes_to_partially_decluster.count(n)) {
VLOG(3) << n->DebugString();
}
}
}
for (Node* n : post_order) {
if (nodes_to_partially_decluster.count(n)) {
TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n));
}
}
nodes_to_partially_decluster.clear();
TF_RETURN_IF_ERROR(FindNodesToDecluster(
**options.graph, &nodes_to_partially_decluster, post_order));
CHECK(nodes_to_partially_decluster.empty());
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,58 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_
#define TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
// Clones nodes from within a cluster to outside the cluster if profitable.
//
// Today this only clones to avoid device-to-host copies, but in the future we
// may consider other reasons to clone. For instance, we convert this:
//
// .....
// |
// v
// A_Clustered ====> C_Unclustered
// |
// v
// B_Clustered
//
// to:
//
// .....
// | |
// | +-------------+
// | |
// v v
// A_Clustered A_Unclustered ====> C_Unclustered
// |
// v
// B_Clustered
//
// where the ===> arrow has a hostmem source and destination and would entail a
// device to host copy if the source and destination were not in the same XLA
// cluster.
class PartiallyDeclusterPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_

View File

@ -0,0 +1,284 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
REGISTER_OP("FakeNullary").Output("out: float");
REGISTER_OP("FakeBinary")
.Input("host_in: float")
.Input("device_in: float")
.Output("host_out: float")
.Output("device_out: float");
REGISTER_OP("FakeResourceVar").Output("out: resource");
REGISTER_OP("FakeResourceUpdate")
.Input("in: resource")
.Output("out: resource")
.Output("something_else: float");
class FakeBinaryOp : public OpKernel {
public:
explicit FakeBinaryOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* ctx) override { CHECK(false); }
};
class FakeResourceVarUpdateOp : public OpKernel {
public:
explicit FakeResourceVarUpdateOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override { CHECK(false); }
};
REGISTER_KERNEL_BUILDER(Name("FakeBinary")
.Device(DEVICE_CPU)
.HostMemory("host_in")
.HostMemory("host_out"),
FakeBinaryOp);
REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate")
.Device(DEVICE_CPU)
.HostMemory("something_else"),
FakeResourceVarUpdateOp);
Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
FixupSourceAndSinkEdges(graph->get());
// Assign all nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
n->set_assigned_device_name(kCpuDevice);
}
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
PartiallyDeclusterPass pass;
return pass.Run(opt_options);
}
const Node* FindNodeByName(const Graph& graph, const string& name) {
for (const Node* node : graph.nodes()) {
if (node->name() == name) {
return node;
}
}
return nullptr;
}
bool GetInputsForNode(const Graph& graph, const string& node_name,
std::vector<Node*>* inputs) {
const Node* node = FindNodeByName(graph, node_name);
if (node == nullptr) {
return false;
}
for (const Edge* e : node->in_edges()) {
inputs->push_back(e->src());
}
std::sort(inputs->begin(), inputs->end(), NodeComparatorName());
return true;
}
TEST(PartiallyDeclusterPassTest, ClusteredAndUnclustered) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* input =
ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
Node* clustered_producer =
ops::BinaryOp("FakeBinary", input, input,
builder.opts().WithName("ClusteredProducer"));
ops::BinaryOp("FakeBinary", clustered_producer, input,
builder.opts().WithName("UnclusteredConsumer"));
Node* clustered_consumer =
ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input,
builder.opts().WithName("ClusteredConsumer"));
clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(PartiallyDecluster(&graph));
std::vector<Node*> unclustered_consumer_inputs;
ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer",
&unclustered_consumer_inputs));
ASSERT_EQ(unclustered_consumer_inputs.size(), 2);
EXPECT_EQ(unclustered_consumer_inputs[0]->name(),
"ClusteredProducer/declustered");
EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input");
std::vector<Node*> clustered_consumer_inputs;
ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredConsumer",
&clustered_consumer_inputs));
ASSERT_EQ(clustered_consumer_inputs.size(), 2);
EXPECT_EQ(clustered_consumer_inputs[0]->name(), "ClusteredProducer");
EXPECT_EQ(clustered_consumer_inputs[1]->name(), "Input");
}
TEST(PartiallyDeclusterPassTest, DifferentClusters) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* input =
ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
Node* clustered_producer =
ops::BinaryOp("FakeBinary", input, input,
builder.opts().WithName("ClusteredProducer"));
Node* consumer_in_different_cluster =
ops::BinaryOp("FakeBinary", clustered_producer, input,
builder.opts().WithName("ConsumerInDifferentCluster"));
Node* clustered_consumer =
ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
builder.opts().WithName("ClusteredConsumer"));
clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(PartiallyDecluster(&graph));
std::vector<Node*> inputs;
ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
ASSERT_EQ(inputs.size(), 2);
EXPECT_EQ(inputs[0]->name(), "ClusteredProducer/declustered");
EXPECT_EQ(inputs[1]->name(), "Input");
}
TEST(PartiallyDeclusterPassTest, DontDeclusterIfUserIsDeviceMem) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* input =
ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
Node* clustered_producer =
ops::BinaryOp("FakeBinary", input, input,
builder.opts().WithName("ClusteredProducer"));
// The first input is hostmem and the second input is devicemem.
Node* consumer_in_different_cluster =
ops::BinaryOp("FakeBinary", input, clustered_producer,
builder.opts().WithName("ConsumerInDifferentCluster"));
Node* clustered_consumer =
ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
builder.opts().WithName("ClusteredConsumer"));
clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(PartiallyDecluster(&graph));
std::vector<Node*> inputs;
ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
ASSERT_EQ(inputs.size(), 2);
EXPECT_EQ(inputs[0]->name(), "ClusteredProducer");
EXPECT_EQ(inputs[1]->name(), "Input");
}
TEST(PartiallyDeclusterPassTest, DontDuplicateResourceVarOps) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* input =
ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
Node* resource_var = ops::SourceOp("FakeResourceVar",
builder.opts().WithName("ResourceVar"));
Node* clustered_producer =
ops::UnaryOp("FakeResourceUpdate", resource_var,
builder.opts().WithName("ClusteredProducer"));
Node* consumer_in_different_cluster =
ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input,
builder.opts().WithName("ConsumerInDifferentCluster"));
Node* clustered_consumer =
ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
builder.opts().WithName("ClusteredConsumer"));
clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(PartiallyDecluster(&graph));
std::vector<Node*> inputs;
ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
ASSERT_EQ(inputs.size(), 2);
EXPECT_EQ(inputs[0]->name(), "ClusteredProducer");
EXPECT_EQ(inputs[1]->name(), "Input");
}
TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* input =
ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
Node* clustered_producer_0 =
ops::BinaryOp("FakeBinary", input, input,
builder.opts().WithName("ClusteredProducer0"));
Node* clustered_producer_1 =
ops::BinaryOp("FakeBinary", clustered_producer_0, input,
builder.opts().WithName("ClusteredProducer1"));
ops::BinaryOp("FakeBinary", clustered_producer_1, input,
builder.opts().WithName("UnclusteredConsumer"));
Node* clustered_consumer =
ops::BinaryOp("FakeBinary", {clustered_producer_1, 1}, input,
builder.opts().WithName("ClusteredConsumer"));
clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0");
clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0");
clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(PartiallyDecluster(&graph));
std::vector<Node*> unclustered_consumer_inputs, declustered_producer_1_inputs;
ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer",
&unclustered_consumer_inputs));
ASSERT_EQ(unclustered_consumer_inputs.size(), 2);
EXPECT_EQ(unclustered_consumer_inputs[0]->name(),
"ClusteredProducer1/declustered");
EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input");
ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredProducer1/declustered",
&declustered_producer_1_inputs));
ASSERT_EQ(declustered_producer_1_inputs.size(), 2);
EXPECT_EQ(declustered_producer_1_inputs[0]->name(),
"ClusteredProducer0/declustered");
EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
}
} // namespace
} // namespace tensorflow

View File

@ -185,4 +185,26 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
return Status::OK();
}
gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node) {
const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
if (attr_value == nullptr) {
return gtl::nullopt;
}
Status s = AttrValueHasType(*attr_value, "string");
if (!s.ok()) {
return gtl::nullopt;
}
return attr_value->s();
}
bool HasResourceInputOrOutput(const Node& node) {
return std::find(node.input_types().begin(), node.input_types().end(),
DT_RESOURCE) != node.input_types().end() ||
std::find(node.output_types().begin(), node.output_types().end(),
DT_RESOURCE) != node.output_types().end();
}
void RemoveFromXlaCluster(NodeDef* node_def) {
node_def->mutable_attr()->erase(kXlaClusterAttr);
}
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/gtl/optional.h"
namespace tensorflow {
@ -44,6 +45,16 @@ bool HasForwardedRefInput(const Node& node);
// the enclosing graph.
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles);
// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster,
// otherwise returns nullopt.
gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node);
// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
void RemoveFromXlaCluster(NodeDef* node_def);
// Returns true if `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_

View File

@ -78,7 +78,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
executable->Run(launch_context.arguments(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie());
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
ctx, result, run_result.ConsumeValueOrDie()));
return Status::OK();
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@ -216,6 +217,8 @@ XlaDevice::XlaDevice(
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name << " " << this;
thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device",
/*num_threads=*/1));
}
XlaDevice::~XlaDevice() {
@ -262,10 +265,12 @@ Status XlaDevice::EnsureDeviceContextOk() {
Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
const string& name,
xla::StreamPool::Ptr* stream,
std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed) {
if (!(*stream) || !(*stream)->ok()) {
TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_));
xla::StreamPool::Ptr ptr;
TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
*stream = std::shared_ptr<se::Stream>(std::move(ptr));
VLOG(1) << "XlaDevice " << this << " new " << name << " "
<< (*stream)->DebugStreamPointers();
*stream_was_changed = true;
@ -281,8 +286,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
&need_new_device_context));
se::Stream* host_to_device_stream = stream_.get();
se::Stream* device_to_host_stream = stream_.get();
std::shared_ptr<se::Stream> host_to_device_stream = stream_;
std::shared_ptr<se::Stream> device_to_host_stream = stream_;
if (use_multiple_streams_) {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
&host_to_device_stream_,
@ -290,8 +295,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
&device_to_host_stream_,
&need_new_device_context));
host_to_device_stream = host_to_device_stream_.get();
device_to_host_stream = device_to_host_stream_.get();
host_to_device_stream = host_to_device_stream_;
device_to_host_stream = device_to_host_stream_;
}
if (!need_new_device_context) {
@ -304,9 +309,13 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
if (device_context_) {
device_context_->Unref();
}
// The XlaDeviceContext keeps a reference count to the streams, and the
// XlaDeviceContext remains live for the duration of a Executor run. This
// ensures that the streams remain live for the duration of a run, even if
// an error is encountered and the streams are replaced with new ones.
device_context_ = new XlaDeviceContext(
stream_.get(), host_to_device_stream, device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
stream_, host_to_device_stream, device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_, thread_pool_.get());
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
<< device_context_;
@ -371,6 +380,22 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
op_kernel->ComputeAsync(context, done);
}
Status XlaDevice::Sync() {
VLOG(1) << "XlaDevice::Sync";
std::shared_ptr<se::Stream> stream;
{
mutex_lock lock(mu_);
stream = stream_;
}
if (!stream) return Status::OK();
if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) {
return errors::Internal("XlaDevice::Sync() failed.");
}
VLOG(1) << "XlaDevice::Sync completed";
return Status::OK();
}
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) {

View File

@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
@ -124,7 +123,7 @@ class XlaDevice : public LocalDevice {
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
Status Sync() override { return Status::OK(); }
Status Sync() override;
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override
@ -153,7 +152,7 @@ class XlaDevice : public LocalDevice {
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
xla::StreamPool::Ptr* stream,
std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
@ -174,17 +173,17 @@ class XlaDevice : public LocalDevice {
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
xla::StreamPool::Ptr stream_ GUARDED_BY(mu_);
std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
// If false, only stream_ is valid and all computation and transfers use
// stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_);
std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_);
std::shared_ptr<se::Stream> device_to_host_stream_ GUARDED_BY(mu_);
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
const bool transfer_as_literal_;
@ -198,6 +197,9 @@ class XlaDevice : public LocalDevice {
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
// Thread pool used for running closures
std::unique_ptr<thread::ThreadPool> thread_pool_;
};
// Builds OpKernel registrations on 'device' for the JIT operators

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_context.h"
#include <memory>
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@ -48,17 +51,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(
se::Stream* compute_stream, se::Stream* host_to_device_stream,
se::Stream* device_to_host_stream, xla::LocalClient* client,
std::shared_ptr<se::Stream> compute_stream,
std::shared_ptr<se::Stream> host_to_device_stream,
std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: stream_(compute_stream),
host_to_device_stream_(host_to_device_stream),
device_to_host_stream_(device_to_host_stream),
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
thread::ThreadPool* thread_pool)
: stream_(std::move(compute_stream)),
host_to_device_stream_(std::move(host_to_device_stream)),
device_to_host_stream_(std::move(device_to_host_stream)),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(std::move(shape_representation_fn)) {
shape_representation_fn_(std::move(shape_representation_fn)),
thread_pool_(thread_pool) {
CHECK(host_to_device_stream_ != nullptr);
CHECK(device_to_host_stream_ != nullptr);
CHECK(stream_ != nullptr);
@ -88,15 +94,15 @@ Status XlaTransferManager::TransferLiteralToDevice(
if (UseMultipleStreams()) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
host_to_device_stream_->ThenWaitFor(stream_);
host_to_device_stream_->ThenWaitFor(stream_.get());
}
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
host_to_device_stream_, *literal, shaped_buffer));
host_to_device_stream_.get(), *literal, shaped_buffer));
if (UseMultipleStreams()) {
se::Event event(stream_->parent());
TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
host_to_device_stream_->ThenRecordEvent(&event);
xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
auto event = std::make_shared<se::Event>(stream_->parent());
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
host_to_device_stream_->ThenRecordEvent(event.get());
xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event));
}
// Unref the host tensor, and capture the literal shared_ptr too so it goes
// out of scope when the lambda completes.
@ -116,7 +122,7 @@ void XlaTransferManager::TransferLiteralFromDevice(
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream_, shaped_buffer, literal,
device_to_host_stream_.get(), shaped_buffer, literal,
[=, &shaped_buffer, &literal](xla::Status status) {
ref.Unref();
done([&]() -> Status {
@ -179,8 +185,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
if (status.ok()) {
xla_tensor->set_host_tensor(*cpu_tensor);
host_to_device_stream_->ThenDoHostCallback(
[done]() { done(Status::OK()); });
host_to_device_stream_->ThenDoHostCallback([this, done]() {
// We must not call the done closure directly from DoHostCallback
// to avoid a deadlock. If done() is the callback that ends an
// Executor's run, the Executor may call XlaDevice::Sync() inside the
// callback. This deadlocks, because XlaDevice::Sync() waits for all
// stream activity to complete.
thread_pool_->Schedule([done]() { done(Status::OK()); });
});
return;
}
} else {
@ -192,7 +204,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s",
host_to_device_stream_, block_status.error_message().c_str());
host_to_device_stream_.get(), block_status.error_message().c_str());
}
}
xla_tensor->set_host_tensor(*cpu_tensor);
@ -225,9 +237,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
if (se::Event* event =
xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) {
device_to_host_stream_->ThenWaitFor(event);
xla_tensor->SetDefinedOn(device_to_host_stream_);
xla_tensor->SetDefinedOn(device_to_host_stream_.get());
}
Status status;
@ -240,7 +252,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Status block_status = device_to_host_stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s", stream_,
"Failed to complete data transfer on stream %p: %s", stream_.get(),
block_status.error_message().c_str());
}
}
@ -278,14 +290,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
if (stream_ != device_to_device_stream) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
device_to_device_stream->ThenWaitFor(stream_);
device_to_device_stream->ThenWaitFor(stream_.get());
}
}
if (se::Event* event =
xla_src->GetDefinitionEvent(device_to_device_stream)) {
xla_src->GetDefinitionEvent(device_to_device_stream.get())) {
device_to_device_stream->ThenWaitFor(event);
xla_src->SetDefinedOn(device_to_device_stream);
xla_src->SetDefinedOn(device_to_device_stream.get());
}
auto from_iter = xla_src->shaped_buffer().buffers().begin();
@ -297,28 +309,37 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
}
if (UseMultipleStreams()) {
se::Event event(stream_->parent());
CHECK(event.Init());
device_to_device_stream->ThenRecordEvent(&event);
xla_dst->SetDefinedOn(device_to_device_stream, std::move(event));
auto event = std::make_shared<se::Event>(stream_->parent());
TF_RET_CHECK(event->Init()) << "Event failed to initialize";
device_to_device_stream->ThenRecordEvent(event.get());
xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event));
}
return Status::OK();
}();
if (!status.ok()) {
return done(status);
} else {
stream_->ThenDoHostCallback([=]() { done(Status::OK()); });
stream_->ThenDoHostCallback([this, done]() {
// We must not call the done closure directly from DoHostCallback to avoid
// a deadlock. If done() is the callback that ends an Executor's run, the
// Executor may call XlaDevice::Sync() inside the callback. This
// deadlocks, because XlaDevice::Sync() waits for all stream activity to
// complete.
thread_pool_->Schedule([done]() { done(Status::OK()); });
});
}
}
XlaDeviceContext::XlaDeviceContext(
se::Stream* compute_stream, se::Stream* host_to_device_stream,
se::Stream* device_to_host_stream, xla::LocalClient* client,
std::shared_ptr<se::Stream> compute_stream,
std::shared_ptr<se::Stream> host_to_device_stream,
std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: manager_(compute_stream, host_to_device_stream, device_to_host_stream,
client, transfer_as_literal,
std::move(shape_representation_fn)) {}
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
thread::ThreadPool* thread_pool)
: manager_(std::move(compute_stream), std::move(host_to_device_stream),
std::move(device_to_host_stream), client, transfer_as_literal,
std::move(shape_representation_fn), thread_pool) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,

View File

@ -47,10 +47,12 @@ class XlaDeviceAllocator : public Allocator {
class XlaTransferManager {
public:
explicit XlaTransferManager(
se::Stream* compute_stream, se::Stream* host_to_device_stream,
se::Stream* device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
std::shared_ptr<se::Stream> compute_stream,
std::shared_ptr<se::Stream> host_to_device_stream,
std::shared_ptr<se::Stream> device_to_host_stream,
xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
@ -61,7 +63,7 @@ class XlaTransferManager {
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
const StatusCallback& done);
se::Stream* stream() const { return stream_; }
se::Stream* stream() const { return stream_.get(); }
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
@ -73,13 +75,13 @@ class XlaTransferManager {
// The main compute stream of the device, used to synchronize the transfer
// streams if they are set.
se::Stream* stream_;
std::shared_ptr<se::Stream> stream_;
// The stream to use for transferring data from host to device. Can be
// idential to stream_, but must not be nullptr.
se::Stream* host_to_device_stream_;
std::shared_ptr<se::Stream> host_to_device_stream_;
// The stream to use for transferring data from device to host. Can be
// idential to stream_, but must not be nullptr.
se::Stream* device_to_host_stream_;
std::shared_ptr<se::Stream> device_to_host_stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
@ -87,6 +89,9 @@ class XlaTransferManager {
// True if we must use XLA's TransferManager for correct device transfers.
const bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// Thread pool used for running closures
thread::ThreadPool* thread_pool_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
@ -95,10 +100,12 @@ class XlaTransferManager {
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(
se::Stream* compute_stream, se::Stream* host_to_device_stream,
se::Stream* device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
std::shared_ptr<se::Stream> compute_stream,
std::shared_ptr<se::Stream> host_to_device_stream,
std::shared_ptr<se::Stream> device_to_host_stream,
xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include <memory>
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -182,7 +184,7 @@ void XlaComputationLaunchContext::PopulateInputs(
}
}
void XlaComputationLaunchContext::PopulateOutputs(
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output) {
se::Stream* stream =
@ -211,6 +213,15 @@ void XlaComputationLaunchContext::PopulateOutputs(
output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
}
std::shared_ptr<se::Event> definition_event;
if (use_multiple_streams_) {
definition_event = std::make_shared<se::Event>(stream->parent());
if (!definition_event->Init()) {
return errors::Internal("Failed to initialize tensor definition event.");
}
stream->ThenRecordEvent(definition_event.get());
}
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
@ -228,12 +239,13 @@ void XlaComputationLaunchContext::PopulateOutputs(
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
OP_REQUIRES_OK(
ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
TF_RETURN_IF_ERROR(
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
OP_REQUIRES(ctx, device != nullptr,
errors::Internal("DeviceBase was not a Device."));
if (device == nullptr) {
return errors::Internal("DeviceBase was not a Device.");
}
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
@ -263,16 +275,13 @@ void XlaComputationLaunchContext::PopulateOutputs(
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
Tensor* output_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor));
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
if (use_multiple_streams_) {
se::Event event(stream->parent());
CHECK(event.Init());
stream->ThenRecordEvent(&event);
xla_tensor->SetDefinedOn(stream, std::move(event));
xla_tensor->SetDefinedOn(stream, definition_event);
}
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
@ -298,41 +307,39 @@ void XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
OP_REQUIRES(ctx,
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
errors::Internal("Invalid input index for variable write."));
if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
se::DeviceMemoryBase buffer = output.buffer({output_num});
Var* variable = nullptr;
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, write.input_index),
&variable, [this, ctx, &write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, write.input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu());
OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
errors::Internal("Mismatched type in variable write"));
if (variable->tensor()->dtype() != write.type) {
return errors::Internal("Mismatched type in variable write");
}
if (allocate_xla_tensors_) {
Tensor output_tensor;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor));
TF_RETURN_IF_ERROR(
ctx->allocate_temp(write.type, write.shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
if (use_multiple_streams_) {
se::Event event(stream->parent());
CHECK(event.Init());
stream->ThenRecordEvent(&event);
xla_tensor->SetDefinedOn(stream, std::move(event));
xla_tensor->SetDefinedOn(stream, definition_event);
}
*variable->tensor() = output_tensor;
} else {
@ -343,6 +350,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
}
++output_num;
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -93,9 +93,9 @@ class XlaComputationLaunchContext {
const std::map<int, OptionalTensor>& variables);
// Given the XLA output in `output`, populate all outputs of `ctx`.
void PopulateOutputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
xla::ScopedShapedBuffer output);
Status PopulateOutputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
xla::ScopedShapedBuffer output);
// Return the argument list. Only valid after PopulateInputs() has been
// called.

View File

@ -75,7 +75,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
mutex_lock lock(mu_);
if (!definition_event_.has_value()) {
if (!definition_event_) {
return nullptr;
}
@ -87,10 +87,11 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
return nullptr;
}
return &*definition_event_;
return definition_event_.get();
}
void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
void XlaTensor::SetDefinedOn(se::Stream* stream,
std::shared_ptr<se::Event> event) {
mutex_lock lock(mu_);
definition_event_ = std::move(event);
streams_defined_on_ = {stream};

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
#include <memory>
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/core/framework/allocator.h"
@ -94,7 +96,7 @@ class XlaTensor {
// Assert that the tensor's content is defined on 'stream' by the time 'event'
// triggers.
void SetDefinedOn(se::Stream* stream, se::Event event);
void SetDefinedOn(se::Stream* stream, std::shared_ptr<se::Event> event);
// Assert that the tensor's content is defined on 'stream'. This version does
// not provide an event, and must be called *after* SetDefinedOn(Stream,
@ -116,7 +118,7 @@ class XlaTensor {
// An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined.
gtl::optional<se::Event> definition_event_;
std::shared_ptr<se::Event> definition_event_;
// A list of all streams for which the tensor's content is defined for any
// newly enqueued command.
gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);

View File

@ -31,7 +31,6 @@ std::vector<tensorflow::Flag>* flag_objects;
std::once_flag flags_init;
void SetDebugOptionsDefaults(DebugOptions* flags) {
flags->set_xla_enable_fast_math(true);
flags->set_xla_llvm_enable_alias_scope_metadata(true);
flags->set_xla_llvm_enable_noalias_metadata(true);
flags->set_xla_llvm_enable_invariant_load_metadata(true);
@ -53,6 +52,11 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
// the heuristics needed to decide when to run on multiple streams. See
// b/77879207.
flags->set_xla_gpu_disable_multi_streaming(true);
// TODO(jlebar): Disable fastmath once doing so is not a performance
// regression.
flags->set_xla_cpu_enable_fast_math(true);
flags->set_xla_gpu_enable_fast_math(true);
}
// Allocates flag_values and flag_objects; this function must not be called more
@ -150,10 +154,16 @@ void AllocateFlags() {
flag_values->mutable_xla_generate_hlo_text_to(),
"Dump all HLO modules as text into the provided directory path."),
tensorflow::Flag(
"xla_enable_fast_math",
bool_setter_for(&DebugOptions::set_xla_enable_fast_math),
flag_values->xla_enable_fast_math(),
"Enable unsafe fast-math optimizations in the compiler; "
"xla_cpu_enable_fast_math",
bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
flag_values->xla_cpu_enable_fast_math(),
"Enable unsafe fast-math optimizations in the CPU compiler; "
"this may produce faster code at the expense of some accuracy."),
tensorflow::Flag(
"xla_gpu_enable_fast_math",
bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
flag_values->xla_cpu_enable_fast_math(),
"Enable unsafe fast-math optimizations in the GPU compiler; "
"this may produce faster code at the expense of some accuracy."),
tensorflow::Flag(
"xla_llvm_enable_alias_scope_metadata",

View File

@ -1234,6 +1234,20 @@ cc_library(
],
)
cc_library(
name = "scatter_expander",
srcs = ["scatter_expander.cc"],
hdrs = ["scatter_expander.h"],
deps = [
":hlo",
":hlo_creation_utils",
":hlo_pass",
":while_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
],
)
tf_cc_test(
name = "batchnorm_expander_test",
size = "small",

View File

@ -1705,6 +1705,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
reshape, HloInstruction::CreateReshape(reshape->shape(),
operand->mutable_operand(0)));
}
if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
*operand->mutable_shape() = reshape->shape();
return ReplaceInstruction(reshape, operand);
}
if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
auto opt_dims = ReshapeLeavesDimensionsUnmodified(
@ -2144,6 +2148,11 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
transpose->dimensions())));
}
if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
*operand->mutable_shape() = transpose->shape();
return ReplaceInstruction(transpose, operand);
}
if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) {
ReplaceWithBitcast(transpose);
return Status::OK();

View File

@ -1428,6 +1428,37 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
}
// Test transforming reshapes and transposes of rng.
TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
HloComputation::Builder builder(TestName());
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction* one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
HloInstruction* rng0 = builder.AddInstruction(
HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}),
RandomDistribution::RNG_UNIFORM, {zero, one}));
HloInstruction* transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0}));
Shape reshape_shape = builder
.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {4}), transpose))
->shape();
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
// Verify that that reshape(transpose(rng)) is replace by a single rng of the
// same shape as the reshape.
EXPECT_THAT(computation->root_instruction(), op::Rng());
EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(),
reshape_shape));
}
// Test transforming reshapes to bitcasts under various conditions.
TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
HloComputation::Builder builder(TestName());

View File

@ -139,6 +139,7 @@ Status GatherComputationsByAllocationType(
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kFusion:
// Map/reduce etc computations are always thread-local.

View File

@ -61,6 +61,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kFusion:
return CallContext::kParallel;

View File

@ -312,7 +312,7 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
return Status::OK();
}
// We add copies for all the indices of the true and false computaiton roots,
// We add copies for all the indices of the true and false computation roots,
// in order to resolve interference. We later rely on the CopyRemover to drop
// the unnecessary ones.
Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
@ -648,7 +648,12 @@ class CopyRemover {
// We can only perform copy elision if the resulting merged values have
// totally ordered live ranges; otherwise the merged buffer would have
// live range interference.
if (IsHead(*dest)) {
if (src->next == dest) {
// In the process of eliding copies, its possible for a copy to have the
// same source and destination buffer. In this case, the copy can be
// safely removed.
VLOG(2) << copy->name() << " source and destination buffers are same.";
} else if (IsHead(*dest)) {
// The copy copies an arbitrary value in the source buffer (call it s_x)
// and defines d_0, the first value in the destination buffer. After
// merging, the values in the combined buffer must be strictly ordered

View File

@ -2007,5 +2007,46 @@ ENTRY TestComputation {
InsertCopies(module.get());
}
TEST_F(CopyInsertionTest, NestedWhiles) {
// Verify that only no unnecessary copies remain after copy insertion for
// trivial nested whiles (b/112472605).
const string& hlo_string = R"(
HloModule TestModule
cond.inner {
ROOT param.cond.inner = pred[] parameter(0)
}
body.inner {
param.body.inner = pred[] parameter(0)
ROOT neg = pred[] negate(param.body.inner)
}
cond.outer {
ROOT param.cond.outer = pred[] parameter(0)
}
body.outer {
param.cond.outer = pred[] parameter(0)
ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
}
ENTRY TestComputation {
entry_param = pred[] parameter(0)
ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
InsertCopies(module.get());
// There should only be a single copy inserted, and it's in the entry
// computation.
EXPECT_EQ(CountCopies(*module), 1);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::While(op::Copy(op::Parameter())));
}
} // namespace
} // namespace xla

View File

@ -20,7 +20,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
"mkl_deps",
)
# Filegroup used to collect source files for dependency checking.
@ -86,6 +86,7 @@ cc_library(
":parallel_task_assignment",
":simple_orc_jit",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
@ -497,10 +498,7 @@ cc_library(
"//tensorflow/core:framework_lite",
"//tensorflow/core/kernels:eigen_helpers",
"//third_party/eigen3",
] + if_mkl([
"@mkl_dnn",
"//third_party/mkl:intel_binary_blob",
]),
] + mkl_deps(),
)
cc_library(
@ -554,10 +552,7 @@ cc_library(
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
] + if_mkl([
"//third_party/mkl:intel_binary_blob",
"@mkl_dnn",
]),
] + mkl_deps(),
)
cc_library(

View File

@ -88,6 +88,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/scatter_expander.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
@ -299,6 +300,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pipeline.AddPass<CpuInstructionFusion>();
pipeline.AddPass<ScatterExpander>();
ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@ -356,7 +359,7 @@ llvm::TargetOptions CompilerTargetOptions(
llvm::TargetOptions target_options;
llvm_ir::SetTargetOptions(
/*fast_math_enabled=*/module_config.debug_options()
.xla_enable_fast_math(),
.xla_cpu_enable_fast_math(),
&target_options);
return target_options;
}
@ -523,7 +526,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
options::OptimizeForSizeRequested(module->config()),
module->config().debug_options().xla_enable_fast_math(),
module->config().debug_options().xla_cpu_enable_fast_math(),
module->config().debug_options().xla_llvm_disable_expensive_passes(),
pre_optimization_ir_hook, post_optimization_ir_hook);
llvm_module->setDataLayout(jit->data_layout());
@ -653,9 +656,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
// so we bail if the configs have conflicting flags. At the moment, the only
// flag that needs to be consistent is fast-math.
const bool fast_math_enabled =
modules[0]->config().debug_options().xla_enable_fast_math();
modules[0]->config().debug_options().xla_cpu_enable_fast_math();
for (const auto& module : modules) {
if (module->config().debug_options().xla_enable_fast_math() !=
if (module->config().debug_options().xla_cpu_enable_fast_math() !=
fast_math_enabled) {
return InvalidArgument(
"All HLO module configs must have the same value for "
@ -832,7 +835,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
CompilerFunctor compiler_functor(
target_machine.get(), &disassembler, opt_level,
options::OptimizeForSizeRequested(module->config()),
module->config().debug_options().xla_enable_fast_math(),
module->config().debug_options().xla_cpu_enable_fast_math(),
module->config().debug_options().xla_llvm_disable_expensive_passes(),
pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook);
std::unique_ptr<llvm::MemoryBuffer> object_file =

View File

@ -249,24 +249,11 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) {
if (GetRootPointsToSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous");
}
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
std::vector<OwningDeviceMemory> owning_buffers;
std::vector<se::DeviceMemoryBase> unowning_buffers;
TF_ASSIGN_OR_RETURN(
std::tie(unowning_buffers, owning_buffers),
CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
arguments));
TF_RETURN_IF_ERROR(ExecuteComputeFunction(
&run_options->run_options(), unowning_buffers, hlo_execution_profile));
return CreateResultShapedBuffer(run_options, &owning_buffers);
auto result,
ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile));
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
return std::move(result);
}
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
@ -277,6 +264,16 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
"Asynchronous execution on stream with hlo profiling is not yet "
"supported on CPU.");
}
return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr);
}
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) {
if (GetRootPointsToSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous");
}
auto* host_stream = dynamic_cast<se::host::HostStream*>(
run_options->stream()->implementation());
@ -310,19 +307,20 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
ServiceExecutableRunOptions run_options;
std::vector<se::DeviceMemoryBase> unowning_buffers;
std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
HloExecutionProfile* hlo_execution_profile;
void operator()() {
// Failing a CHECK here is not great, but I don't see an obvious way to
// return a failed Status asynchronously.
TF_CHECK_OK(executable->ExecuteComputeFunction(
&run_options.run_options(), unowning_buffers,
/*hlo_execution_profile=*/nullptr));
&run_options.run_options(), unowning_buffers, hlo_execution_profile));
}
};
host_stream->EnqueueTask(
AsyncRunTask{this, *run_options, std::move(unowning_buffers),
std::make_shared<std::vector<OwningDeviceMemory>>(
std::move(owning_buffers))});
std::move(owning_buffers)),
hlo_execution_profile});
return std::move(result);
}

View File

@ -85,6 +85,16 @@ class CpuExecutable : public Executable {
const BufferAssignment& buffer_assignment() const { return *assignment_; }
private:
// This is for sharing the code between ExecuteOnStream and
// ExecuteAsyncOnStream.
//
// Notice that it's tricky to use correctly, as the profile object (when it
// exists) must out-live the task.
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile);
// Creates an array suitable for passing as the "temps" argument to the JIT
// compiled function pointer.
//

View File

@ -1066,7 +1066,7 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
<< config.GetCacheKey();
const bool enable_fast_math =
hlo_module_config_.debug_options().xla_enable_fast_math();
hlo_module_config_.debug_options().xla_cpu_enable_fast_math();
const bool optimize_for_size =
options::OptimizeForSizeRequested(hlo_module_config_);
@ -1149,7 +1149,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
const bool enable_fast_math =
hlo_module_config_.debug_options().xla_enable_fast_math();
hlo_module_config_.debug_options().xla_cpu_enable_fast_math();
const bool optimize_for_size =
options::OptimizeForSizeRequested(hlo_module_config_);

View File

@ -99,7 +99,7 @@ IrEmitter::IrEmitter(
target_machine_features_(*target_machine_features) {
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
.xla_enable_fast_math()));
.xla_cpu_enable_fast_math()));
}
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
@ -158,11 +158,11 @@ void IrEmitter::InitializeIrFunction(const string& function_name) {
is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
: llvm::GlobalValue::InternalLinkage;
// Create and initialize new IrFunction.
compute_function_.reset(
new IrFunction(function_name, linkage,
options::OptimizeForSizeRequested(hlo_module_config_),
hlo_module_config_.debug_options().xla_enable_fast_math(),
module_, &b_, num_dynamic_loop_bounds_));
compute_function_.reset(new IrFunction(
function_name, linkage,
options::OptimizeForSizeRequested(hlo_module_config_),
hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_,
&b_, num_dynamic_loop_bounds_));
}
IrEmitter::~IrEmitter() {}
@ -577,7 +577,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*reduce_window,
/*operands=*/{reduce_window->operand(0)},
/*supported_types=*/{F32, BF16, S32}));
/*supported_types=*/{F32, BF16, S32, F16}));
// TODO(b/31410564): Implement dilation for reduce-window.
if (window_util::HasDilation(reduce_window->window())) {

View File

@ -21,8 +21,33 @@ limitations under the License.
namespace xla {
namespace {
// Pass which strips control dependencies from all instructions in the module.
class ControlDepRemover : public HloPassInterface {
public:
ControlDepRemover() = default;
tensorflow::StringPiece name() const override {
return "control-dep-remover";
}
StatusOr<bool> Run(HloModule* module) override {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
changed = changed || !instruction->control_predecessors().empty();
TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
}
}
return changed;
}
};
} // namespace
Despecializer::Despecializer() : pipeline_("despecializer") {
// TODO(b/70588125): Also deal with window reversal in a fast way.
pipeline_.AddPass<ControlDepRemover>();
pipeline_.AddPass<Defuser>();
pipeline_.AddPass<ImplicitBroadcastRemover>();
pipeline_.AddPass<BFloat16MixedPrecisionRemoval>();

View File

@ -1,6 +1,7 @@
# Description:
# GPU-specific components in XLA service implementation.
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
licenses(["notice"]) # Apache 2.0
@ -365,6 +366,7 @@ cc_library(
":gpu_executable",
":ir_emission_utils",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
@ -652,6 +654,7 @@ cc_library(
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/service:while_loop_constant_sinking",
@ -852,3 +855,35 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
cc_library(
name = "buffer_comparator",
srcs = ["buffer_comparator.cc"],
hdrs = ["buffer_comparator.h"],
deps = [
":gpu_executable",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
)
xla_test(
name = "buffer_comparator_test",
srcs = ["buffer_comparator_test.cc"],
backends = [
"cpu",
"gpu",
],
deps = [
":buffer_comparator",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,205 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include <cmath>
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace gpu {
static constexpr float kTolerance = 0.1f;
static string GetCompHloText(size_t num_elements) {
// Implements the textual format of the comparison routine, as it's more
// readable.
static constexpr char kF16CompHloText[] = R"(
HloModule CompareF16
MaxF32 {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %max = f32[] maximum(%lhs, %rhs)
}
Canonicalize (aparam: f16[SIZE]) -> f32[SIZE] {
%min_constant = f32[] constant(-65505)
%max_constant = f32[] constant(65505)
%large_constant = f32[] constant(1048576)
%min_values = f32[SIZE] broadcast(%min_constant), dimensions={}
%max_values = f32[SIZE] broadcast(%max_constant), dimensions={}
%large_values = f32[SIZE] broadcast(%large_constant), dimensions={}
%a = f16[SIZE] parameter(0)
%converted = f32[SIZE] convert(%a)
%clamped = f32[SIZE] clamp(%min_values, %converted, %max_values)
// Since the clamp() above already took care of infs, only NaNs will cause
// is-finite() to return false.
%is_finite = pred[SIZE] is-finite(%clamped)
ROOT %result = f32[SIZE] select(%is_finite, %clamped, %large_values)
}
ENTRY MaxDifference {
%one_constant = f32[] constant(1.0)
%zero_constant = f32[] constant(0.0)
%ones = f32[SIZE] broadcast(%one_constant), dimensions={}
%lhs = f16[SIZE] parameter(0)
%rhs = f16[SIZE] parameter(1)
%lhs_canonical = f32[SIZE] call(%lhs), to_apply=Canonicalize
%rhs_canonical = f32[SIZE] call(%rhs), to_apply=Canonicalize
%sub = f32[SIZE] subtract(%lhs_canonical, %rhs_canonical)
%sub_abs = f32[SIZE] abs(%sub)
%lhs_abs = f32[SIZE] abs(%lhs_canonical)
%rhs_abs = f32[SIZE] abs(%rhs_canonical)
%max = f32[SIZE] maximum(%lhs_abs, %rhs_abs)
%denominator = f32[SIZE] add(%max, %ones)
%error = f32[SIZE] divide(%sub_abs, %denominator)
ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32
})";
auto size_string = std::to_string(num_elements);
return tensorflow::str_util::StringReplace(
kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true);
}
StatusOr<F16BufferComparator> F16BufferComparator::Create(
se::DeviceMemory<Eigen::half> ref_buffer, Compiler* compiler,
DeviceMemoryAllocator* allocator, se::Stream* stream) {
auto stream_exec = stream->parent();
int64 num_elements = ref_buffer.ElementCount();
// One may consider using hlo_runner to do all the compilation and execution.
// However, as of the time hlo_runner doesn't support injection for Compiler*,
// Stream*, or even the allocator. We may revisit this in the future if it
// proves to be a maintenance burden.
TF_ASSIGN_OR_RETURN(
auto exec, ([&]() -> StatusOr<std::unique_ptr<Executable>> {
HloModuleConfig config;
DebugOptions debug_options;
debug_options.set_xla_backend_optimization_level(2);
config.set_debug_options(debug_options);
TF_ASSIGN_OR_RETURN(
auto module, ParseHloString(GetCompHloText(num_elements), config));
TF_ASSIGN_OR_RETURN(
module,
compiler->RunHloPasses(std::move(module), stream_exec, nullptr));
return compiler->RunBackend(std::move(module), stream_exec, nullptr);
}()));
TF_ASSIGN_OR_RETURN(
auto shaped_buffer, ([&]() -> StatusOr<ScopedShapedBuffer> {
auto device_ordinal = stream_exec->device_ordinal();
TF_ASSIGN_OR_RETURN(
auto owning_buffer,
allocator->Allocate(device_ordinal, ref_buffer.size()));
se::DeviceMemory<Eigen::half> buffer(
owning_buffer.AsDeviceMemoryBase());
stream->ThenMemcpy(&buffer, ref_buffer, ref_buffer.size());
Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements});
ScopedShapedBuffer ret(shape, shape, allocator, device_ordinal);
ret.set_buffer(std::move(owning_buffer), {});
return std::move(ret);
}()));
return F16BufferComparator(stream, allocator, std::move(exec),
std::move(shaped_buffer));
}
StatusOr<bool> F16BufferComparator::CompareEqualImpl(
se::DeviceMemory<Eigen::half> test_buffer) {
if (ref_buffer_.root_buffer().size() != test_buffer.size()) {
return InternalError("Mismatched buffer size: %lld vs %lld",
ref_buffer_.root_buffer().size(), test_buffer.size());
}
int64 num_elements = test_buffer.ElementCount();
TF_ASSIGN_OR_RETURN(
auto result_buffer, ([&]() -> StatusOr<ScopedShapedBuffer> {
auto stream_exec = stream_->parent();
Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements});
auto device_ordinal = stream_exec->device_ordinal();
ShapedBuffer shaped_test_buffer(shape, shape, stream_exec->platform(),
device_ordinal);
shaped_test_buffer.set_buffer(test_buffer, {});
ExecutableRunOptions run_options;
run_options.set_device_ordinal(stream_exec->device_ordinal());
run_options.set_stream(stream_);
run_options.set_allocator(allocator_);
ServiceExecutableRunOptions service_run_options(run_options);
return exec_->ExecuteOnStream(
&service_run_options, {&ref_buffer_, &shaped_test_buffer}, nullptr);
}()));
float result;
CHECK(result_buffer.root_buffer().size() == sizeof(result));
stream_->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result));
TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone());
return result < kTolerance;
}
StatusOr<bool> F16BufferComparator::CompareEqual(
se::DeviceMemory<Eigen::half> test_buffer) {
TF_ASSIGN_OR_RETURN(auto result, CompareEqualImpl(test_buffer));
if (result) {
return true;
}
// Host side code that does the same thing, but report some of the
// differences as well.
int64 n = test_buffer.ElementCount();
std::vector<half> host_ref_buffer(n), host_test_buffer(n);
stream_->ThenMemcpy(host_ref_buffer.data(), ref_buffer_.root_buffer(),
ref_buffer_.root_buffer().size());
stream_->ThenMemcpy(host_test_buffer.data(), test_buffer, test_buffer.size());
TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone());
const auto canonicalize = [](float a) -> float {
constexpr float kBigNumer = 1048576.;
constexpr float kMaxFp16Value = 65504.;
if (std::isnan(a)) {
return kBigNumer;
}
if (std::isinf(a)) {
if (a < 0) {
return -(kMaxFp16Value + 1);
}
return kMaxFp16Value + 1;
}
return a;
};
int differences_seen = 0;
for (int64 i = 0; i < n && differences_seen < 10; i++) {
float original_ref = static_cast<float>(host_ref_buffer[i]);
float original_test = static_cast<float>(host_test_buffer[i]);
float ref = canonicalize(original_ref);
float test = canonicalize(original_test);
if (!(std::abs(ref - test) / (std::max(std::abs(ref), std::abs(test)) + 1) <
kTolerance)) {
differences_seen++;
LOG(ERROR) << "Difference at " << i << ": " << original_ref << " vs "
<< original_test;
}
}
return false;
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,71 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
namespace gpu {
// A fp16 comparator that internally keeps a reference buffer, and compares it
// against other test buffers.
class F16BufferComparator {
public:
F16BufferComparator(const F16BufferComparator&) = delete;
F16BufferComparator(F16BufferComparator&&) = default;
// Creates a new comparator. It internally allocates a buffer initialized by
// ref_buffer.
static StatusOr<F16BufferComparator> Create(
se::DeviceMemory<Eigen::half> ref_buffer, Compiler* compiler,
DeviceMemoryAllocator* allocator, se::Stream* stream);
// Returns true if the internally allocated buffer "compares equal" to
// test_buffer. The definition of "equal" is:
// * All NaNs equal.
// * All infs are treated as 65505 or -65505, so that this checker is tolerant
// to fp16 overflows.
// * With NaNs and infs taken care of, a and b compare equal iff:
// abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance
//
// See the implementation for the tolerance value.
StatusOr<bool> CompareEqual(se::DeviceMemory<Eigen::half> test_buffer);
private:
F16BufferComparator(se::Stream* stream, DeviceMemoryAllocator* allocator,
std::unique_ptr<Executable> exec,
ScopedShapedBuffer ref_buffer)
: stream_(stream),
allocator_(allocator),
exec_(std::move(exec)),
ref_buffer_(std::move(ref_buffer)) {}
StatusOr<bool> CompareEqualImpl(se::DeviceMemory<Eigen::half> test_buffer);
se::Stream* stream_;
DeviceMemoryAllocator* allocator_;
std::unique_ptr<Executable> exec_;
ScopedShapedBuffer ref_buffer_;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_

View File

@ -0,0 +1,126 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include <limits>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace gpu {
namespace {
class BufferComparatorTest : public testing::Test {
protected:
BufferComparatorTest()
: backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()),
stream_exec_(backend_->default_stream_executor()),
allocator_(stream_exec_->platform(), {stream_exec_}),
compiler_(Compiler::GetForPlatform(stream_exec_->platform())
.ConsumeValueOrDie()) {}
// Take floats only for convenience. Still uses half internally.
bool CompareEqualFloatBuffers(const std::vector<float>& lhs_float,
const std::vector<float>& rhs_float) {
std::vector<half> lhs(lhs_float.begin(), lhs_float.end());
std::vector<half> rhs(rhs_float.begin(), rhs_float.end());
se::Stream stream(stream_exec_);
stream.Init();
auto owning_lhs_buffer =
allocator_
.Allocate(stream_exec_->device_ordinal(), lhs.size() * sizeof(half))
.ConsumeValueOrDie();
auto owning_rhs_buffer =
allocator_
.Allocate(stream_exec_->device_ordinal(), rhs.size() * sizeof(half))
.ConsumeValueOrDie();
auto lhs_buffer =
se::DeviceMemory<Eigen::half>(owning_lhs_buffer.AsDeviceMemoryBase());
auto rhs_buffer =
se::DeviceMemory<Eigen::half>(owning_rhs_buffer.AsDeviceMemoryBase());
stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size());
stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size());
TF_CHECK_OK(stream.BlockHostUntilDone());
return F16BufferComparator::Create(lhs_buffer, compiler_, &allocator_,
&stream)
.ConsumeValueOrDie()
.CompareEqual(rhs_buffer)
.ConsumeValueOrDie();
}
std::unique_ptr<Backend> backend_;
se::StreamExecutor* stream_exec_;
StreamExecutorMemoryAllocator allocator_;
Compiler* compiler_;
};
TEST_F(BufferComparatorTest, TestNaNs) {
EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")}));
// NaN values with different bit patterns should compare equal.
EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")}));
EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.}));
}
TEST_F(BufferComparatorTest, TestInfs) {
const auto inf = std::numeric_limits<float>::infinity();
EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")}));
EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf}));
EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504}));
EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504}));
EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504}));
EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504}));
EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20}));
EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20}));
EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20}));
EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20}));
}
TEST_F(BufferComparatorTest, TestNumbers) {
EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1}));
EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1}));
EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1}));
EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10}));
EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9}));
}
TEST_F(BufferComparatorTest, TestMultiple) {
EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60},
{20.1, 30.1, 40.1, 50.1, 60.1}));
std::vector<float> lhs(200);
std::vector<float> rhs(200);
for (int i = 0; i < 200; i++) {
EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs))
<< "should be the same at index " << i;
lhs[i] = 3;
rhs[i] = 5;
EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs))
<< "should be the different at index " << i;
lhs[i] = 0;
rhs[i] = 0;
}
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -30,7 +30,6 @@ namespace {
using se::DeviceMemoryBase;
using se::dnn::AlgorithmConfig;
using se::dnn::AlgorithmDesc;
using tensorflow::gtl::nullopt;
using tensorflow::gtl::optional;
class ScratchAllocator : public se::ScratchAllocator {
@ -173,7 +172,7 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// cache misses and doing extra work. Overall, caching doesn't seem worth the
// trouble, but we may want to revisit this if we ever find a model where
// caching would speed up compilation a lot.
optional<std::tuple<int64, bool, int64>>
StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
@ -206,45 +205,25 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// Allocate space for the input, filter, and output of the convolution. We
// use a ScratchAllocator for this instead of calling allocator_ directly so
// that our allocations don't leak.
//
// We don't put any data in these buffers, because (in theory, anyway) the
// speed of a conv isn't affected by the data being convolved.
ScratchAllocator input_output_allocator(device_ordinal, allocator);
StatusOr<DeviceMemoryBase> maybe_input_buf =
input_output_allocator.AllocateBytes(&stream,
ShapeUtil::ByteSizeOf(input_shape));
StatusOr<DeviceMemoryBase> maybe_filter_buf =
input_output_allocator.AllocateBytes(&stream,
ShapeUtil::ByteSizeOf(filter_shape));
StatusOr<DeviceMemoryBase> maybe_output_buf =
input_output_allocator.AllocateBytes(&stream,
ShapeUtil::ByteSizeOf(output_shape));
if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() ||
!maybe_output_buf.ok()) {
LOG(WARNING)
<< "Couldn't allocate space for input/filter/output of convolution "
<< instr->ToString() << ". Falling back to default algorithm.";
return nullopt;
}
DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie();
DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie();
DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie();
TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf,
input_output_allocator.AllocateBytes(
&stream, ShapeUtil::ByteSizeOf(input_shape)));
TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf,
input_output_allocator.AllocateBytes(
&stream, ShapeUtil::ByteSizeOf(filter_shape)));
TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf,
input_output_allocator.AllocateBytes(
&stream, ShapeUtil::ByteSizeOf(output_shape)));
// Although we don't have evidence this matters, zero out the buffers before
// autotuning. It's conceivable that using uninitialized memory as the inputs
// might affect performance if e.g. the inputs contain denormals, and this is
// easy enough.
if (!stream.ThenMemZero(&input_buf, input_buf.size())
.ThenMemZero(&filter_buf, filter_buf.size())
.ThenMemZero(&output_buf, output_buf.size())
.BlockHostUntilDone()
.ok()) {
LOG(WARNING)
<< "Couldn't zero out input/filter/output buffer for convolution "
<< instr->ToString() << ". Falling back to default algorithm.";
return nullopt;
}
TF_RETURN_IF_ERROR(stream.ThenMemZero(&input_buf, input_buf.size())
.ThenMemZero(&filter_buf, filter_buf.size())
.ThenMemZero(&output_buf, output_buf.size())
.BlockHostUntilDone());
const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
input_shape, output_shape, dnums, stream_exec_);
@ -292,9 +271,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
best_result_bytes_used);
}
LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString()
<< " failed. Falling back to default algorithm.";
return nullopt;
return InternalError(
"All algorithms tried for convolution %s failed. Falling back to "
"default algorithm.",
instr->ToString().c_str());
}
StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
@ -305,12 +285,13 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
const auto& lhs_shape = instr->operand(0)->shape();
const auto& rhs_shape = instr->operand(1)->shape();
const auto& conv_result_shape = instr->shape().tuple_shapes(0);
optional<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
if (call_target == kCudnnConvForwardCallTarget) {
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
/*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape,
instr->window(), instr->convolution_dimension_numbers(), instr);
alg_scratch_and_tc =
PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
/*filter_shape=*/rhs_shape,
/*output_shape=*/conv_result_shape, instr->window(),
instr->convolution_dimension_numbers(), instr);
} else if (call_target == kCudnnConvBackwardInputCallTarget) {
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
@ -326,7 +307,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
<< instr->ToString();
}
if (!alg_scratch_and_tc.has_value()) {
if (!alg_scratch_and_tc.ok()) {
LOG(ERROR) << alg_scratch_and_tc.status();
return false;
}
@ -334,7 +316,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
bool tensor_ops_enabled;
int64 scratch_bytes;
std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc;
std::tie(algorithm, tensor_ops_enabled, scratch_bytes) =
alg_scratch_and_tc.ConsumeValueOrDie();
VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
<< NumBytesToString(scratch_bytes)

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@ -34,8 +35,9 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
// memory while timing the various convolution algorithms. If it's null,
// we'll use the default allocator on the StreamExecutor.
CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* allocator)
: stream_exec_(stream_exec), allocator_(allocator) {}
DeviceMemoryAllocator* allocator,
Compiler* compiler)
: stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {}
tensorflow::StringPiece name() const override {
return "cudnn-convolution-algorithm-picker";
@ -46,13 +48,14 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
private:
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
tensorflow::gtl::optional<std::tuple<int64, bool, int64>> PickBestAlgorithm(
StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
Compiler* compiler_;
};
} // namespace gpu

View File

@ -210,11 +210,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
return make_sqrt();
}
if (hlo_module_config_.debug_options().xla_enable_fast_math() &&
IsFPLiteralWithValue(rhs, -.5)) {
if (IsFPLiteralWithValue(rhs, -.5)) {
VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString();
// LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX
// rsqrt.approx instruction.
//
// TODO(jlebar): Does this happen with fastmath disabled? If not, should
// we force-enable it?
TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt());
return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
}
@ -274,16 +276,18 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
PrimitiveType prim_type, llvm::Value* value) const {
// If we don't care much about precision, emit a fast approximation of
// tanh.
if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
// Upcast F16 to F32 if necessary.
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
llvm::Value* input = b_->CreateFPCast(value, type);
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
return b_->CreateFPCast(fast_tanh, value->getType());
}
return EmitLibdeviceMathCall("__nv_tanh", {value}, {prim_type}, prim_type);
// Emit a fast approximation of tanh instead of calling __nv_tanh.
// __nv_tanh is particularly bad because it contains branches, thus
// preventing LLVM's load-store vectorizer from working its magic across a
// function which contains tanh calls.
//
// This routine isn't numerically precise, but it's good enough for ML.
// Upcast F16 to F32 if necessary.
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
llvm::Value* input = b_->CreateFPCast(value, type);
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
return b_->CreateFPCast(fast_tanh, value->getType());
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(

View File

@ -64,7 +64,7 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
hlo_module_config_(hlo_module_config) {
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config.debug_options()
.xla_enable_fast_math()));
.xla_gpu_enable_fast_math()));
}
Status IrEmitter::DefaultAction(HloInstruction* hlo) {

View File

@ -180,7 +180,7 @@ std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
TargetOptions target_options = InitTargetOptionsFromCodeGenFlags();
llvm_ir::SetTargetOptions(
/*fast_math_enabled=*/hlo_module_config.debug_options()
.xla_enable_fast_math(),
.xla_gpu_enable_fast_math(),
&target_options);
// Enable FMA synthesis.

View File

@ -72,6 +72,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/scatter_expander.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
@ -130,8 +131,12 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) {
}
// Runs optimization passes on the given HLO module.
//
// It takes a compiler pointer, as passes may compile and execute HLOs on the
// fly for cuDNN verification or other purposes.
Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* device_allocator) {
DeviceMemoryAllocator* device_allocator,
Compiler* compiler) {
{
HloPassPipeline pipeline("optimization");
pipeline.AddInvariantChecker<HloVerifier>();
@ -167,6 +172,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// elimination has to come after that pass.
pipeline.AddPass<ZeroSizedHloElimination>();
pipeline.AddPass<ScatterExpander>();
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; });
@ -245,8 +252,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// the gte(customcall, 0) would probably already be into a fusion node. We
// can't simplify across HloComputation boundaries, so in this case we
// wouldn't be able to simplify away the new_tuple bits.
pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(stream_exec,
device_allocator);
pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(
stream_exec, device_allocator, compiler);
// Clean up new_tuple described above.
pipeline.AddPass<TupleSimplifier>();
@ -492,11 +499,15 @@ NVPTXCompiler::NVPTXCompiler()
StatusOr<std::unique_ptr<HloModule>> NVPTXCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* device_allocator) {
// We dump the post-optimization HLO in RunBackend so no need to dump it here.
VLOG(2) << "*** HLO Before Optimization";
XLA_VLOG_LINES(2, module->ToString());
XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses");
tracing::ScopedActivity activity("HLO Transforms", module->name(),
/*is_expensive=*/true);
TF_RETURN_IF_ERROR(
OptimizeHloModule(module.get(), stream_exec, device_allocator));
OptimizeHloModule(module.get(), stream_exec, device_allocator, this));
return std::move(module);
}
@ -548,6 +559,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
// include headers, so no need for us to print them ourselves.
XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString());
XLA_VLOG_LINES(2, buffer_assignment->ToString());
VLOG(2) << "*** HLO After Optimization";
XLA_VLOG_LINES(2, module->ToString());
const string xla_dump_optimized_hlo_proto_to =
module->config().debug_options().xla_dump_optimized_hlo_proto_to();

View File

@ -174,6 +174,29 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
}
StatusOr<HloInstruction*> MakeMapHlo(
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* map_computation) {
CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
HloComputation* computation = operands.front()->parent();
std::vector<const Shape*> operand_shapes;
int64 max_operand_rank = 0;
for (const HloInstruction* operand : operands) {
CHECK_EQ(computation, operand->parent());
operand_shapes.push_back(&operand->shape());
max_operand_rank =
std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
}
std::vector<int64> map_dims(max_operand_rank);
std::iota(map_dims.begin(), map_dims.end(), 0);
TF_ASSIGN_OR_RETURN(
Shape map_shape,
ShapeInference::InferMapShape(
operand_shapes, map_computation->ComputeProgramShape(), map_dims));
return computation->AddInstruction(
HloInstruction::CreateMap(map_shape, operands, map_computation));
}
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
CHECK_GT(n, 0);
@ -251,6 +274,38 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
return MakeReshapeHlo(output_shape, operand);
}
StatusOr<HloInstruction*> InsertDegenerateDims(
HloInstruction* operand, ArraySlice<int64> dims_to_insert) {
CHECK(c_is_sorted(dims_to_insert));
const Shape& operand_shape = operand->shape();
int64 output_shape_rank =
operand_shape.dimensions_size() + dims_to_insert.size();
for (auto dim_to_insert : dims_to_insert) {
CHECK_LT(dim_to_insert, output_shape_rank);
}
std::vector<int64> output_shape_dim_bounds;
output_shape_dim_bounds.reserve(output_shape_rank);
int64 operand_dims_idx = 0;
int64 dims_to_insert_idx = 0;
for (int64 i = 0; i < output_shape_rank; ++i) {
if (dims_to_insert_idx < dims_to_insert.size() &&
i == dims_to_insert[dims_to_insert_idx]) {
output_shape_dim_bounds.push_back(1);
++dims_to_insert_idx;
} else {
output_shape_dim_bounds.push_back(
operand_shape.dimensions(operand_dims_idx));
++operand_dims_idx;
}
}
Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
output_shape_dim_bounds);
return MakeReshapeHlo(output_shape, operand);
}
StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
int64 zeros_to_prepend,
int64 zeros_to_append) {

View File

@ -102,6 +102,12 @@ StatusOr<HloInstruction*> MakeConcatHlo(
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dim_numbers);
// Creates a Map HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation.
StatusOr<HloInstruction*> MakeMapHlo(
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* map_computation);
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
// these add all the instructions they generate into the computation containing
@ -144,6 +150,16 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
StatusOr<HloInstruction*> ElideDegenerateDims(
HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_elide);
// Inserts (via reshape) a set of degenerate dimensions (dimensions containing
// exactly one element), `dims_to_insert` into `operand`. The dimensions in
// `dims_to_insert` refer to the dimensions in the result, and hence should be
// less than the rank of the result. Also, `dims_to_insert` must be sorted.
//
// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
StatusOr<HloInstruction*> InsertDegenerateDims(
HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_insert);
// Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
// front and `zeros_to_append` zeros in the back.
StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,

View File

@ -143,8 +143,47 @@ TokKind HloLexer::LexToken() {
return TokKind::kLparen;
case ')':
return TokKind::kRparen;
case '/':
return LexComment();
case '/': {
if (PeekCurrentChar() == '*') {
// This is the start of a /*...*/ delimited comment. Save the current
// location in case the comment is unterminated so the error message
// will point to the beginning of the comment.
const char* comment_start = current_ptr_;
current_ptr_++;
// Advance until '*/' is found.
while (true) {
int current = GetNextChar();
if (current == '*' && PeekCurrentChar() == '/') {
// End of comment.
current_ptr_++;
break;
}
if (current == kEOF) {
// Unterminated comment.
current_ptr_ = comment_start;
return TokKind::kError;
}
}
// Return no token for the comment. Keep lexing.
continue;
} else if (PeekCurrentChar() == '/') {
// This is the start of a '//' delimited comment. Throw away
// everything until end of line or file. The end-of-line character(s)
// are left unlexed in the buffer which is harmless because these are
// skipped later by the lexer. This approach enables support for
// different end-of-line encodings.
while (true) {
int current = PeekCurrentChar();
if (current == kEOF || current == '\n' || current == '\r') {
break;
}
current_ptr_++;
}
continue;
}
// A lone '/' is an error.
return TokKind::kError;
}
case '"':
return LexString();
}
@ -357,16 +396,6 @@ tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const {
return StringPieceFromPointers(start, end);
}
TokKind HloLexer::LexComment() {
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
static LazyRE2 comment_pattern = {R"(\/\*.*?\*\/)"};
if (RE2::Consume(&consumable, *comment_pattern)) {
current_ptr_ = consumable.begin();
return TokKind::kComment;
}
return TokKind::kError;
}
// Lexes quoted string with escaping characters. If matched, the quoted string
// will be unescaped and stored to str_val_.
TokKind HloLexer::LexString() {
@ -412,8 +441,6 @@ string TokKindToString(TokKind kind) {
return "kRparen";
case TokKind::kArrow:
return "kArrow";
case TokKind::kComment:
return "kComment";
case TokKind::kw_HloModule:
return "kw_HloModule";
case TokKind::kw_ENTRY:

View File

@ -105,7 +105,6 @@ class HloLexer {
TokKind LexShape();
TokKind LexConstant();
TokKind LexNumberOrPattern();
TokKind LexComment();
TokKind LexString();
const tensorflow::StringPiece buf_;

View File

@ -231,6 +231,7 @@ HLO_MATCHER(Tanh);
HLO_MATCHER(Trace);
HLO_MATCHER(Transpose);
HLO_MATCHER(Tuple);
HLO_MATCHER(TupleSelect);
HLO_MATCHER(While);
// The special cases below let you check additional information about the

View File

@ -166,7 +166,7 @@ class HloModuleGroupMetadata {
//
// Precondition: IsCompanionWhile(instruction) is true.
const std::unordered_set<HloInstruction*>& Companions(
HloInstruction* instruction) const {
const HloInstruction* instruction) const {
CHECK_EQ(companion_set_index_.count(instruction), 1);
return companion_set(companion_set_index_.at(instruction));
}
@ -243,7 +243,7 @@ class HloModuleGroupMetadata {
companion_sets_;
// Map from each companion while instruction to the index into companion_set_.
tensorflow::gtl::FlatMap<HloInstruction*, int64> companion_set_index_;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_;
// Map from computation to the instruction using it (a kWhile, kConditional).
tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction>

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@ -37,24 +38,38 @@ namespace xla {
std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
HloInstruction* instruction) {
std::vector<HloInstruction*> predecessors;
std::vector<HloInstruction*>
predecessors; // Use a vector to avoid non-determinism.
tensorflow::gtl::FlatSet<HloInstruction*> unique;
// Adds to the unique predecessors list and also add companion instructions
// if the given predecessor has those.
// Adds to the unique predecessors list; if the predecessors is a companion
// instruction, also add companion instructions; if the predecessors is a
// cross-module all-reduce, also add the all-reduce instructions in the same
// group.
auto add_unique_predecessor = [&](HloInstruction* predecessor) {
if (std::find(predecessors.begin(), predecessors.end(), predecessor) !=
predecessors.end()) {
if (unique.find(predecessor) != unique.end()) {
return;
}
if (!metadata_.IsCompanionInstruction(predecessor)) {
predecessors.push_back(predecessor);
if (metadata_.IsCompanionInstruction(predecessor)) {
for (HloInstruction* instr : metadata_.Companions(predecessor)) {
if (unique.insert(instr).second) {
predecessors.push_back(instr);
}
}
return;
}
for (HloInstruction* companion : metadata_.Companions(predecessor)) {
predecessors.push_back(companion);
if (predecessor->IsCrossModuleAllReduce()) {
for (HloInstruction* instr :
metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) {
if (unique.insert(instr).second) {
predecessors.push_back(instr);
}
}
return;
}
unique.insert(predecessor);
predecessors.push_back(predecessor);
};
// If the given instruction is a companion instruction, we need to find the
// predecessors of all of its companion instructions. If the instruction is an
// all-reduce, we need to find the predecessors of all the peer all-reduce
@ -98,22 +113,37 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
HloInstruction* instruction) {
std::vector<HloInstruction*> successors;
std::vector<HloInstruction*>
successors; // Use a vector to avoid non-determinism.
tensorflow::gtl::FlatSet<HloInstruction*> unique;
// Adds to the unique successors list and also add companion instructions
// if the given successor has those.
// Adds to the unique successors list; if the successor is a companion
// instruction, also add companion instructions; if the successor is a
// cross-module all-reduce, also add the all-reduce instructions in the same
// group.
auto add_unique_successor = [&](HloInstruction* successor) {
if (std::find(successors.begin(), successors.end(), successor) !=
successors.end()) {
if (unique.find(successor) != unique.end()) {
return;
}
if (!metadata_.IsCompanionInstruction(successor)) {
successors.push_back(successor);
if (metadata_.IsCompanionInstruction(successor)) {
for (HloInstruction* instr : metadata_.Companions(successor)) {
if (unique.insert(instr).second) {
successors.push_back(instr);
}
}
return;
}
for (HloInstruction* companion : metadata_.Companions(successor)) {
successors.push_back(companion);
if (successor->IsCrossModuleAllReduce()) {
for (HloInstruction* instr :
metadata_.GetAllReduceGroup(*successor->all_reduce_id())) {
if (unique.insert(instr).second) {
successors.push_back(instr);
}
}
return;
}
unique.insert(successor);
successors.push_back(successor);
};
// If the given instruction is a companion instruction, we need to find the

View File

@ -1824,7 +1824,6 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
break;
}
case TokKind::kComma:
case TokKind::kComment:
// Skip.
lexer_.Lex();
break;

View File

@ -1560,6 +1560,81 @@ ENTRY consts {
"last");
}
TEST_F(HloParserTest, Comments) {
const string original = R"(/* module description. */
HloModule comments:
ENTRY /*comment*/ c1 {
/* blah */
ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/})
/* comment */
}
/* something else */
)";
auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
}
TEST_F(HloParserTest, MultilineComments) {
const string original = R"(HloModule multiline_comment:
ENTRY c1 {
/*
ROOT foo = f32[1]{0} constant({12345})
*/
ROOT const1 = f32[1]{0} constant({12345})
/*
a
b
c
d
*/
})";
auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
}
TEST_F(HloParserTest, UnterminatedComment) {
const string original = R"(HloModule unterminated_comment:
ENTRY c1 {
/* unterminated
ROOT const1 = f32[1]{0} constant({12345})
})";
// Verify that the error message points to the beginning of the unterminated
// comment.
ExpectHasSubstr(ParseHloString(original).status().error_message(),
"/* unterminated\n^");
}
TEST_F(HloParserTest, SlashSlashComments) {
const string original = R"(HloModule slash_slash_comment:
// Garbage
ENTRY c1 {
// Foo bar
ROOT const1 = f32[1]{0} constant({12345}) // Something else
})";
auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
}
TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
const string original =
"HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
"bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
}
TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
const string original =
"HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
"bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
}
TEST_F(HloParserTest, MultipleEntries) {
const string original = R"(HloModule multiple_entries:
ENTRY c1 {

View File

@ -44,7 +44,6 @@ enum class TokKind {
kRparen, // ( )
kArrow, // ->
kComment, // /*xxx*/
// Keywords
kw_HloModule,

View File

@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
}
// No "synchronize all activity" implemented for this platform at the moment.
bool SynchronizeAllActivity() override { return false; }
bool SynchronizeAllActivity() override { return true; }
bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override {
return false;
}

View File

@ -0,0 +1,350 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/scatter_expander.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
using tensorflow::gtl::ArraySlice;
// Transposes the given scatter_indices such that the index_vector_dim becomes
// the most-minor dimension.
static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
HloInstruction* scatter_indices, int64 index_vector_dim) {
const Shape& scatter_indices_shape = scatter_indices->shape();
if (scatter_indices_shape.dimensions_size() == index_vector_dim) {
return scatter_indices;
}
if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) {
return scatter_indices;
}
std::vector<int64> permutation;
permutation.reserve(scatter_indices_shape.dimensions_size());
for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
if (i != index_vector_dim) {
permutation.push_back(i);
}
}
permutation.push_back(index_vector_dim);
return MakeTransposeHlo(scatter_indices, permutation);
}
// Canonicalizes the scatter_indices tensor in order to keep them uniform while
// performing the scatter operation.
static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
HloInstruction* scatter_indices, int64 index_vector_dim) {
// Transpose the non-index-vector dimensions to the front.
TF_ASSIGN_OR_RETURN(
HloInstruction * transposed_scatter_indices,
TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim));
bool indices_are_scalar =
index_vector_dim == scatter_indices->shape().dimensions_size();
// The number of dimensions in scatter_indices that are index dimensions.
const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1;
// If there is only one index (i.e. scatter_indices has rank 1 and this
// scatter is really just a dynamic update slice) add a leading degenerate
// dimension for uniformity. Otherwise create a "collapsed" leading dimension
// that subsumes all of the non-index-vector dimensions.
const Shape& shape = transposed_scatter_indices->shape();
if (shape.dimensions_size() == index_dims_in_scatter_indices) {
return PrependDegenerateDims(transposed_scatter_indices, 1);
} else {
// Collapse all but the dimensions (0 or 1) in scatter_indices containing
// the index vectors.
return CollapseFirstNDims(
transposed_scatter_indices,
shape.dimensions_size() - index_dims_in_scatter_indices);
}
}
// Permutes the `updates` tensor such that all the scatter dims appear in the
// major dimensions and all the window dimensions appear in the minor
// dimensions.
static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
HloInstruction* updates, ArraySlice<int64> update_window_dims) {
std::vector<int64> permutation;
const int64 updates_rank = ShapeUtil::Rank(updates->shape());
permutation.reserve(updates_rank);
for (int64 i = 0; i < updates_rank; ++i) {
bool is_scatter_dim = !c_binary_search(update_window_dims, i);
if (is_scatter_dim) {
permutation.push_back(i);
}
}
for (auto window_dim : update_window_dims) {
permutation.push_back(window_dim);
}
return MakeTransposeHlo(updates, permutation);
}
// Expands or contracts the scatter indices in the updates tensor.
static StatusOr<HloInstruction*> AdjustScatterDims(
const Shape& scatter_indices_shape, HloInstruction* updates,
int64 index_vector_dim) {
int64 num_scatter_dims = scatter_indices_shape.dimensions_size();
if (index_vector_dim < scatter_indices_shape.dimensions_size()) {
--num_scatter_dims;
}
if (num_scatter_dims == 0) {
// If there are no scatter dims, this must be a dynamic-update-slice kind of
// scatter. In this case, we prepend a degenerate dimension to work
// uniformly in the while loop.
return PrependDegenerateDims(updates, 1);
}
return CollapseFirstNDims(updates, num_scatter_dims);
}
// Expands an index vector from the scatter_indices tensor into a vector that
// can be used to dynamic-update-slice to perform the scatter update.
static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers,
int64 operand_rank) {
HloComputation* computation = index_vector->parent();
const Shape& index_shape = index_vector->shape();
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
// We extract out individual components from the smaller index and concatenate
// them (interspersing zeros as needed) into the larger index.
std::vector<HloInstruction*> expanded_index_components;
for (int i = 0; i < operand_rank; i++) {
int64 index_vector_dim_index =
FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i);
if (index_vector_dim_index !=
dim_numbers.scatter_dims_to_operand_dims_size()) {
TF_ASSIGN_OR_RETURN(
HloInstruction * component_to_concat,
MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
/*limit_indices=*/{index_vector_dim_index + 1},
/*strides=*/{1}));
expanded_index_components.push_back(component_to_concat);
} else {
expanded_index_components.push_back(zero);
}
}
return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
}
// Body of the while loop that performs the scatter operation using other HLOs.
static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
HloInstruction* scatter, HloInstruction* induction_var,
const std::vector<HloInstruction*>& loop_state) {
const ScatterDimensionNumbers& dim_numbers =
scatter->scatter_dimension_numbers();
CHECK_EQ(loop_state.size(), 3);
HloInstruction* operand = loop_state[0];
HloInstruction* scatter_indices = loop_state[1];
HloInstruction* updates = loop_state[2];
bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1;
CHECK_EQ(has_scalar_indices,
dim_numbers.index_vector_dim() ==
scatter->operand(1)->shape().dimensions_size());
// Build a vector form of the induction variable of the while loop.
TF_ASSIGN_OR_RETURN(
HloInstruction * induction_var_as_vector,
MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/{1}));
// Pick the index to scatter from scatter_indices based on the induction_var
// and transform that to an index into the `operand` space.
HloInstruction* index_vector;
if (has_scalar_indices) {
TF_ASSIGN_OR_RETURN(
index_vector,
MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1}));
} else {
TF_ASSIGN_OR_RETURN(
HloInstruction * index_into_scatter_indices,
PadVectorWithZeros(induction_var_as_vector,
/*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
int index_vector_size = scatter_indices->shape().dimensions(1);
TF_ASSIGN_OR_RETURN(
HloInstruction * index_vector_2d,
MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices,
{1, index_vector_size}));
TF_ASSIGN_OR_RETURN(index_vector,
ElideDegenerateDims(index_vector_2d, {0}));
}
TF_ASSIGN_OR_RETURN(
HloInstruction * scatter_slice_start,
ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
operand->shape().dimensions_size()));
// Extract the slice to be used to update from `updates` tensor for the
// induction_var corresponding to this iteration of the while loop.
TF_ASSIGN_OR_RETURN(
HloInstruction * index_into_updates,
PadVectorWithZeros(
induction_var_as_vector, /*zeros_to_prepend=*/0,
/*zeros_to_append=*/updates->shape().dimensions_size() - 1));
std::vector<int64> update_slice_bounds(updates->shape().dimensions().begin(),
updates->shape().dimensions().end());
update_slice_bounds[0] = 1;
TF_ASSIGN_OR_RETURN(
HloInstruction * update_slice,
MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds));
TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter,
ElideDegenerateDims(update_slice, {0}));
TF_ASSIGN_OR_RETURN(
HloInstruction * update_slice_with_dims_inserted,
InsertDegenerateDims(update_slice_for_scatter,
AsInt64Slice(dim_numbers.inserted_window_dims())));
// Extact the slice to update from `operand` tensor.
const Shape& update_slice_shape = update_slice_with_dims_inserted->shape();
TF_ASSIGN_OR_RETURN(
HloInstruction * operand_slice_to_update,
MakeDynamicSliceHlo(operand, scatter_slice_start,
AsInt64Slice(update_slice_shape.dimensions())));
// Compute the new value for the slice to be updated in `operand` tensor by
// combining the existing value and the update value using the update
// computation.
TF_ASSIGN_OR_RETURN(
HloInstruction * updated_operand_slice,
MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted},
scatter->to_apply()));
// Write the updated value of the slice into `operand` tensor.
TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand,
MakeDynamicUpdateSliceHlo(operand, updated_operand_slice,
scatter_slice_start));
return StatusOr<std::vector<HloInstruction*>>{
{updated_operand, scatter_indices, updates}};
}
// High Level Algorithm.
//
// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where
// each row is an index into the operand.
// 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1`
// and the scatter dim is the most-major dimension.
// 3. Iterate over the set of indices in the canonicalized scatter_indices
// tensor using a while loop, updating the operand for each such index. Each
// iteration of this while loop performs the following:
// a. Pick the index from scatter_indices for this iteration.
// b. Transfrom this index into an index into the operand space.
// c. Extract the slice to be used to update from the updates tensor.
// d. Extract the slice to update from the operand tensor.
// e. Compute the new value for the slice to update by combining the slices
// from c. and d. using the update_computation of scatter.
// f. Write the updated value of the slice into the operand tensor.
StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
HloInstruction* scatter) {
HloInstruction* operand = scatter->mutable_operand(0);
HloInstruction* scatter_indices = scatter->mutable_operand(1);
HloInstruction* updates = scatter->mutable_operand(2);
const ScatterDimensionNumbers& dim_numbers =
scatter->scatter_dimension_numbers();
// If the updates tensor is empty, there is no need to update the operand. We
// can return the operand as is.
if (ShapeUtil::IsZeroElementArray(updates->shape())) {
return operand;
}
// Compute the trip count for the while loop to be used for scatter. This
// should be the number of indices we should scatter into the operand.
const Shape& scatter_indices_shape = scatter_indices->shape();
int64 scatter_loop_trip_count = 1;
for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
if (i != dim_numbers.index_vector_dim()) {
scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
}
}
if (!IsInt32(scatter_loop_trip_count)) {
return Unimplemented(
"Scatter operations with more than 2147483647 scatter indices are not "
"supported. This error occurred for %s.",
scatter->ToString().c_str());
}
// Canonicalize the scatter_indices, after which the size of its most-major
// dimension must be same as the while loop trip count.
TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices,
CanonicalizeScatterIndices(
scatter_indices, dim_numbers.index_vector_dim()));
CHECK_EQ(scatter_loop_trip_count,
canonical_scatter_indices->shape().dimensions(0));
// Canonicalize the updates, after which the size of its most-major dimension
// must be same as the while loop trip count.
TF_ASSIGN_OR_RETURN(
HloInstruction * canonical_updates,
PermuteScatterAndWindowDims(
updates, AsInt64Slice(dim_numbers.update_window_dims())));
TF_ASSIGN_OR_RETURN(
HloInstruction * adjusted_canonical_updates,
AdjustScatterDims(scatter_indices->shape(), canonical_updates,
dim_numbers.index_vector_dim()));
CHECK_EQ(scatter_loop_trip_count,
adjusted_canonical_updates->shape().dimensions(0));
// The while loop that implements the scatter operation.
StatusOr<std::vector<HloInstruction*>> scatter_loop_result_status =
WhileUtil::MakeCountedLoop(
scatter->parent(), scatter_loop_trip_count,
{operand, canonical_scatter_indices, adjusted_canonical_updates},
[&](HloInstruction* induction_var,
const std::vector<HloInstruction*>& loop_state) {
return ScatterLoopBody(scatter, induction_var, loop_state);
});
TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> scatter_loop_result,
scatter_loop_result_status);
return scatter_loop_result.front();
}
StatusOr<bool> ScatterExpander::Run(HloModule* module) {
std::vector<HloInstruction*> scatter_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
for (HloInstruction* instr : computation->instructions()) {
if (instr->opcode() == HloOpcode::kScatter) {
scatter_instrs.push_back(instr);
}
}
}
for (auto instr : scatter_instrs) {
TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr));
TF_RETURN_IF_ERROR(
instr->parent()->ReplaceInstruction(instr, expanded_root));
}
return !scatter_instrs.empty();
}
} // namespace xla

View File

@ -0,0 +1,34 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
class ScatterExpander : public HloPassInterface {
public:
tensorflow::StringPiece name() const override { return "scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;
private:
StatusOr<HloInstruction*> ExpandScatter(HloInstruction* scatter);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_

View File

@ -1014,12 +1014,13 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
}
/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
if (!IsTuple(shape)) {
return 1;
}
int64 count = 0;
ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) {
if (IsLeafIndex(shape, index)) {
++count;
}
});
for (const Shape& subshape : shape.tuple_shapes()) {
count += GetLeafCount(subshape);
}
return count;
}

View File

@ -709,6 +709,21 @@ xla_test(
],
)
xla_test(
name = "scatter_test",
srcs = ["scatter_test.cc"],
deps = [
":client_library_test_base",
":hlo_test_base",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
# Repeat dot_operation_runtime_test with single-threaded eigen.
xla_test(
name = "dot_operation_single_threaded_runtime_test",
@ -2061,6 +2076,8 @@ tf_cc_test(
xla_test(
name = "test_utils_test",
srcs = ["test_utils_test.cc"],
# There is nothing backend specific in this test, so just pick an arbitrary backend.
backends = ["cpu"],
deps = [
":local_client_test_base",
":test_utils",

View File

@ -74,8 +74,9 @@ class ClientLibraryTestBase : public ::testing::Test {
string TestName() const;
void SetFastMathDisabled(bool disabled) {
execution_options_.mutable_debug_options()->set_xla_enable_fast_math(
!disabled);
auto* opts = execution_options_.mutable_debug_options();
opts->set_xla_cpu_enable_fast_math(!disabled);
opts->set_xla_gpu_enable_fast_math(!disabled);
}
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }

View File

@ -1464,5 +1464,24 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] {
EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt));
}
XLA_TEST_F(HloTestBase, ReduceWindowF16) {
const string hlo_string = R"(
HloModule reduce-window
%identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] {
%param0 = f16[] parameter(0)
ROOT %param1 = f16[] parameter(1)
}
ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] {
%parameter.0 = f16[81,8]{1,0} parameter(0)
%parameter.1 = f16[] parameter(1)
ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt));
}
} // namespace
} // namespace xla

View File

@ -0,0 +1,615 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
using tensorflow::gtl::nullopt;
class ScatterTest : public HloTestBase {
protected:
void RunTest(const string& hlo_text, Literal* operand,
Literal* scatter_indices, Literal* updates) {
RunTest(hlo_text, {operand, scatter_indices, updates});
}
void RunTest(const string& hlo_text,
tensorflow::gtl::ArraySlice<Literal*> args) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_text, config));
EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt));
}
};
XLA_TEST_F(ScatterTest, TensorFlowScatterV1_Update) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR1<int32>({0, 2});
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
const char* hlo_text = R"(
HloModule TensorFlowScatterV2
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[3,2] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={0},
inserted_window_dims={1},
scatter_dims_to_operand_dims={1},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR1<int32>({0, 2});
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
const string hlo_text = R"(
HloModule TensorFlowScatter_Add
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR1<int32>({0, 2});
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
const string hlo_text = R"(
HloModule TensorFlowScatter_Mul
mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=mul_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR1<int32>({0, 2});
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) {
const string hlo_text = R"(
HloModule TensorFlowScatter_F32
add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(f32[] lhs, f32[] rhs)
}
ENTRY main {
operand = f32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = f32[2,3] parameter(2)
ROOT scatter = f32[3,3] scatter(operand, indices, updates),
to_apply=add_f32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR1<int32>({2, 1});
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) {
const char* hlo_text = R"(
HloModule TensorFlowScatter
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR1<int32>({1, 1});
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) {
const char* hlo_text = R"(
HloModule TensorFlowScatterMultipleBatchDims
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
updates = s32[2,3,2] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={1},
inserted_window_dims={1},
scatter_dims_to_operand_dims={1},
index_vector_dim=2
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd) {
const char* hlo_text = R"(
HloModule TensorFlowScatterNd
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
updates = s32[2,2] parameter(2)
ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0,1},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) {
const char* hlo_text = R"(
HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
updates = s32[2,2] parameter(2)
ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0,1},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=0
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, DynamicUpdateSlice) {
const char* hlo_text = R"(
HloModule DynamicUpdateSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[1,1] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={0,1},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=0
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR1<int32>({1, 1});
std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) {
const char* hlo_text = R"(
HloModule BatchDynamicUpdateSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
updates = s32[2,1,1] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=0
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, ZeroDimBounds) {
const char* hlo_text = R"(
HloModule TensorFlowScatter_ZeroDimBounds
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,0] parameter(2)
ROOT scatter = s32[3,0] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR1<int32>({0, 2});
std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, NoUpdateWindowDims) {
const string hlo_text = R"(
HloModule Scatter_NoUpdateWindowDims
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3] parameter(0)
indices = s32[2,2,1] parameter(1)
updates = s32[2,2] parameter(2)
ROOT scatter = s32[3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=2
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
std::unique_ptr<Literal> scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, OutOfBoundsIndex) {
const string hlo_text = R"(
HloModule BatchDynamicSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
updates = s32[6,1,1]{2,1,0} parameter(2)
ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) {
const string hlo_text = R"(
HloModule BatchDynamicSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = u32[6,2]{1,0} parameter(1)
updates = s32[6,1,1]{2,1,0} parameter(2)
ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, NegativeIndex) {
const string hlo_text = R"(
HloModule BatchDynamicSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
updates = s32[6,1,1]{2,1,0} parameter(2)
ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, OneScalarIndex) {
const char* hlo_text = R"(
HloModule OneScalarIndex
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[2,3,2]{2,1,0} parameter(0)
index = s32[] parameter(1)
updates = s32[1,3,2]{2,1,0} parameter(2)
ROOT scatter = s32[2,3,2]{2,1,0} scatter(operand, index, updates),
to_apply=update_s32,
update_window_dims={0,1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0},
index_vector_dim=0
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
std::unique_ptr<Literal> updates =
LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, ScalarUpdate) {
const char* hlo_text = R"(
HloModule ScalarUpdate
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[4]{0} parameter(0)
index = s32[] parameter(1)
updates = s32[] parameter(2)
ROOT scatter = s32[4]{0} scatter(operand, index, updates),
to_apply=update_s32,
update_window_dims={},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=0
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25);
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
XLA_TEST_F(ScatterTest, EmptyIndices) {
const string hlo_text = R"(
HloModule EmptyIndices
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3] parameter(0)
indices = s32[0] parameter(1)
updates = s32[0] parameter(2)
ROOT scatter = s32[3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({});
std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({});
RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
}
} // namespace
} // namespace xla

View File

@ -208,16 +208,12 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
const Shape& input_shape, const Shape& slice_shape,
std::minstd_rand0* engine) {
const int64 rank = ShapeUtil::Rank(input_shape);
std::vector<int32> start_indices(rank);
std::unique_ptr<Literal> MakeRandomIndex(
tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) {
std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
for (int i = 0; i < rank; ++i) {
const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
ShapeUtil::GetDimension(slice_shape, i);
std::uniform_int_distribution<int32> generator(0, upper_bound);
for (int i = 0; i < index_space.size(); ++i) {
std::uniform_int_distribution<int32> generator(0, index_space[i]);
start_indices[i] = generator(*engine);
}
}
@ -267,37 +263,42 @@ std::vector<HloInstruction*> FindConstrainedUses(
StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
HloInstruction* needs_index = nullptr;
HloInstruction* needs_constant = nullptr;
std::vector<int64> index_space;
bool needs_constant = false;
ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
if (needs_index != nullptr) {
auto needs_index_shape = needs_index->shape();
auto use_shape = use->shape();
if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
needs_index_shape = needs_index->operand(0)->shape();
case HloOpcode::kDynamicUpdateSlice: {
const Shape& indexed_shape = use->operand(0)->shape();
const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
? use->shape()
: use->operand(1)->shape();
const int64 rank = ShapeUtil::Rank(indexed_shape);
if (!index_space.empty()) {
TF_RET_CHECK(rank == index_space.size());
for (int64 i = 0; i < rank; ++i) {
index_space[i] = std::min(
index_space[i], ShapeUtil::GetDimension(indexed_shape, i) -
ShapeUtil::GetDimension(slice_shape, i));
}
if (use->opcode() == HloOpcode::kDynamicSlice) {
use_shape = use->operand(0)->shape();
}
if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
return Unimplemented(
"Conflicting operand generation slice index constraints\n");
} else {
index_space.resize(rank);
for (int64 i = 0; i < rank; ++i) {
index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) -
ShapeUtil::GetDimension(slice_shape, i);
}
}
needs_index = use;
break;
}
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
needs_constant = use;
needs_constant = true;
constant_type = GetInitValue(*use->to_apply());
break;
case HloOpcode::kSelectAndScatter:
needs_constant = use;
needs_constant = true;
constant_type = GetInitValue(*use->scatter());
break;
@ -307,16 +308,14 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
use->ToString().c_str());
}
}
if (needs_index != nullptr && needs_constant != nullptr) {
if (!index_space.empty() && needs_constant) {
return Unimplemented(
"Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
"constant: %s\n",
needs_index->ToString().c_str(), needs_constant->ToString().c_str());
"Conflicting operand generation constraints. Dynamically indexes a "
"shape and is the init value of a reduction.");
}
if (needs_index != nullptr) {
return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
needs_index->shape(), engine);
} else if (needs_constant != nullptr) {
if (!index_space.empty()) {
return MakeRandomIndex(index_space, engine);
} else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
@ -356,8 +355,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
std::vector<std::unique_ptr<Literal>> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument(
*dataflow, *params[i], engine.get()));
arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get())
.ValueOrDie();
}
return std::move(arguments);
}

View File

@ -72,5 +72,60 @@ XLA_TEST_F(TestUtilsTest, Token) {
TF_ASSERT_OK(MakeFakeArguments(module.get()).status());
}
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
auto module = ParseHloString(
R"(HloModule index_space_module
ENTRY IndexSpace {
index_param = s32[3]{0} parameter(0)
array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3}
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
})")
.ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 3);
const Literal& index_arg = *args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
EXPECT_GE(index_arg.Get<int32>({1}), 0);
EXPECT_LE(index_arg.Get<int32>({1}), 2);
EXPECT_GE(index_arg.Get<int32>({2}), 0);
EXPECT_LE(index_arg.Get<int32>({2}), 3);
}
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
auto module = ParseHloString(
R"(HloModule index_space_module
ENTRY IndexSpace {
index_param = s32[3]{0} parameter(0)
array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
update_param.1 = f32[1,2,3]{0,1,2} parameter(3)
update_param.2 = f32[3,2,2]{0,1,2} parameter(4)
dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param)
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
})")
.ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 5);
const Literal& index_arg = *args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
EXPECT_GE(index_arg.Get<int32>({1}), 0);
EXPECT_LE(index_arg.Get<int32>({1}), 2);
EXPECT_GE(index_arg.Get<int32>({2}), 0);
EXPECT_LE(index_arg.Get<int32>({2}), 3);
}
} // namespace
} // namespace xla

View File

@ -104,15 +104,6 @@ message DebugOptions {
// interpretation of this value is left to the backends.
int32 xla_backend_optimization_level = 31;
// When true, "unsafe" mathematical optimizations are enabled. These
// transformations include but are not limited to:
//
// - Reducing the precision of operations (e.g. using an approximate sin
// function, or transforming x/y into x * (1/y)).
// - Assuming that operations never produce or consume NaN or +/- Inf.
// - Assuming that +0 and -0 are indistinguishable.
bool xla_enable_fast_math = 32;
// Embed the compiler IR as a string in the executable.
bool xla_embed_ir_in_executable = 33;
@ -194,6 +185,16 @@ message DebugOptions {
// Maximum kernel unroll factor for the GPU backend.
int32 xla_gpu_max_kernel_unroll_factor = 98;
// When true, "unsafe" mathematical optimizations are enabled. These
// transformations include but are not limited to:
//
// - Reducing the precision of operations (e.g. using an approximate sin
// function, or transforming x/y into x * (1/y)).
// - Assuming that operations never produce or consume NaN or +/- Inf.
// - Assuming that +0 and -0 are indistinguishable.
bool xla_cpu_enable_fast_math = 99;
bool xla_gpu_enable_fast_math = 100;
// Extra options to pass to the compilation backend; specific interpretation
// of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;

View File

@ -218,11 +218,11 @@ class ToBigtableOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
dataset->MakeIterator(&iter_ctx, "ToBigtableOpIterator", &iterator),
dataset->MakeIterator(IteratorContext(ctx), "ToBigtableOpIterator",
&iterator),
done);
int64 timestamp_int;
@ -245,9 +245,10 @@ class ToBigtableOp : public AsyncOpKernel {
::google::cloud::bigtable::BulkMutation mutation;
// TODO(saeta): Make # of mutations configurable.
for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) {
OP_REQUIRES_OK_ASYNC(
ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
done);
OP_REQUIRES_OK_ASYNC(ctx,
iterator->GetNext(IteratorContext(ctx),
&components, &end_of_sequence),
done);
if (!end_of_sequence) {
OP_REQUIRES_OK_ASYNC(
ctx,

View File

@ -53,7 +53,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
BigtableTableResource* table,
@ -61,7 +61,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
std::vector<string> columns,
const DataTypeVector& output_types,
std::vector<PartialTensorShape> output_shapes)
: GraphDatasetBase(ctx),
: DatasetBase(DatasetContext(ctx)),
input_(input),
table_(table),
column_families_(std::move(column_families)),
@ -80,8 +80,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::BigtableLookupDataset")}));
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::BigtableLookup")}));
}
const DataTypeVector& output_dtypes() const override {
@ -96,6 +96,14 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
return "BigtableLookupDatasetOp::Dataset";
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented("%s does not support serialization",
DebugString());
}
private:
static ::google::cloud::bigtable::Filter MakeFilter(
const std::vector<string>& column_families,

View File

@ -35,11 +35,13 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string prefix)
: GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) {
: DatasetBase(DatasetContext(ctx)),
table_(table),
prefix_(std::move(prefix)) {
table_->Ref();
}
@ -47,8 +49,8 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")}));
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::BigtablePrefixKey")}));
}
const DataTypeVector& output_dtypes() const override {
@ -68,6 +70,14 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented("%s does not support serialization",
DebugString());
}
private:
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
public:

View File

@ -39,11 +39,11 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string start_key, string end_key)
: GraphDatasetBase(ctx),
: DatasetBase(DatasetContext(ctx)),
table_(table),
start_key_(std::move(start_key)),
end_key_(std::move(end_key)) {
@ -54,8 +54,8 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::BigtableRangeKeyDataset")}));
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::BigtableRangeKey")}));
}
const DataTypeVector& output_dtypes() const override {
@ -75,6 +75,14 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented("%s does not support serialization",
DebugString());
}
private:
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
public:

View File

@ -52,11 +52,11 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string prefix, string start_key, string end_key)
: GraphDatasetBase(ctx),
: DatasetBase(DatasetContext(ctx)),
table_(table),
key_range_(MakeMultiModeKeyRange(
std::move(prefix), std::move(start_key), std::move(end_key))) {
@ -68,7 +68,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::BigtableSampleKeyPairsDataset")}));
{this, strings::StrCat(prefix, "::BigtableSampleKeyPairs")}));
}
const DataTypeVector& output_dtypes() const override {
@ -87,6 +87,14 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
return "BigtableSampleKeyPairsDatasetOp::Dataset";
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented("%s does not support serialization",
DebugString());
}
private:
static MultiModeKeyRange MakeMultiModeKeyRange(string prefix,
string start_key,

View File

@ -31,10 +31,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table)
: GraphDatasetBase(ctx), table_(table) {
: DatasetBase(DatasetContext(ctx)), table_(table) {
table_->Ref();
}
@ -43,7 +43,7 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::BigtableSampleKeysDataset")}));
{this, strings::StrCat(prefix, "::BigtableSampleKeys")}));
}
const DataTypeVector& output_dtypes() const override {
@ -63,6 +63,14 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented("%s does not support serialization",
DebugString());
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:

View File

@ -84,7 +84,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string prefix, string start_key, string end_key,
@ -92,7 +92,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
std::vector<string> columns, float probability,
const DataTypeVector& output_types,
std::vector<PartialTensorShape> output_shapes)
: GraphDatasetBase(ctx),
: DatasetBase(DatasetContext(ctx)),
table_(table),
prefix_(std::move(prefix)),
start_key_(std::move(start_key)),
@ -111,8 +111,8 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::BigtableScanDataset")}));
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::BigtableScan")}));
}
const DataTypeVector& output_dtypes() const override {
@ -129,6 +129,14 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented("%s does not support serialization",
DebugString());
}
private:
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
public:

View File

@ -34,7 +34,9 @@
namespace tensorflow {
using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearnerConfig_MultiClassStrategy;
using boosted_trees::learner::ObliviousSplitInfo;
using boosted_trees::learner::SplitInfo;
using boosted_trees::learner::stochastic::GradientStats;
using boosted_trees::learner::stochastic::NodeStats;
@ -158,6 +160,11 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
const Tensor* weak_learner_type_t;
OP_REQUIRES_OK(context,
context->input("weak_learner_type", &weak_learner_type_t));
const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
partition_boundaries.push_back(0);
@ -188,20 +195,59 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
tensorflow::TTypes<int32>::Vec output_partition_ids =
output_partition_ids_t->vec<int32>();
Tensor* gains_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output("gains", TensorShape({num_elements}),
&gains_t));
// For a normal tree, we output a split per partition. For an oblivious
// tree, we output one split for all partitions of the layer
int32 size_output = num_elements;
if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
num_elements > 0) {
size_output = 1;
}
Tensor* gains_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
"gains", TensorShape({size_output}), &gains_t));
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
Tensor* output_splits_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
"split_infos", TensorShape({num_elements}),
&output_splits_t));
OP_REQUIRES_OK(context, context->allocate_output("split_infos",
TensorShape({size_output}),
&output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
if (num_elements == 0) {
return;
}
SplitBuilderState state(context);
switch (weak_learner_type) {
case LearnerConfig::NORMAL_DECISION_TREE: {
ComputeNormalDecisionTree(
&state, normalizer_ratio, num_elements, partition_boundaries,
bucket_boundaries, partition_ids, bucket_ids, gradients_t,
hessians_t, &output_partition_ids, &gains, &output_splits);
break;
}
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
ComputeObliviousDecisionTree(
&state, normalizer_ratio, num_elements, partition_boundaries,
bucket_boundaries, partition_ids, bucket_ids, gradients_t,
hessians_t, &output_partition_ids, &gains, &output_splits);
break;
}
}
}
private:
void ComputeNormalDecisionTree(
SplitBuilderState* state, const float normalizer_ratio,
const int num_elements, const std::vector<int32>& partition_boundaries,
const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
const tensorflow::TTypes<int32>::ConstVec& partition_ids,
const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
const Tensor* gradients_t, const Tensor* hessians_t,
tensorflow::TTypes<int32>::Vec* output_partition_ids,
tensorflow::TTypes<float>::Vec* gains,
tensorflow::TTypes<string>::Vec* output_splits) {
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[root_idx];
@ -213,7 +259,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
root_gradient_stats *= normalizer_ratio;
NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
int32 best_bucket_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@ -223,10 +269,10 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
GradientStats g(*gradients_t, *hessians_t, bucket_idx);
g *= normalizer_ratio;
left_gradient_stats += g;
NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@ -237,21 +283,125 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
SplitInfo split_info;
auto* dense_split =
split_info.mutable_split_node()->mutable_dense_float_binary_split();
dense_split->set_feature_column(state.feature_column_group_id());
dense_split->set_feature_column(state->feature_column_group_id());
dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
state.FillLeaf(best_left_node_stats, left_child);
state.FillLeaf(best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
best_gain - root_stats.gain - state.tree_complexity_regularization();
output_partition_ids(root_idx) = partition_ids(start_index);
state->FillLeaf(best_left_node_stats, left_child);
state->FillLeaf(best_right_node_stats, right_child);
split_info.SerializeToString(&(*output_splits)(root_idx));
(*gains)(root_idx) =
best_gain - root_stats.gain - state->tree_complexity_regularization();
(*output_partition_ids)(root_idx) = partition_ids(start_index);
}
}
void ComputeObliviousDecisionTree(
SplitBuilderState* state, const float normalizer_ratio,
const int num_elements, const std::vector<int32>& partition_boundaries,
const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
const tensorflow::TTypes<int32>::ConstVec& partition_ids,
const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
const Tensor* gradients_t, const Tensor* hessians_t,
tensorflow::TTypes<int32>::Vec* output_partition_ids,
tensorflow::TTypes<float>::Vec* gains,
tensorflow::TTypes<string>::Vec* output_splits) {
// Holds the root stats per each node to be split.
std::vector<GradientStats> current_layer_stats;
current_layer_stats.reserve(num_elements);
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
const int start_index = partition_boundaries[root_idx];
const int end_index = partition_boundaries[root_idx + 1];
GradientStats root_gradient_stats;
for (int64 bucket_idx = start_index; bucket_idx < end_index;
++bucket_idx) {
root_gradient_stats +=
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
root_gradient_stats *= normalizer_ratio;
current_layer_stats.push_back(root_gradient_stats);
}
float best_gain = std::numeric_limits<float>::lowest();
int64 best_bucket_idx = 0;
std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
int64 current_bucket_id = 0;
int64 last_bucket_id = -1;
// Indexes offsets for each of the partitions that can be used to access
// gradients of a partition for a current bucket we consider.
std::vector<int> current_layer_offsets(num_elements, 0);
std::vector<GradientStats> left_gradient_stats(num_elements);
// The idea is to try every bucket id in increasing order. In each iteration
// we calculate the gain of the layer using the current bucket id as split
// value, and we also obtain the following bucket id to try.
while (current_bucket_id > last_bucket_id) {
last_bucket_id = current_bucket_id;
int64 next_bucket_id = -1;
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
int idx =
current_layer_offsets[root_idx] + partition_boundaries[root_idx];
const int end_index = partition_boundaries[root_idx + 1];
if (idx < end_index && bucket_ids(idx, 0) == current_bucket_id) {
GradientStats g(*gradients_t, *hessians_t, idx);
g *= normalizer_ratio;
left_gradient_stats[root_idx] += g;
current_layer_offsets[root_idx]++;
idx++;
}
if (idx < end_index &&
(bucket_ids(idx, 0) < next_bucket_id || next_bucket_id == -1)) {
next_bucket_id = bucket_ids(idx, 0);
}
}
float gain_of_split = 0.0;
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
GradientStats right_gradient_stats =
current_layer_stats[root_idx] - left_gradient_stats[root_idx];
NodeStats left_stat =
state->ComputeNodeStats(left_gradient_stats[root_idx]);
NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
gain_of_split += left_stat.gain + right_stat.gain;
current_left_node_stats[root_idx] = left_stat;
current_right_node_stats[root_idx] = right_stat;
}
if (gain_of_split > best_gain) {
best_gain = gain_of_split;
best_left_node_stats = current_left_node_stats;
best_right_node_stats = current_right_node_stats;
}
current_bucket_id = next_bucket_id;
}
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
}
best_gain -= num_elements * state->tree_complexity_regularization();
ObliviousSplitInfo oblivious_split_info;
auto* oblivious_dense_split = oblivious_split_info.mutable_split_node()
->mutable_dense_float_binary_split();
oblivious_dense_split->set_feature_column(state->feature_column_group_id());
oblivious_dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
(*gains)(0) = best_gain;
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
auto* left_children = oblivious_split_info.add_children_leaves();
auto* right_children = oblivious_split_info.add_children_leaves();
state->FillLeaf(best_left_node_stats[root_idx], left_children);
state->FillLeaf(best_right_node_stats[root_idx], right_children);
const int start_index = partition_boundaries[root_idx];
(*output_partition_ids)(root_idx) = partition_ids(start_index);
}
oblivious_split_info.SerializeToString(&(*output_splits)(0));
}
};
REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU),
BuildDenseInequalitySplitsOp);

View File

@ -64,6 +64,7 @@ from __future__ import print_function
import re
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops
from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
@ -171,6 +172,7 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy,
init_stamp_token=0,
loss_uses_sum_reduction=False,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
name=None):
"""Initialize the internal state for this split handler.
@ -192,6 +194,7 @@ class DenseSplitHandler(InequalitySplitHandler):
stamped objects.
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
SUM or MEAN reduction was used for the loss.
weak_learner_type: Specifies the type of weak learner to use.
name: An optional handler name.
"""
super(DenseSplitHandler, self).__init__(
@ -209,6 +212,7 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy=multiclass_strategy,
loss_uses_sum_reduction=loss_uses_sum_reduction)
self._dense_float_column = dense_float_column
self._weak_learner_type = weak_learner_type
# Register dense_make_stats_update function as an Op to the graph.
g = ops.get_default_graph()
dense_make_stats_update.add_to_graph(g)
@ -269,16 +273,17 @@ class DenseSplitHandler(InequalitySplitHandler):
next_stamp_token, self._multiclass_strategy, class_id,
self._feature_column_group_id, self._l1_regularization,
self._l2_regularization, self._tree_complexity_regularization,
self._min_node_weight, self._loss_uses_sum_reduction))
self._min_node_weight, self._loss_uses_sum_reduction,
self._weak_learner_type))
return are_splits_ready, partition_ids, gains, split_infos
def _make_dense_split(
quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
next_stamp_token, multiclass_strategy, class_id, feature_column_id,
l1_regularization, l2_regularization, tree_complexity_regularization,
min_node_weight, is_multi_dimentional, loss_uses_sum_reduction):
def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
stamp_token, next_stamp_token, multiclass_strategy,
class_id, feature_column_id, l1_regularization,
l2_regularization, tree_complexity_regularization,
min_node_weight, is_multi_dimentional,
loss_uses_sum_reduction, weak_learner_type):
"""Function that builds splits for a dense feature column."""
# Get the bucket boundaries
are_splits_ready, buckets = (
@ -327,7 +332,8 @@ def _make_dense_split(
l2_regularization=l2_regularization,
tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
multiclass_strategy=multiclass_strategy))
multiclass_strategy=multiclass_strategy,
weak_learner_type=weak_learner_type))
return are_splits_ready, partition_ids, gains, split_infos
@ -507,7 +513,40 @@ def _make_sparse_split(
return are_splits_ready, partition_ids, gains, split_infos
def _specialize_make_split(func, is_multi_dimentional):
def _specialize_make_split_dense(func, is_multi_dimentional):
"""Builds a specialized version of the function."""
@function.Defun(
dtypes.resource,
dtypes.resource,
dtypes.int64,
dtypes.int64,
dtypes.int32,
dtypes.int32,
dtypes.int32,
dtypes.float32,
dtypes.float32,
dtypes.float32,
dtypes.float32,
dtypes.bool,
dtypes.int32,
noinline=True)
def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
next_stamp_token, multiclass_strategy, class_id, feature_column_id,
l1_regularization, l2_regularization, tree_complexity_regularization,
min_node_weight, loss_uses_sum_reduction, weak_learner_type):
"""Function that builds splits for a sparse feature column."""
return func(quantile_accumulator_handle, stats_accumulator_handle,
stamp_token, next_stamp_token, multiclass_strategy, class_id,
feature_column_id, l1_regularization, l2_regularization,
tree_complexity_regularization, min_node_weight,
is_multi_dimentional, loss_uses_sum_reduction,
weak_learner_type)
return f
def _specialize_make_split_sparse(func, is_multi_dimentional):
"""Builds a specialized version of the function."""
@function.Defun(
@ -537,15 +576,17 @@ def _specialize_make_split(func, is_multi_dimentional):
return f
make_dense_split_scalar = _specialize_make_split(_make_dense_split,
is_multi_dimentional=False)
make_dense_split_tensor = _specialize_make_split(_make_dense_split,
is_multi_dimentional=True)
make_sparse_split_scalar = _specialize_make_split(_make_sparse_split,
is_multi_dimentional=False)
make_sparse_split_tensor = _specialize_make_split(_make_sparse_split,
is_multi_dimentional=True)
make_dense_split_scalar = _specialize_make_split_dense(
_make_dense_split, is_multi_dimentional=False)
make_dense_split_tensor = _specialize_make_split_dense(
_make_dense_split, is_multi_dimentional=True)
make_sparse_split_scalar = _specialize_make_split_sparse(
_make_sparse_split, is_multi_dimentional=False)
make_sparse_split_tensor = _specialize_make_split_sparse(
_make_sparse_split, is_multi_dimentional=True)
@function.Defun(

View File

@ -182,6 +182,133 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testObliviousFeatureSplitGeneration(self):
with self.test_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 2 |
# i1 | (-0.5, 0.07) | 0 | 2 |
# i2 | (1.2, 0.2) | 0 | 0 |
# i3 | (4.0, 0.13) | 1 | 1 |
dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52])
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
class_id = -1
gradient_shape = tensor_shape.scalar()
hessian_shape = tensor_shape.scalar()
split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1,
l2_regularization=1.,
tree_complexity_regularization=0.,
min_node_weight=0.,
epsilon=0.001,
num_quantiles=10,
feature_column_group_id=0,
dense_float_column=dense_column,
init_stamp_token=0,
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
resources.initialize_resources(resources.shared_resources()).run()
empty_gradients, empty_hessians = get_empty_tensors(
gradient_shape, hessian_shape)
example_weights = array_ops.ones([4, 1], dtypes.float32)
update_1 = split_handler.update_stats_sync(
0,
partition_ids,
gradients,
hessians,
empty_gradients,
empty_hessians,
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
are_splits_ready = split_handler.make_splits(
np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
partition_ids,
gradients,
hessians,
empty_gradients,
empty_hessians,
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
]))
# During the first iteration, inequality split handlers are not going to
# have any splits. Make sure that we return not_ready in that case.
self.assertFalse(are_splits_ready)
self.assertTrue(are_splits_ready2)
self.assertAllEqual([0, 1], partitions)
oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
oblivious_split_info.ParseFromString(splits[0])
split_node = oblivious_split_info.split_node.dense_float_binary_split
self.assertAllClose(0.3, split_node.threshold, 0.00001)
self.assertEqual(0, split_node.feature_column)
# Check the split on partition 0.
# -(1.2 - 0.1) / (0.2 + 1)
expected_left_weight_0 = -0.9166666666666666
# expected_left_weight_0 * -(1.2 - 0.1)
expected_left_gain_0 = 1.008333333333333
# (-0.5 + 0.2 + 0.1) / (0.19 + 1)
expected_right_weight_0 = 0.1680672
# expected_right_weight_0 * -(-0.5 + 0.2 + 0.1))
expected_right_gain_0 = 0.033613445378151252
# (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
expected_bias_gain_0 = 0.46043165467625896
left_child = oblivious_split_info.children_leaves[0].vector
right_child = oblivious_split_info.children_leaves[1].vector
self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001)
# Check the split on partition 1.
expected_left_weight_1 = 0
expected_left_gain_1 = 0
# -(4 - 0.1) / (0.13 + 1)
expected_right_weight_1 = -3.4513274336283186
# expected_right_weight_1 * -(4 - 0.1)
expected_right_gain_1 = 13.460176991150442
# (-4 + 0.1) ** 2 / (0.13 + 1)
expected_bias_gain_1 = 13.460176991150442
left_child = oblivious_split_info.children_leaves[2].vector
right_child = oblivious_split_info.children_leaves[3].vector
self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
# The layer gain is the sum of the gains of each partition
layer_gain = (
expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + (
expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1)
self.assertAllClose(layer_gain, gains[0], 0.00001)
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
with self.test_session() as sess:
# The data looks like the following:

View File

@ -36,6 +36,7 @@ REGISTER_OP("BuildDenseInequalitySplits")
.Input("tree_complexity_regularization: float")
.Input("min_node_weight: float")
.Input("multiclass_strategy: int32")
.Input("weak_learner_type: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@ -84,6 +85,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child.
be considered.
multiclass_strategy: A scalar, specifying the multiclass handling strategy.
See LearnerConfig.MultiClassStrategy for valid values.
weak_learner_type: A scalar, specifying the weak learner type to use.
See LearnerConfig.WeakLearnerType for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.

View File

@ -108,6 +108,11 @@ message LearnerConfig {
DIAGONAL_HESSIAN = 3;
}
enum WeakLearnerType {
NORMAL_DECISION_TREE = 0;
OBLIVIOUS_DECISION_TREE = 1;
}
// Number of classes.
uint32 num_classes = 1;
@ -141,4 +146,7 @@ message LearnerConfig {
// If you want to average the ensembles (for regularization), provide the
// config below.
AveragingConfig averaging_config = 11;
// By default we use NORMAL_DECISION_TREE as weak learner.
WeakLearnerType weak_learner_type = 12;
}

View File

@ -17,3 +17,10 @@ message SplitInfo {
// Right Leaf node.
tensorflow.boosted_trees.trees.Leaf right_child = 3;
}
message ObliviousSplitInfo {
// The split node with the feature_column and threshold defined.
tensorflow.boosted_trees.trees.TreeNode split_node = 1;
// The new leaves of the tree.
repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2;
}

View File

@ -59,7 +59,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@ -132,7 +133,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@ -171,7 +173,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
# .assertEmpty doesn't exist on ubuntu-contrib
self.assertEqual(0, len(partitions))

View File

@ -672,6 +672,8 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config.constraints.min_node_weight, dtypes.float32)
loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM
loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
weak_learner_type = constant_op.constant(
self._learner_config.weak_learner_type)
epsilon = 0.01
num_quantiles = 100
strategy_tensor = constant_op.constant(strategy)
@ -696,6 +698,7 @@ class GradientBoostedDecisionTreeModel(object):
multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token,
loss_uses_sum_reduction=loss_uses_sum_reduction,
weak_learner_type=weak_learner_type,
))
fc_name_idx += 1

View File

@ -31,6 +31,9 @@ Checkpointable data structures:
@@List
@@Mapping
@@UniqueNameTracker
Checkpoint management:
@@CheckpointManager
"""
from __future__ import absolute_import
@ -41,6 +44,7 @@ from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
from tensorflow.python.training.checkpoint_management import CheckpointManager
from tensorflow.python.training.checkpointable.base import CheckpointableBase
from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.data_structures import Mapping

View File

@ -204,7 +204,7 @@ def find_best_candidate_distribution(objective_vector,
assert best_pp is not None
# Throughout this loop, a maximum_violation of "lower" is not achievable,
# but a maximum_violation of "upper" is achiveable.
# but a maximum_violation of "upper" is achievable.
while True:
middle = 0.5 * (lower + upper)
if (middle - lower <= epsilon) or (upper - middle <= epsilon):

View File

@ -72,7 +72,8 @@ class ConstrainedMinimizationProblem(object):
else:
proxy_constraints_shape = self.proxy_constraints.get_shape()
if (constraints_shape is None or proxy_constraints_shape is None or
if (constraints_shape.ndims is None or
proxy_constraints_shape.ndims is None or
any([ii is None for ii in constraints_shape.as_list()]) or
any([ii is None for ii in proxy_constraints_shape.as_list()])):
raise ValueError(
@ -121,3 +122,19 @@ class ConstrainedMinimizationProblem(object):
A tensor of proxy constraint functions.
"""
return None
# This is a property, instead of an abstract property, since it doesn't need
# to be overridden: if pre_train_ops returns None, then there are no ops to
# run before train_op.
@property
def pre_train_ops(self):
"""Returns a list of `Operation`s to run before the train_op.
When a `ConstrainedOptimizer` creates a train_op (in `minimize`
`minimize_unconstrained`, or `minimize_constrained`), it will include these
ops before the main training step.
Returns:
A list of `Operation`s.
"""
return None

View File

@ -55,20 +55,21 @@ class ConstrainedOptimizer(object):
"""Returns the `tf.train.Optimizer` used for optimization."""
return self._optimizer
def minimize_unconstrained(self,
minimization_problem,
global_step=None,
var_list=None,
gate_gradients=train_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
"""Returns an `Op` for minimizing the unconstrained problem.
@abc.abstractmethod
def _minimize_constrained(self,
minimization_problem,
global_step=None,
var_list=None,
gate_gradients=train_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
"""Version of `minimize_constrained` to be overridden by subclasses.
Unlike `minimize_constrained`, this function ignores the `constraints` (and
`proxy_constraints`) portion of the minimization problem entirely, and only
minimizes `objective`.
Implementations of this method should ignore the `pre_train_ops` property of
the `minimization_problem`. The public `minimize_constrained` method will
take care of executing these before the returned train_op.
Args:
minimization_problem: ConstrainedMinimizationProblem, the problem to
@ -83,19 +84,10 @@ class ConstrainedOptimizer(object):
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
Returns:
TensorFlow Op.
`Operation`, the train_op.
"""
return self.optimizer.minimize(
minimization_problem.objective,
global_step=global_step,
var_list=var_list,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
name=name,
grad_loss=grad_loss)
pass
@abc.abstractmethod
def minimize_constrained(self,
minimization_problem,
global_step=None,
@ -105,7 +97,7 @@ class ConstrainedOptimizer(object):
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
"""Returns an `Op` for minimizing the constrained problem.
"""Returns an `Operation` for minimizing the constrained problem.
Unlike `minimize_unconstrained`, this function attempts to find a solution
that minimizes the `objective` portion of the minimization problem while
@ -124,9 +116,83 @@ class ConstrainedOptimizer(object):
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
Returns:
TensorFlow Op.
`Operation`, the train_op.
"""
pass
def train_op_callback():
return self._minimize_constrained(
minimization_problem,
global_step=global_step,
var_list=var_list,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
name=name,
grad_loss=grad_loss)
# If we have pre_train_ops, use tf.control_dependencies() to ensure that
# they execute before the train_op.
pre_train_ops = minimization_problem.pre_train_ops
if pre_train_ops:
with ops.control_dependencies(pre_train_ops):
train_op = train_op_callback()
else:
train_op = train_op_callback()
return train_op
def minimize_unconstrained(self,
minimization_problem,
global_step=None,
var_list=None,
gate_gradients=train_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
"""Returns an `Operation` for minimizing the unconstrained problem.
Unlike `minimize_constrained`, this function ignores the `constraints` (and
`proxy_constraints`) portion of the minimization problem entirely, and only
minimizes `objective`.
Args:
minimization_problem: ConstrainedMinimizationProblem, the problem to
optimize.
global_step: as in `tf.train.Optimizer`'s `minimize` method.
var_list: as in `tf.train.Optimizer`'s `minimize` method.
gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
method.
name: as in `tf.train.Optimizer`'s `minimize` method.
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
Returns:
`Operation`, the train_op.
"""
def train_op_callback():
return self.optimizer.minimize(
minimization_problem.objective,
global_step=global_step,
var_list=var_list,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
name=name,
grad_loss=grad_loss)
# If we have pre_train_ops, use tf.control_dependencies() to ensure that
# they execute before the train_op.
pre_train_ops = minimization_problem.pre_train_ops
if pre_train_ops:
with ops.control_dependencies(pre_train_ops):
train_op = train_op_callback()
else:
train_op = train_op_callback()
return train_op
def minimize(self,
minimization_problem,
@ -138,7 +204,7 @@ class ConstrainedOptimizer(object):
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
"""Returns an `Op` for minimizing the constrained problem.
"""Returns an `Operation` for minimizing the constrained problem.
This method combines the functionality of `minimize_unconstrained` and
`minimize_constrained`. If global_step < unconstrained_steps, it will
@ -164,14 +230,14 @@ class ConstrainedOptimizer(object):
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
Returns:
TensorFlow Op.
`Operation`, the train_op.
Raises:
ValueError: If unconstrained_steps is provided, but global_step is not.
"""
def unconstrained_fn():
"""Returns an `Op` for minimizing the unconstrained problem."""
"""Returns an `Operation` for minimizing the unconstrained problem."""
return self.minimize_unconstrained(
minimization_problem=minimization_problem,
global_step=global_step,
@ -183,7 +249,7 @@ class ConstrainedOptimizer(object):
grad_loss=grad_loss)
def constrained_fn():
"""Returns an `Op` for minimizing the constrained problem."""
"""Returns an `Operation` for minimizing the constrained problem."""
return self.minimize_constrained(
minimization_problem=minimization_problem,
global_step=global_step,

View File

@ -70,11 +70,13 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius):
region w.r.t. the Euclidean norm.
Raises:
ValueError: if the `multipliers` tensor does not have a fully-known shape,
or is not one-dimensional.
ValueError: if the `multipliers` tensor is not floating-point, does not have
a fully-known shape, or is not one-dimensional.
"""
if not multipliers.dtype.is_floating:
raise ValueError("multipliers must have a floating-point dtype")
multipliers_shape = multipliers.get_shape()
if multipliers_shape is None:
if multipliers_shape.ndims is None:
raise ValueError("multipliers must have known shape")
if multipliers_shape.ndims != 1:
raise ValueError(
@ -101,12 +103,12 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius):
(radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum(
1.0, standard_ops.reduce_sum(inactive)))
multipliers += scale * inactive
new_inactive = standard_ops.to_float(multipliers > 0)
new_inactive = standard_ops.cast(multipliers > 0, multipliers.dtype)
multipliers *= new_inactive
return (iteration, multipliers, new_inactive, inactive)
iteration = standard_ops.constant(0)
inactive = standard_ops.ones_like(multipliers)
inactive = standard_ops.ones_like(multipliers, dtype=multipliers.dtype)
# We actually want a do-while loop, so we explicitly call while_loop_body()
# once before tf.while_loop().
@ -189,16 +191,16 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
def _projection_op(self, state, name=None):
pass
def minimize_constrained(self,
minimization_problem,
global_step=None,
var_list=None,
gate_gradients=train_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
"""Returns an `Op` for minimizing the constrained problem.
def _minimize_constrained(self,
minimization_problem,
global_step=None,
var_list=None,
gate_gradients=train_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
"""Returns an `Operation` for minimizing the constrained problem.
The `optimizer` constructor parameter will be used to update the model
parameters, while the Lagrange multipliers will be updated using
@ -216,8 +218,11 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
name: as in `tf.train.Optimizer`'s `minimize` method.
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
Raises:
ValueError: If the minimization_problem tensors have different dtypes.
Returns:
TensorFlow Op.
`Operation`, the train_op.
"""
objective = minimization_problem.objective
@ -225,6 +230,14 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
proxy_constraints = minimization_problem.proxy_constraints
if proxy_constraints is None:
proxy_constraints = constraints
# Make sure that the objective, constraints and proxy constraints all have
# the same dtype.
if (objective.dtype.base_dtype != constraints.dtype.base_dtype or
objective.dtype.base_dtype != proxy_constraints.dtype.base_dtype):
raise ValueError("objective, constraints and proxy_constraints must "
"have the same dtype")
# Flatten both constraints tensors to 1d.
num_constraints = minimization_problem.num_constraints
constraints = standard_ops.reshape(constraints, shape=(num_constraints,))
@ -241,8 +254,10 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
multipliers = self._lagrange_multipliers(state)
loss = (
objective + standard_ops.tensordot(multipliers, proxy_constraints, 1))
multipliers_gradient = constraints
objective + standard_ops.tensordot(
standard_ops.cast(multipliers, proxy_constraints.dtype),
proxy_constraints, 1))
multipliers_gradient = standard_ops.cast(constraints, multipliers.dtype)
update_ops = []
if self.constraint_optimizer is None:
@ -356,6 +371,8 @@ class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer):
# For an AdditiveExternalRegretOptimizer, the internal state is simply a
# tensor of Lagrange multipliers with shape (m,), where m is the number of
# constraints.
#
# FUTURE WORK: make the dtype a parameter.
return standard_ops.zeros((num_constraints,), dtype=dtypes.float32)
def _lagrange_multipliers(self, state):

View File

@ -79,9 +79,11 @@ def _maximal_eigenvector_power_method(matrix,
The maximal right-eigenvector of `matrix`.
Raises:
ValueError: If the epsilon or maximum_iterations parameters violate their
bounds.
ValueError: If the `matrix` tensor is not floating-point, or if the
`epsilon` or `maximum_iterations` parameters violate their bounds.
"""
if not matrix.dtype.is_floating:
raise ValueError("multipliers must have a floating-point dtype")
if epsilon <= 0.0:
raise ValueError("epsilon must be strictly positive")
if maximum_iterations <= 0:
@ -139,11 +141,13 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
(i.e. the Frobenius norm).
Raises:
ValueError: if the `matrix` tensor does not have a fully-known shape, or is
not two-dimensional and square.
ValueError: if the `matrix` tensor is not floating-point, does not have a
fully-known shape, or is not two-dimensional and square.
"""
if not matrix.dtype.is_floating:
raise ValueError("multipliers must have a floating-point dtype")
matrix_shape = matrix.get_shape()
if matrix_shape is None:
if matrix_shape.ndims is None:
raise ValueError("matrix must have known shape")
if matrix_shape.ndims != 2:
raise ValueError(
@ -172,12 +176,12 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
matrix, axis=0, keepdims=True)) / standard_ops.maximum(
1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True))
matrix += scale * inactive
new_inactive = standard_ops.to_float(matrix > 0)
new_inactive = standard_ops.cast(matrix > 0, matrix.dtype)
matrix *= new_inactive
return (iteration, matrix, new_inactive, inactive)
iteration = standard_ops.constant(0)
inactive = standard_ops.ones_like(matrix)
inactive = standard_ops.ones_like(matrix, dtype=matrix.dtype)
# We actually want a do-while loop, so we explicitly call while_loop_body()
# once before tf.while_loop().
@ -218,7 +222,7 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
"""Base class representing a `_SwapRegretOptimizer`.
This class contains most of the logic for performing constrained optimization,
minimizing external regret for the constraints player. What it *doesn't* do is
minimizing swap regret for the constraints player. What it *doesn't* do is
keep track of the internal state (the stochastic matrix). Instead, the state
is accessed via the _initial_state(), _stochastic_matrix(),
_constraint_grad_and_var() and _projection_op() methods.
@ -291,16 +295,16 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
def _projection_op(self, state, name=None):
pass
def minimize_constrained(self,
minimization_problem,
global_step=None,
var_list=None,
gate_gradients=train_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
"""Returns an `Op` for minimizing the constrained problem.
def _minimize_constrained(self,
minimization_problem,
global_step=None,
var_list=None,
gate_gradients=train_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
"""Returns an `Operation` for minimizing the constrained problem.
The `optimizer` constructor parameter will be used to update the model
parameters, while the constraint/objective weight matrix (the analogue of
@ -320,8 +324,11 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
name: as in `tf.train.Optimizer`'s `minimize` method.
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
Raises:
ValueError: If the minimization_problem tensors have different dtypes.
Returns:
TensorFlow Op.
`Operation`, the train_op.
"""
objective = minimization_problem.objective
@ -329,6 +336,14 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
proxy_constraints = minimization_problem.proxy_constraints
if proxy_constraints is None:
proxy_constraints = constraints
# Make sure that the objective, constraints and proxy constraints all have
# the same dtype.
if (objective.dtype.base_dtype != constraints.dtype.base_dtype or
objective.dtype.base_dtype != proxy_constraints.dtype.base_dtype):
raise ValueError("objective, constraints and proxy_constraints must "
"have the same dtype")
# Flatten both constraints tensors to 1d.
num_constraints = minimization_problem.num_constraints
constraints = standard_ops.reshape(constraints, shape=(num_constraints,))
@ -344,15 +359,18 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
name="swap_regret_optimizer_state")
zero_and_constraints = standard_ops.concat(
(standard_ops.zeros((1,)), constraints), axis=0)
(standard_ops.zeros((1,), dtype=constraints.dtype), constraints),
axis=0)
objective_and_proxy_constraints = standard_ops.concat(
(standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0)
distribution = self._distribution(state)
loss = standard_ops.tensordot(distribution, objective_and_proxy_constraints,
1)
loss = standard_ops.tensordot(
standard_ops.cast(distribution, objective_and_proxy_constraints.dtype),
objective_and_proxy_constraints, 1)
matrix_gradient = standard_ops.matmul(
standard_ops.expand_dims(zero_and_constraints, 1),
standard_ops.expand_dims(
standard_ops.cast(zero_and_constraints, distribution.dtype), 1),
standard_ops.expand_dims(distribution, 0))
update_ops = []
@ -555,6 +573,7 @@ class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer):
log_initial_one = math.log(1.0 - (self._initial_multiplier_radius *
(dimension - 1) / (dimension)))
log_initial_zero = math.log(self._initial_multiplier_radius / dimension)
# FUTURE WORK: make the dtype a parameter.
return standard_ops.concat(
(standard_ops.constant(
log_initial_one, dtype=dtypes.float32, shape=(1, dimension)),

View File

@ -42,13 +42,13 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const std::vector<string>& transformations,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: GraphDatasetBase(ctx),
: DatasetBase(DatasetContext(ctx)),
input_(input),
transformations_(transformations),
output_types_(output_types),

View File

@ -131,7 +131,7 @@ class CSVDatasetOp : public DatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header,
string compression_type, io::ZlibCompressionOptions options,
@ -139,7 +139,7 @@ class CSVDatasetOp : public DatasetOpKernel {
const std::vector<PartialTensorShape>& output_shapes,
std::vector<Tensor> record_defaults, std::vector<int64> select_cols,
bool use_quote_delim, char delim, string na_value)
: GraphDatasetBase(ctx),
: DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
header_(header),
out_type_(output_types),

View File

@ -63,11 +63,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
std::vector<DatasetBase*> data_inputs)
: GraphDatasetBase(ctx),
: DatasetBase(DatasetContext(ctx)),
selector_input_(selector_input),
data_inputs_(std::move(data_inputs)) {
selector_input_->Ref();

View File

@ -35,10 +35,10 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
: GraphDatasetBase(ctx), input_(input) {
: DatasetBase(DatasetContext(ctx)), input_(input) {
input_->Ref();
}

View File

@ -929,10 +929,9 @@ class MultiDeviceIteratorInitOp : public OpKernel {
LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
core::ScopedUnref unref(resource);
IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK(ctx,
dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
OP_REQUIRES_OK(ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator",
&iterator));
int64 incarnation_id;
OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
&incarnation_id));

View File

@ -130,11 +130,13 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
ThreadPoolResource* threadpool)
: GraphDatasetBase(ctx), input_(input), threadpool_(threadpool) {
: DatasetBase(DatasetContext(ctx)),
input_(input),
threadpool_(threadpool) {
input_->Ref();
threadpool_->Ref();
}
@ -165,9 +167,8 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented(
"Cannot currently serialize the thread pool for a "
"ThreadPoolDataset.");
return errors::Unimplemented("%s does not support serialization",
DebugString());
}
private:

View File

@ -47,10 +47,10 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
}
private:
class Dataset : public GraphDatasetBase {
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input)
: GraphDatasetBase(ctx), input_(input) {
: DatasetBase(DatasetContext(ctx)), input_(input) {
input_->Ref();
}

Some files were not shown because too many files have changed in this diff Show More