Merge branch 'master' into java-eager-tensor

This commit is contained in:
Karl Lessard 2019-05-12 00:34:00 -04:00 committed by GitHub
commit 4384648a78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1265 changed files with 57386 additions and 46070 deletions

View File

@ -293,7 +293,6 @@ tf_cuda_cc_test(
"//conditions:default": [],
}),
tags = [
"no_oss", # http://b/119522529
"noasan",
],
# We must ensure that the dependencies can be dynamically linked since

View File

@ -30,8 +30,8 @@ limitations under the License.
#include "tensorflow/cc/ops/while_loop.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/framework/logging.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/kernels/logging_ops.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/device_mgr.h"

View File

@ -36,6 +36,7 @@ py_binary(
name = "make_test_graphs",
testonly = 1,
srcs = ["make_test_graphs.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",

View File

@ -220,6 +220,7 @@ cc_library(
name = "shape_inference_helpers",
srcs = ["shape_inference_helpers.cc"],
hdrs = ["shape_inference_helpers.h"],
visibility = [":friends"],
deps = ["//tensorflow/core:graph"],
)
@ -262,7 +263,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -270,6 +270,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
@ -466,6 +467,9 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],

View File

@ -165,6 +165,18 @@ bool LogNotCompilableAndReturn(const Node& node,
return false;
}
bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) {
// b/127344411: SelfAdjointEigV2 and Svd precision issues.
return node.type_string() == "SelfAdjointEigV2" ||
node.type_string() == "Svd";
}
bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) {
// b/128001705: SelfAdjointEigV2 and Svd performance issues.
return node.type_string() == "SelfAdjointEigV2" ||
node.type_string() == "Svd" || node.type_string() == "Qr";
}
bool RecursiveCompilabilityChecker::IsCompilableNode(
const Node& node, int depth, FunctionLibraryRuntime* lib_runtime) {
// _Arg nodes in a top-level function represent feeds and _Retval nodes in a
@ -228,8 +240,12 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
"resource variable op in called function");
}
if (!op_filter_.allow_svd_op && node.type_string() == "Svd") {
return LogNotCompilableAndReturn(node, "Svd ops disabled");
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) {
return LogNotCompilableAndReturn(node, "operation with correctness issues");
}
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) {
return LogNotCompilableAndReturn(node, "slow operation");
}
return true;
@ -248,7 +264,8 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
registration.elide_assert_and_checknumerics;
op_filter.allow_ops_producing_or_consuming_variant =
registration.cluster_variant_ops;
op_filter.allow_svd_op = registration.cluster_svd_op;
op_filter.allow_slow_and_inaccurate_ops =
registration.cluster_slow_and_inaccurate_ops;
return op_filter;
}

View File

@ -97,10 +97,9 @@ class RecursiveCompilabilityChecker {
// live-out DT_VARIANT values.
bool allow_ops_producing_or_consuming_variant;
// Whether the "Svd" op should be auto-clustered. The XLA implemenation of
// this op has some performance (b/128001705) and possibly correctness
// (b/127344411) issues so we avoid auto-clustering it.
bool allow_svd_op;
// Whether ops known to be slow or to have correctness issues should be
// auto-clustered.
bool allow_slow_and_inaccurate_ops;
};
RecursiveCompilabilityChecker(const OperationFilter* op_filter,
@ -119,6 +118,11 @@ class RecursiveCompilabilityChecker {
return IsCompilableCall(call_def, /*depth=*/0, lib_runtime);
}
// Returns true if XLA supports this Op, but we don't want to cluster it (ie:
// due to performance or correctness concerns).
bool OpIsInaccurate(const Node& node);
bool OpIsSlow(const Node& node);
private:
bool IsCompilableNode(const Node& node, int depth,
FunctionLibraryRuntime* lib_runtime);

View File

@ -371,7 +371,8 @@ class PredicateFactory {
Predicate** predicate) {
TensorId tensor_id(node->name(), output_idx);
bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL;
bool is_boolean_tensor =
BaseType(node->output_type(tensor_id.index())) == DT_BOOL;
TF_RET_CHECK(!must_be_true || is_boolean_tensor);
if (node->type_string() == "Const" && must_be_true) {

View File

@ -1067,5 +1067,25 @@ TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
}
TEST(DeadnessAnalysisTest, RefBoolSwitchCondition) {
Scope root = Scope::NewRootScope().ExitOnError();
Output condition_ref_var =
ops::Variable(root.WithOpName("cond_ref"), TensorShape({}), DT_BOOL);
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
ops::Switch sw(root.WithOpName("switch"), value, condition_ref_var);
Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
FixupSourceAndSinkEdges(root.graph());
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "~*cond_ref:0");
EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "*cond_ref:0");
}
} // namespace
} // namespace tensorflow

View File

@ -97,28 +97,19 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
return Status::OK();
}
Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices,
bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device,
absl::optional<jit::DeviceId>* out_device_picked) {
if (out_can_pick_device) {
*out_can_pick_device = true;
}
xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
bool failure_to_pick_is_error) {
#define FAILED_TO_PICK_DEVICE(failing_status) \
do { \
if (out_can_pick_device) { \
*out_can_pick_device = false; \
return Status::OK(); \
} else { \
if (failure_to_pick_is_error) { \
return failing_status; \
} else { \
return {absl::nullopt}; \
} \
} while (false)
TF_RET_CHECK(!devices.IsEmpty()) << "No devices to choose from";
DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
absl::optional<jit::DeviceId> maybe_gpu_device;
absl::optional<jit::DeviceId> maybe_cpu_device;
absl::optional<jit::DeviceId> maybe_unknown_device;
@ -182,17 +173,15 @@ Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache,
}
}
if (out_device_picked) {
if (maybe_gpu_device) {
*out_device_picked = *maybe_gpu_device;
} else if (maybe_unknown_device) {
*out_device_picked = *maybe_unknown_device;
} else {
*out_device_picked = *maybe_cpu_device;
}
if (maybe_gpu_device) {
return {*maybe_gpu_device};
} else if (maybe_unknown_device) {
return {*maybe_unknown_device};
} else if (maybe_cpu_device) {
return {*maybe_cpu_device};
}
return Status::OK();
FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
#undef FAILED_TO_PICK_DEVICE
}
@ -200,21 +189,18 @@ Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache,
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
absl::optional<jit::DeviceId> device;
TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(
device_info_cache, devices, allow_mixing_unknown_and_cpu,
/*out_can_pick_device=*/nullptr, &device));
return *device;
TF_ASSIGN_OR_RETURN(absl::optional<jit::DeviceId> device_id,
PickDeviceForXlaImpl(device_info_cache, devices,
allow_mixing_unknown_and_cpu,
/*failure_to_pick_is_error=*/true));
return *device_id;
}
xla::StatusOr<bool> CanPickDeviceForXla(
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
bool can_pick_device;
TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(device_info_cache, devices,
allow_mixing_unknown_and_cpu,
&can_pick_device,
/*out_device_picked=*/nullptr));
return can_pick_device;
return PickDeviceForXlaImpl(device_info_cache, devices,
allow_mixing_unknown_and_cpu,
/*failure_to_pick_is_error=*/false);
}
} // namespace tensorflow

View File

@ -71,17 +71,34 @@ class DeviceSet {
// iterator if this ends up being used widely.
for (int word_index = 0; word_index < storage_.size(); word_index++) {
uint64 word = storage_[word_index];
for (int bit_index = 0; bit_index < kWordSize; bit_index++) {
if (word & (1ull << bit_index)) {
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
return;
}
while (word != 0) {
uint64 only_lowest_bit_set = word & -word;
// The number of trailing zeros in a non-zero word is the index of the
// least significant 1.
int bit_index = ctz_uint64(word);
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
return;
}
word ^= only_lowest_bit_set;
}
}
}
private:
static int ctz_uint64(uint64 x) {
DCHECK_NE(x, 0);
#ifdef __GNUC__
return __builtin_ctzl(x);
#else
int result = 0u;
while ((x & 1u) == 0u) {
x >>= 1;
++result;
}
return result;
#endif
}
absl::InlinedVector<uint64, 1> storage_;
const int kWordSize = 64;
@ -181,9 +198,12 @@ xla::StatusOr<jit::DeviceId> PickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
// This is like `PickDeviceForXla` except that it returns false (instead of a
// This is like `PickDeviceForXla` except that it returns nullopt (instead of a
// non-OK Status) if no unambiguous choice of device exists.
xla::StatusOr<bool> CanPickDeviceForXla(
//
// We return a failing Status for errors unrelated to the device choice
// algorithm itself.
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
} // namespace tensorflow

View File

@ -537,8 +537,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
XlaClusterInfo{func, func_name_attrs, xla_computation_node,
std::map<string, int>{}});
}
bool modified;
s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
graph_out.get(), flr, lib_def.get());
graph_out.get(), flr, lib_def.get(), &modified);
if (!s.ok()) return s;
GraphDef graphdef_out;
@ -1105,7 +1106,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}},
{"F"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
@ -1985,7 +1988,9 @@ TEST(EncapsulateSubgraphsTest,
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}}},
},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}});
@ -2110,7 +2115,9 @@ TEST(EncapsulateSubgraphsTest,
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
@ -2258,7 +2265,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
absl::Span<const string>(
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
{}},
{{"outside_compilation_O3_host_compute"},
"XlaHostCompute",
@ -2271,7 +2279,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O3"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
absl::Span<const string>({"_xla_token_arg_node",
"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"})}},
{}}},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}});

View File

@ -14,9 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include <algorithm>
#include <iterator>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/shape_inference.h"
@ -24,6 +27,9 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using stream_executor::port::StatusOr;
namespace tensorflow {
@ -333,6 +339,43 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
return Status::OK();
}
StatusOr<std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(
const Graph* g, const string& outside_compilation_attr_name) {
auto cluster_deps = absl::make_unique<
absl::flat_hash_map<string, absl::flat_hash_set<string>>>();
for (const Edge* e : g->edges()) {
auto src_outside_compilation =
GetStringAttr(*e->src(), outside_compilation_attr_name);
auto dst_outside_compilation =
GetStringAttr(*e->dst(), outside_compilation_attr_name);
if (src_outside_compilation && dst_outside_compilation &&
*src_outside_compilation != *dst_outside_compilation) {
auto dst_deps_it = cluster_deps->find(*dst_outside_compilation);
if (dst_deps_it == cluster_deps->end()) {
cluster_deps->insert(std::make_pair(
*dst_outside_compilation,
absl::flat_hash_set<string>({*src_outside_compilation})));
} else {
dst_deps_it->second.insert(*src_outside_compilation);
}
}
}
auto cluster_deps_ordered =
absl::make_unique<absl::flat_hash_map<string, std::vector<string>>>();
for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) {
std::vector<string> ordered_deps(it->second.begin(), it->second.end());
std::sort(ordered_deps.begin(), ordered_deps.end());
cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps));
}
return std::move(cluster_deps_ordered);
}
Status PreprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) {
// Remove edges from source node to outside compilation nodes, and edges

View File

@ -19,7 +19,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
@ -89,6 +91,15 @@ struct XlaClusterInfo {
const std::map<string, int> host_compute_core;
};
// Finds dependencies between outside compilation clusters, including both data
// dependencies and control dependencies. cluster_deps maps the name name of an
// outside compilation cluster to a set of names of outside compilation clusters
// that it depends on.
stream_executor::port::StatusOr<
std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(
const Graph* g, const string& outside_compilation_attr_name);
// Preprocesses edges within the same XLA cluster. It will perform the following
// operations in order:
//

View File

@ -15,12 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
@ -287,15 +289,20 @@ absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
return results;
}
string host_compute_node_name(const string& original_oc_name) {
return absl::StrCat("outside_compilation_", original_oc_name,
"_host_compute");
}
// Builds XlaHostCompute NodeDef from the outside compilation call node.
xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
const Node* call_node, const std::map<string, int>& host_compute_core) {
const Node* call_node, const std::map<string, int>& host_compute_core,
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
string original_oc_name;
TF_RETURN_IF_ERROR(GetNodeAttr(
call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
NodeDefBuilder host_compute_builder(
absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"),
"XlaHostCompute");
NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
"XlaHostCompute");
// Copy all attributes.
for (auto attr : call_node->attrs()) {
@ -309,9 +316,25 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
host_compute_builder.Attr("tpu_core", core);
}
// Set input tokens.
host_compute_builder.Attr(kXlaTokenInputNodesAttrName,
std::vector<string>{kXlaTokenArgNodeName});
// Set input tokens and other outside compilation clusters that current
// cluster depends in `kXlaTokenArgNodeName`. This is needed because when
// outside compilation subgraphs are encapsulated and moved to host graph,
// control/data edges between them will only be reflected in host graph.
// From XLA's perspective, two originally dependent clusters are no longer
// connected, which makes them look like they can be scheduled for execution
// in arbitrary order even though in fact they must be executed in order
// according to their host-side graph dependency. This can cause deadlock.
// Therefore, we hint XLA what the correct ordering of these clusters should
// be to avoid deadlocks.
std::vector<string> xla_token_input_nodes;
xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
auto cluster_deps_it = cluster_deps.find(original_oc_name);
if (cluster_deps_it != cluster_deps.end()) {
for (auto dep : cluster_deps_it->second) {
xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
}
}
host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
// Populate inputs.
std::vector<DataType> input_dtypes;
@ -371,7 +394,8 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) {
// If the function call node has no input/output edges, we will just remove it
// and not create a XlaHostCompute node.
Status ReplaceOrRemoveOutsideCompilationCallNode(
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core) {
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
// If the function call node has no input/output edges, just remove it.
bool has_edge = false;
for (auto e : call_node->in_edges()) {
@ -393,8 +417,9 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
}
// Build XlaHostCompute NodeDef.
TF_ASSIGN_OR_RETURN(NodeDef node_def,
BuildXlaHostComputeNodeDef(call_node, host_compute_core));
TF_ASSIGN_OR_RETURN(
NodeDef node_def,
BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
TF_ASSIGN_OR_RETURN(Node * host_compute_node,
ReplaceNode(g, call_node, node_def));
VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
@ -1589,6 +1614,11 @@ Status ExtractOutsideCompilationForFunction(
// We cannot early return here, because we might have outside compilation in
// If/While function body.
// Find dependencies between outside compilation clusters.
TF_ASSIGN_OR_RETURN(auto cluster_deps,
OutsideCompilationClusterDependencies(
fbody->graph, outside_compilation_attr_name));
// Preprocess edges between different outside compilations. They will be
// restored in `ConstructHostGraph()`.
TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
@ -1643,7 +1673,7 @@ Status ExtractOutsideCompilationForFunction(
for (Node* n : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core));
graph_out.get(), n, host_compute_core, *cluster_deps));
}
// Handle nodes with associated functions.
@ -1691,11 +1721,13 @@ Status ExtractOutsideCompilation(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
bool* modified) {
if (VLOG_IS_ON(4)) {
DumpGraphToFile("extract_outside_compilation_before", *g, fld);
}
*modified = false;
auto node_name_index = g->BuildNodeNameIndex();
for (auto& iter : clusters) {
string xla_cluster_name = iter.first;
@ -1711,6 +1743,7 @@ Status ExtractOutsideCompilation(
func_name_attrs, func_name_attrs.name(), host_graph_func_name,
host_compute_core, flr, fld, &shape_inference_graphs,
&has_outside_compilation));
*modified |= has_outside_compilation;
string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
Node* pivot_node = node_name_index[pivot_name];

View File

@ -101,7 +101,8 @@ Status ExtractOutsideCompilation(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld);
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
bool* modified);
} // namespace tensorflow

View File

@ -922,4 +922,145 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
}
}
TEST_F(ExtractOutsideCompilationForFunctionTest,
OutsideCompilationClusterDataDependency) {
// Build the XLA computation func.
// "const0"
// "identity0" = "const0" (outside compilation cluster "0")
// "identity1" = "identity0" (outside compilation cluster "1")
// "identity2" = "identity1"
FunctionDefLibrary fdl;
{
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
<< std::endl;
auto node_name_image = g->BuildNodeNameIndex();
node_name_image["identity0"]->AddAttr("_oc", "0");
node_name_image["identity1"]->AddAttr("_oc", "1");
PartialTensorShape shape({2});
node_name_image["identity1"]->AddAttr(
kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
protobuf::Map<string, tensorflow::AttrValue> attrs;
std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
std::vector<string> shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationTest(
"_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
// Get rewritten XLA computation function.
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
// Check XlaHostCompute nodes.
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
EXPECT_NE(host_compute_0, nullptr);
Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
EXPECT_NE(host_compute_1, nullptr);
// Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
std::vector<string> token_input_nodes;
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
token_input_nodes.clear();
std::vector<string> expected_token_input_nodes_1(
{"_xla_token_arg_node", "outside_compilation_0_host_compute"});
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
}
TEST_F(ExtractOutsideCompilationForFunctionTest,
OutsideCompilationClusterControlDependency) {
// Build the XLA computation func.
// "const0"
// "identity0" = "const0" (outside compilation cluster "0")
// "identity1" = "const0" "^identity0" (outside compilation cluster "1",
// control depdent on cluster "0")
// "identity2" = "identity1"
FunctionDefLibrary fdl;
{
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
Output identity1 = ops::Identity(
s.WithOpName("identity1").WithControlDependencies(identity0), const0);
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
<< std::endl;
auto node_name_image = g->BuildNodeNameIndex();
node_name_image["identity0"]->AddAttr("_oc", "0");
node_name_image["identity1"]->AddAttr("_oc", "1");
PartialTensorShape shape({2});
node_name_image["identity1"]->AddAttr(
kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
protobuf::Map<string, tensorflow::AttrValue> attrs;
std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
std::vector<string> shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationTest(
"_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
// Get rewritten XLA computation function.
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
// Check XlaHostCompute nodes.
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
EXPECT_NE(host_compute_0, nullptr);
Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
EXPECT_NE(host_compute_1, nullptr);
// Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
std::vector<string> token_input_nodes;
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
token_input_nodes.clear();
std::vector<string> expected_token_input_nodes_1(
{"_xla_token_arg_node", "outside_compilation_0_host_compute"});
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
}
} // namespace tensorflow

View File

@ -13,11 +13,23 @@ cc_library(
srcs = ["graphcycles.cc"],
hdrs = ["graphcycles.h"],
deps = [
":ordered_set",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "ordered_set",
hdrs = ["ordered_set.h"],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:span",
],
)
@ -31,3 +43,14 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "ordered_set_test",
srcs = ["ordered_set_test.cc"],
deps = [
":ordered_set",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -38,13 +38,16 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
typedef std::unordered_set<int32> NodeSet;
using NodeSet = absl::flat_hash_set<int32>;
using OrderedNodeSet = OrderedSet<int32>;
template <typename T>
struct VecStruct {
typedef absl::InlinedVector<T, 4> type;
@ -53,13 +56,11 @@ template <typename T>
using Vec = typename VecStruct<T>::type;
struct Node {
Node() : in(4), out(4) {} // Small hashtables for in/out edges
int32 rank; // rank number assigned by Pearce-Kelly algorithm
bool visited; // Temporary marker used by depth-first-search
void* data; // User-supplied data
NodeSet in; // List of immediate predecessor nodes in graph
NodeSet out; // List of immediate successor nodes in graph
OrderedNodeSet in; // List of immediate predecessor nodes in graph
OrderedNodeSet out; // List of immediate successor nodes in graph
};
} // namespace
@ -96,7 +97,7 @@ bool GraphCycles::CheckInvariants() const {
if (!ranks.insert(nx->rank).second) {
LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank;
}
for (auto y : nx->out) {
for (int32 y : nx->out.GetSequence()) {
Node* ny = r->nodes_[y];
if (nx->rank >= ny->rank) {
LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment "
@ -127,14 +128,14 @@ int32 GraphCycles::NewNode() {
void GraphCycles::RemoveNode(int32 node) {
Node* x = rep_->nodes_[node];
for (auto y : x->out) {
rep_->nodes_[y]->in.erase(node);
for (int32 y : x->out.GetSequence()) {
rep_->nodes_[y]->in.Erase(node);
}
for (auto y : x->in) {
rep_->nodes_[y]->out.erase(node);
for (int32 y : x->in.GetSequence()) {
rep_->nodes_[y]->out.Erase(node);
}
x->in.clear();
x->out.clear();
x->in.Clear();
x->out.Clear();
rep_->free_nodes_.push_back(node);
}
@ -147,12 +148,12 @@ void GraphCycles::SetNodeData(int32 node, void* data) {
}
bool GraphCycles::HasEdge(int32 x, int32 y) const {
return rep_->nodes_[x]->out.find(y) != rep_->nodes_[x]->out.end();
return rep_->nodes_[x]->out.Contains(y);
}
void GraphCycles::RemoveEdge(int32 x, int32 y) {
rep_->nodes_[x]->out.erase(y);
rep_->nodes_[y]->in.erase(x);
rep_->nodes_[x]->out.Erase(y);
rep_->nodes_[y]->in.Erase(x);
// No need to update the rank assignment since a previous valid
// rank assignment remains valid after an edge deletion.
}
@ -168,13 +169,13 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) {
if (x == y) return false;
Rep* r = rep_;
Node* nx = r->nodes_[x];
if (!nx->out.insert(y).second) {
if (!nx->out.Insert(y)) {
// Edge already exists.
return true;
}
Node* ny = r->nodes_[y];
ny->in.insert(x);
ny->in.Insert(x);
if (nx->rank <= ny->rank) {
// New edge is consistent with existing rank assignment.
@ -185,8 +186,8 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) {
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
if (!ForwardDFS(r, y, nx->rank)) {
// Found a cycle. Undo the insertion and tell caller.
nx->out.erase(y);
ny->in.erase(x);
nx->out.Erase(y);
ny->in.Erase(x);
// Since we do not call Reorder() on this path, clear any visited
// markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf_);
@ -212,7 +213,7 @@ static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound) {
nn->visited = true;
r->deltaf_.push_back(n);
for (auto w : nn->out) {
for (auto w : nn->out.GetSequence()) {
Node* nw = r->nodes_[w];
if (nw->rank == upper_bound) {
return false; // Cycle
@ -238,7 +239,7 @@ static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound) {
nn->visited = true;
r->deltab_.push_back(n);
for (auto w : nn->in) {
for (auto w : nn->in.GetSequence()) {
Node* nw = r->nodes_[w];
if (!nw->visited && lower_bound < nw->rank) {
r->stack_.push_back(w);
@ -324,7 +325,7 @@ int GraphCycles::FindPath(int32 x, int32 y, int max_path_len,
return path_len;
}
for (auto w : r->nodes_[n]->out) {
for (auto w : r->nodes_[n]->out.GetSequence()) {
if (seen.insert(w).second) {
r->stack_.push_back(w);
}
@ -378,31 +379,35 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
}
Node* nb = rep_->nodes_[b];
std::unordered_set<int32> out = std::move(nb->out);
std::unordered_set<int32> in = std::move(nb->in);
for (auto y : out) {
rep_->nodes_[y]->in.erase(b);
OrderedNodeSet out = std::move(nb->out);
OrderedNodeSet in = std::move(nb->in);
for (int32 y : out.GetSequence()) {
rep_->nodes_[y]->in.Erase(b);
}
for (auto y : in) {
rep_->nodes_[y]->out.erase(b);
for (int32 y : in.GetSequence()) {
rep_->nodes_[y]->out.Erase(b);
}
rep_->free_nodes_.push_back(b);
for (auto y : out) {
rep_->nodes_[a]->out.Reserve(rep_->nodes_[a]->out.Size() + out.Size());
for (int32 y : out.GetSequence()) {
InsertEdge(a, y);
}
for (auto y : in) {
rep_->nodes_[a]->in.Reserve(rep_->nodes_[a]->in.Size() + in.Size());
for (int32 y : in.GetSequence()) {
InsertEdge(y, a);
}
return true;
}
std::unordered_set<int32> GraphCycles::Successors(int32 node) const {
return rep_->nodes_[node]->out;
absl::Span<const int32> GraphCycles::Successors(int32 node) const {
return rep_->nodes_[node]->out.GetSequence();
}
std::unordered_set<int32> GraphCycles::Predecessors(int32 node) const {
return rep_->nodes_[node]->in;
absl::Span<const int32> GraphCycles::Predecessors(int32 node) const {
return rep_->nodes_[node]->in.GetSequence();
}
namespace {
@ -444,7 +449,7 @@ string GraphCycles::DebugString() const {
continue;
}
for (int32 succ : rep_->nodes_[i]->out) {
for (int32 succ : rep_->nodes_[i]->out.GetSequence()) {
absl::StrAppend(&result, " \"", i, "\" -> \"", succ, "\"\n");
}
}

View File

@ -40,8 +40,7 @@ limitations under the License.
// FindPath() is linear in the size of the graph.
// The current implementation uses O(|V|+|E|) space.
#include <unordered_set>
#include "absl/types/span.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@ -119,8 +118,8 @@ class GraphCycles {
// Expensive: should only be called from graphcycles_test.cc.
bool CheckInvariants() const;
std::unordered_set<int32> Successors(int32 node) const;
std::unordered_set<int32> Predecessors(int32 node) const;
absl::Span<const int32> Successors(int32 node) const;
absl::Span<const int32> Predecessors(int32 node) const;
// Returns all nodes in post order.
//

View File

@ -0,0 +1,85 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
// This is a set data structure that provides a deterministic iteration order.
// The iteration order of elements only depends on the sequence of
// inserts/deletes, so as long as the inserts/deletes happen in the same
// sequence, the set will have the same iteration order.
//
// Assumes that T can be cheaply copied for simplicity.
template <typename T>
class OrderedSet {
public:
// Inserts `value` into the ordered set. Returns true if the value was not
// present in the set before the insertion.
bool Insert(T value) {
bool new_insertion =
value_to_index_.insert({value, value_sequence_.size()}).second;
if (new_insertion) {
value_sequence_.push_back(value);
}
return new_insertion;
}
// Removes `value` from the set. Assumes `value` is already present in the
// set.
void Erase(T value) {
auto it = value_to_index_.find(value);
DCHECK(it != value_to_index_.end());
// Since we don't want to move values around in `value_sequence_` we swap
// the value in the last position and with value to be deleted and then
// pop_back.
value_to_index_[value_sequence_.back()] = it->second;
std::swap(value_sequence_[it->second], value_sequence_.back());
value_sequence_.pop_back();
value_to_index_.erase(it);
}
void Reserve(size_t new_size) {
value_to_index_.reserve(new_size);
value_sequence_.reserve(new_size);
}
void Clear() {
value_to_index_.clear();
value_sequence_.clear();
}
bool Contains(T value) const { return value_to_index_.contains(value); }
size_t Size() const { return value_sequence_.size(); }
absl::Span<T const> GetSequence() const { return value_sequence_; }
private:
// The stable order that we maintain through insertions and deletions.
std::vector<T> value_sequence_;
// Maps values to their indices in `value_sequence_`.
absl::flat_hash_map<T, int> value_to_index_;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_

View File

@ -0,0 +1,117 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
TEST(OrderedSetTest, Insert) {
OrderedSet<int> ordered_set;
EXPECT_TRUE(ordered_set.Insert(90));
EXPECT_TRUE(ordered_set.Insert(100));
EXPECT_TRUE(ordered_set.Insert(80));
EXPECT_FALSE(ordered_set.Insert(100));
EXPECT_EQ(ordered_set.Size(), 3);
EXPECT_TRUE(ordered_set.Contains(90));
EXPECT_TRUE(ordered_set.Contains(100));
EXPECT_TRUE(ordered_set.Contains(80));
EXPECT_FALSE(ordered_set.Contains(40));
std::array<int, 3> expected_sequence = {90, 100, 80};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence);
}
TEST(OrderedSetTest, Erase) {
OrderedSet<int> ordered_set;
EXPECT_TRUE(ordered_set.Insert(90));
EXPECT_TRUE(ordered_set.Insert(100));
EXPECT_TRUE(ordered_set.Insert(80));
ordered_set.Erase(100);
EXPECT_EQ(ordered_set.Size(), 2);
EXPECT_TRUE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_TRUE(ordered_set.Contains(80));
std::array<int, 2> expected_sequence_0 = {90, 80};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_0);
ordered_set.Erase(80);
EXPECT_EQ(ordered_set.Size(), 1);
EXPECT_TRUE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_FALSE(ordered_set.Contains(80));
std::array<int, 1> expected_sequence_1 = {90};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_1);
ordered_set.Erase(90);
EXPECT_EQ(ordered_set.Size(), 0);
EXPECT_FALSE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_FALSE(ordered_set.Contains(80));
std::array<int, 0> expected_sequence_2 = {};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_2);
}
TEST(OrderedSetTest, Clear) {
OrderedSet<int> ordered_set;
EXPECT_TRUE(ordered_set.Insert(90));
EXPECT_TRUE(ordered_set.Insert(100));
EXPECT_TRUE(ordered_set.Insert(80));
ordered_set.Clear();
EXPECT_EQ(ordered_set.Size(), 0);
EXPECT_FALSE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_FALSE(ordered_set.Contains(80));
std::array<int, 0> expected_sequence = {};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence);
}
TEST(OrderedSetTest, LargeInsertions) {
const int kSize = 50 * 9000;
OrderedSet<int> ordered_set;
for (int i = 0; i < kSize; i++) {
EXPECT_TRUE(ordered_set.Insert(i + 500));
}
for (int i = 0; i < kSize; i++) {
EXPECT_EQ(ordered_set.GetSequence()[i], i + 500);
}
}
} // namespace
} // namespace tensorflow

View File

@ -62,7 +62,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
std::unique_ptr<XlaAllocator> xla_allocator;
xla::DeviceMemoryAllocator* device_allocator = nullptr;
se::DeviceMemoryAllocator* device_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
platform_id = se::host::kHostPlatformId;

View File

@ -40,7 +40,7 @@ class XlaPlatformInfo {
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
std::unique_ptr<XlaAllocator> xla_allocator,
xla::DeviceMemoryAllocator* device_allocator)
se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type),
platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata),
@ -55,7 +55,7 @@ class XlaPlatformInfo {
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
}
xla::DeviceMemoryAllocator* allocator() const {
se::DeviceMemoryAllocator* allocator() const {
return device_allocator_ ? device_allocator_ : xla_allocator_.get();
}
DeviceType device_type() const { return device_type_; }
@ -86,7 +86,7 @@ class XlaPlatformInfo {
// then device_allocator_ is null and xla_allocator_ points to an appropriate
// XlaAllocator instance.
std::unique_ptr<XlaAllocator> xla_allocator_;
xla::DeviceMemoryAllocator* device_allocator_;
se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
};

View File

@ -270,7 +270,7 @@ class MarkForCompilationPassImpl {
StatusOr<bool> ShouldCompileCluster(const Cluster& cluster);
StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(
const Cluster& to);
const Cluster& from, const Cluster& to);
// Returns true if the devices in `cluster_a` and `cluster_b` are compatible
// and therefore not a hindrance for combining the two clusters into a larger
@ -698,7 +698,7 @@ Status MarkForCompilationPassImpl::DumpDebugInfo() {
StatusOr<bool>
MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
const Cluster& cluster_to) {
const Cluster& cluster_from, const Cluster& cluster_to) {
// If any of the consumer's producers are on a different device, do not
// cluster these nodes. This prevents other work on this device from being
// delayed by work on other devices. We consider predecessors of the entire
@ -722,6 +722,11 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
if (!devices_compatible) {
return true;
}
TF_ASSIGN_OR_RETURN(devices_compatible,
AreDevicesCompatible(cluster_from, *cluster_in));
if (!devices_compatible) {
return true;
}
}
}
@ -1026,7 +1031,7 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
}
TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency,
ClusteringWillIntroduceInterDeviceDependency(*to));
ClusteringWillIntroduceInterDeviceDependency(*from, *to));
if (will_introduce_cross_device_dependency) {
return LogNotContractableAndReturnFalse(
@ -1062,8 +1067,16 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdgesFrom(
Cluster* cluster_from) {
bool changed = false;
for (int to :
cycles_graph_.Successors(cluster_from->cycles_graph_node_id())) {
// Make a copy of the set of successors because we may modify the graph in
// TryToContractEdge.
std::vector<int32> successors_copy = [&] {
absl::Span<const int32> successors =
cycles_graph_.Successors(cluster_from->cycles_graph_node_id());
return std::vector<int32>(successors.begin(), successors.end());
}();
for (int to : successors_copy) {
iteration_count_++;
if (to >= graph_->num_node_ids()) {
// Node is a fictitious node that is present only in the cycle detection
@ -1265,19 +1278,15 @@ StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
DeviceSet devices = cluster_a.devices();
devices.UnionWith(cluster_b.devices());
// First check if we will even be able to pick a device for the larger
// combined cluster.
TF_ASSIGN_OR_RETURN(
bool can_pick_device,
CanPickDeviceForXla(device_info_cache_, devices,
/*allow_mixing_unknown_and_cpu=*/false));
if (!can_pick_device) {
absl::optional<jit::DeviceId> maybe_chosen_device,
MaybePickDeviceForXla(device_info_cache_, devices,
/*allow_mixing_unknown_and_cpu=*/false));
if (!maybe_chosen_device.has_value()) {
return false;
}
TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
PickDeviceForXla(device_info_cache_, devices,
/*allow_mixing_unknown_and_cpu=*/false));
jit::DeviceId chosen_device = *maybe_chosen_device;
// If we are able to pick a device `chosen_device` for the larger cluster, the
// resource operations in `cluster_a` and `cluster_b` must be placed on the
@ -1415,7 +1424,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
op_filter.allow_control_trigger = true;
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
op_filter.allow_ops_producing_or_consuming_variant = true;
op_filter.allow_svd_op = true;
op_filter.allow_slow_and_inaccurate_ops = true;
return RecursiveCompilabilityChecker{&op_filter, &jit_device_type}
.IsCompilableCall(ndef, flr);

View File

@ -1112,6 +1112,45 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) {
EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
}
TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) {
// This is similar to the 'DontClusterMergingNodes' above, except
// MatMulCombined is placed on the CPU.
Scope root = Scope::NewRootScope().ExitOnError();
absl::string_view xla_gpu_dev0 = "/job:worker/replica:0/task:0/device:GPU:0";
absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:GPU:1";
absl::string_view xla_cpu_dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
Output combined =
ops::MatMul(root.WithOpName("MatMulCombined_cpu"), matmul0, matmul1);
TF_ASSERT_OK(root.ToGraph(graph.get()));
for (Node* n : graph->nodes()) {
if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) {
n->set_assigned_device_name(string(xla_cpu_dev0));
} else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
n->set_assigned_device_name(string(xla_gpu_dev0));
} else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
n->set_assigned_device_name(string(xla_gpu_dev1));
}
}
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
// Each of the MatMuls should be in a separate cluster.
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]);
EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]);
EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
}
// TODO(b/117085735): This form of clustering should be prevented.
TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
// MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed

View File

@ -60,7 +60,7 @@ Status XlaCpuDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
registration.cluster_svd_op = true;
registration.cluster_slow_and_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
static XlaDeviceOpRegistrations* registrations =

View File

@ -95,7 +95,7 @@ Status XlaGpuDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
registration.cluster_svd_op = true;
registration.cluster_slow_and_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
static XlaDeviceOpRegistrations* registrations =

View File

@ -63,7 +63,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
registration.cluster_svd_op = true;
registration.cluster_slow_and_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
registration);

View File

@ -168,11 +168,11 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
}
XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
: xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
: se::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
XlaAllocator::~XlaAllocator() {}
xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
xla::StatusOr<se::OwningDeviceMemory> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
AllocationAttributes attrs;
attrs.no_retry_on_failure = !retry_on_failure;
@ -184,8 +184,8 @@ xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
"Out of memory while trying to allocate ", size, " bytes.");
}
}
return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
device_ordinal, this);
return se::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
device_ordinal, this);
}
Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
@ -194,7 +194,7 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
}
XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors, bool use_multiple_streams)
: client_(client),
xla_allocator_(xla_allocator),
@ -374,7 +374,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
++output_num;
@ -435,7 +435,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
*variable_infos[i].var()->tensor() = output_tensor;
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
output.set_buffer(se::OwningDeviceMemory(), {output_num});
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
write.type, write.shape, buffer, allocator);
*variable_infos[i].var()->tensor() = output_tensor;

View File

@ -23,14 +23,14 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/owning_device_memory.h"
namespace tensorflow {
class XlaAllocator;
@ -108,11 +108,11 @@ Status LockVariables(absl::Span<VariableInfo> variables)
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
// see comment on `AllowsAsynchronousDeallocation()`.
class XlaAllocator : public xla::DeviceMemoryAllocator {
class XlaAllocator : public se::DeviceMemoryAllocator {
public:
XlaAllocator(const se::Platform* platform, Allocator* wrapped);
~XlaAllocator() override;
xla::StatusOr<xla::OwningDeviceMemory> Allocate(
xla::StatusOr<se::OwningDeviceMemory> Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) override;
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
@ -142,7 +142,7 @@ class XlaComputationLaunchContext {
// because we track inter-stream dependencies through events inside XlaTensor
// objects.
XlaComputationLaunchContext(xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator,
se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors,
bool use_multiple_streams);
@ -186,7 +186,7 @@ class XlaComputationLaunchContext {
private:
xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_;
se::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
bool use_multiple_streams_;
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;

View File

@ -59,7 +59,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype,
xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
uint64 size =
client->backend().transfer_manager()->GetByteSizeRequirement(subshape);
TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer,
client->backend().memory_allocator()->Allocate(
device_ordinal, size, /*retry_on_failure=*/false));
// Move our buffer into shaped_buffer, which takes ownership of it.

View File

@ -458,10 +458,6 @@ tf_xla_py_test(
name = "extract_image_patches_op_test",
size = "small",
srcs = ["extract_image_patches_op_test.py"],
tags = [
"manual",
"notap",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",

View File

@ -57,7 +57,7 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
namespace tensorrt {

View File

@ -790,6 +790,14 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor {
float getDynamicRangeMax() const override { return 0.f; }
#endif
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
void setAllowedFormats(nvinfer1::TensorFormats formats) override {}
nvinfer1::TensorFormats getAllowedFormats() const override { return 1; }
bool isShape() const override { return false; }
#endif
private:
nvinfer1::DataType trt_dtype_;
nvinfer1::Dims trt_dims_;
@ -4455,6 +4463,40 @@ Status ConvertDepthSpaceShuffle(OpConverterParams* params) {
return Status::OK();
}
Status ConvertSquaredDifference(OpConverterParams* params) {
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
// Broadcast inputs.
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
TF_RETURN_IF_ERROR(params->converter->GetTrtBroadcastShape(
inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r));
nvinfer1::ITensor* tensor_l = nullptr;
nvinfer1::ITensor* tensor_r = nullptr;
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l));
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r));
if (params->validation_only) return Status::OK();
// Subtract x - y.
nvinfer1::IElementWiseLayer* sub =
params->converter->network()->addElementWise(
*tensor_l, *tensor_r, nvinfer1::ElementWiseOperation::kSUB);
TFTRT_RETURN_ERROR_IF_NULLPTR(sub, node_def.name());
// Multiply (x - y) * (x - y).
nvinfer1::IElementWiseLayer* mul =
params->converter->network()->addElementWise(
*sub->getOutput(0), *sub->getOutput(0),
nvinfer1::ElementWiseOperation::kPROD);
TFTRT_RETURN_ERROR_IF_NULLPTR(mul, node_def.name());
params->outputs->push_back(TRT_TensorOrWeights(mul->getOutput(0)));
return Status::OK();
}
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
Status ConvertCombinedNMS(OpConverterParams* params) {
TF_RETURN_IF_ERROR(
@ -4641,7 +4683,6 @@ static void RegisterValidatableOpConverters(
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
(*registration)["ExpandDims"] = ConvertExpandDims;
(*registration)["GatherV2"] = ConvertGather;
(*registration)["Identity"] = ConvertIdentity; // Identity should be removed
(*registration)["LeakyRelu"] = ConvertLeakyRelu;
(*registration)["MatMul"] = ConvertMatMul;
(*registration)["Pack"] = ConvertPack;
@ -4650,11 +4691,11 @@ static void RegisterValidatableOpConverters(
(*registration)["Reshape"] = ConvertReshape;
(*registration)["Rsqrt"] = ConvertRsqrt;
(*registration)["Slice"] = ConvertSlice;
(*registration)["Snapshot"] = ConvertIdentity; // Snapshot should be removed
(*registration)["Softmax"] = ConvertSoftmax;
(*registration)["SpaceToDepth"] = ConvertDepthSpaceShuffle;
(*registration)["Split"] = ConvertSplit;
(*registration)["Square"] = ConvertSquare;
(*registration)["SquaredDifference"] = ConvertSquaredDifference;
(*registration)["Squeeze"] = ConvertSqueeze;
(*registration)["StridedSlice"] = ConvertStridedSlice;
(*registration)["TopKV2"] = ConvertTopK;
@ -4688,6 +4729,11 @@ static void RegisterValidatableOpConverters(
for (auto arg_minmax_type : {"ArgMin", "ArgMax"}) {
(*registration)[arg_minmax_type] = ConvertArgMinMax;
}
// The following are no-ops during inference and will not be mapped to any TRT
// layer.
for (auto identity_op_type : {"Identity", "Snapshot", "StopGradient"}) {
(*registration)[identity_op_type] = ConvertIdentity;
}
}
void TrtNodeValidator::RegisterOpValidators() {

View File

@ -50,8 +50,8 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda.h"
#include "cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
@ -280,6 +280,14 @@ class FakeITensor : public nvinfer1::ITensor {
float getDynamicRangeMax() const override { return 0.f; }
#endif
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
void setAllowedFormats(nvinfer1::TensorFormats formats) override {}
nvinfer1::TensorFormats getAllowedFormats() const override { return 1; }
bool isShape() const override { return false; }
#endif
private:
string name_;
nvinfer1::Dims dims_;
@ -5353,6 +5361,108 @@ TEST_F(OpConverterTest, ConvertClipByValue) {
}
#endif // IS_TRT_VERSION_GE(5, 1, 2, 0)
// Get the NodeDef for SquaredDifference.
NodeDef GetSquaredDifferenceNodeDef(DataType dtype) {
Scope s = Scope::NewRootScope();
auto x = ops::Placeholder(s.WithOpName("x"), dtype);
auto y = ops::Placeholder(s.WithOpName("y"), dtype);
auto squared_diff =
ops::SquaredDifference(s.WithOpName("my_squared_diff"), x, y);
return squared_diff.operation.node()->def();
}
template <DataType dtype>
void TestConvertSquaredDifference(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
struct TestParams {
std::vector<int> dims_x;
std::vector<int> dims_y;
std::vector<CType> value_x;
std::vector<CType> value_y;
std::vector<int> expected_output_dims;
std::vector<CType> expected_output;
};
const std::vector<CType> common_input = InitTestVector<CType>(6);
std::vector<TestParams> params = {
{
/*dims_x=*/{1, 2, 3},
/*dims_y=*/{1, 2, 3},
/*value_x=*/common_input,
/*value_y=*/CastTestVector<int, CType>({0, -1, 3, 0, 10, -7}),
/*expected_output_dims=*/{1, 2, 3},
/*expected_output=*/CastTestVector<int, CType>({0, 4, 1, 9, 36, 144}),
},
{
/*dims_x=*/{1, 2, 3},
/*dims_y=*/{1, 1, 3},
/*value_x=*/common_input,
/*value_y=*/CastTestVector<int, CType>({0, 1, 2}),
/*expected_output_dims=*/{1, 2, 3},
/*expected_output=*/CastTestVector<int, CType>({0, 0, 0, 9, 9, 9}),
},
};
for (int i = 0; i < params.size(); ++i) {
test->Reset();
NodeDef node_def = GetSquaredDifferenceNodeDef(dtype);
test->AddTestTensor("x", params[i].dims_x, 1, TfDataTypeToTrt(dtype));
test->AddTestTensor("y", params[i].dims_y, 1, TfDataTypeToTrt(dtype));
test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_squared_diff", &output));
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray(params[i].expected_output_dims,
output.tensor()->getDimensions());
DataVec input_data{{"x", test::AsTensor<CType>(params[i].value_x)},
{"y", test::AsTensor<CType>(params[i].value_y)}};
DataVec output_data{
{"my_squared_diff",
ConstructTensor<CType>(params[i].expected_output.size())}};
test->BuildAndRun(
input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAreArray(params[i].expected_output));
}
}
TEST_F(OpConverterTest, ConvertSquaredDifference) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_squared_diff", "SquaredDifference", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"SquaredDifference got 0 inputs but expected 2, at my_squared_diff");
}
{
// Input is a weight, should fail.
Reset();
NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT);
AddTestWeights<float>("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
AddTestTensor("y", {1, 2, 3});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"The input \"x\" for SquaredDifference must be "
"a tensor, at my_squared_diff");
}
{
// Shapes are not broadcastable, should fail.
Reset();
NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT);
AddTestTensor("x", {2, 3});
AddTestTensor("y", {7, 5});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Infeasible broadcast scheme");
}
TestConvertSquaredDifference<DT_FLOAT>(this);
TestConvertSquaredDifference<DT_HALF>(this);
}
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow

View File

@ -41,7 +41,7 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
@ -259,7 +259,6 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
}
auto lib = ctx->function_library();
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.rendezvous = ctx->rendezvous();
opts.cancellation_manager = ctx->cancellation_manager();
opts.runner = ctx->runner();

View File

@ -32,7 +32,7 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
namespace tensorflow {
namespace tensorrt {

View File

@ -20,8 +20,8 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda.h"
#include "cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {

View File

@ -19,7 +19,7 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -22,7 +22,7 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
namespace tensorflow {
namespace tensorrt {

View File

@ -25,7 +25,7 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {

View File

@ -27,7 +27,10 @@ package(
default_visibility = [":internal"],
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load(
"//tensorflow/core:platform/default/cuda_build_defs.bzl",
"if_cuda_is_configured",
)
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
cc_library(

View File

@ -253,6 +253,7 @@ Status FunctionalizeControlFlowPass::Run(
{"XlaLaunch", "function"},
};
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
bool fld_modified = false;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
@ -273,9 +274,16 @@ Status FunctionalizeControlFlowPass::Run(
n->ClearAttr(func_attr);
func.set_name(new_func_name);
n->AddAttr(func_attr, func);
fld_modified = true;
}
}
if (fld_modified) {
TF_RETURN_IF_ERROR(
PruneUnreachableFunctionsFromGraph(*graph, options.flib_def));
}
if (VLOG_IS_ON(4)) {
DumpGraphToFile("functionalize_control_flow_after", *graph,
options.flib_def);

View File

@ -367,7 +367,7 @@ cc_library(
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
@ -380,7 +380,7 @@ cc_library(
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],

View File

@ -17,7 +17,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@ -99,23 +101,22 @@ class ExtractImagePatchesOp : public XlaOpKernel {
// The following code is equivalent to:
// eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
int64 kernel_size = 1;
std::vector<int64> lhs_shape(num_dims, 1);
std::vector<int64> kernel_shape(num_dims, 1);
for (int i = 0; i < num_spatial_dims; ++i) {
int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
lhs_shape[i] = ksizes_[input_dim];
kernel_shape[i] = ksizes_[input_dim];
kernel_size *= ksizes_[input_dim];
}
lhs_shape[num_spatial_dims] = depth;
lhs_shape[num_spatial_dims + 1] = 1;
// Builds an identity matrix as a broadcast equality of iotas.
// iota = np.arange(np.prod(ksize), depth)
// filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
xla::XlaOp iota = xla::Iota(builder, xla::S32, kernel_size * depth);
auto lhs = xla::Reshape(iota, lhs_shape);
auto filter = xla::ConvertElementType(
xla::Eq(lhs, iota, {num_spatial_dims + 1}), type);
kernel_shape[num_spatial_dims] = 1;
kernel_shape[num_spatial_dims + 1] = kernel_size * depth;
xla::Shape iota_kernel_shape =
xla::ShapeUtil::MakeShape(xla::S32, {kernel_size, depth, kernel_size});
xla::XlaOp filter =
xla::Reshape(xla::ConvertElementType(
xla::Eq(xla::Iota(builder, iota_kernel_shape, 0),
xla::Iota(builder, iota_kernel_shape, 2)),
type),
kernel_shape);
xla::ConvolutionDimensionNumbers dims;
std::vector<int64> window_strides(num_spatial_dims);
@ -148,7 +149,7 @@ class ExtractImagePatchesOp : public XlaOpKernel {
xla::XlaOp conv =
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
lhs_dilation, rhs_dilation, dims);
lhs_dilation, rhs_dilation, dims, depth);
ctx->SetOutput(0, conv);
}

View File

@ -16,7 +16,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
@ -46,4 +46,4 @@ extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
tensorflow::argmax_float_1d_xla_impl(out, data);
}
REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);

View File

@ -16,7 +16,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
@ -51,4 +51,4 @@ extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
tensorflow::argmax_float_2d_xla_impl(out, data);
}
REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);

View File

@ -15,6 +15,7 @@ limitations under the License.
// XLA specific pooling ops.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -327,6 +328,20 @@ class MaxPoolGradOp : public XlaOpKernel {
xla::Padding xla_padding =
(padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
// Create a MaxPool operation to check the expected resulting shape, and
// then throw away the operation because we don't actually neeed it here.
TensorShape expected_out_shape;
auto pooling =
xla::MaxPool(ctx->Input(0), ksize_, stride_, xla_padding,
XlaTensorFormat(data_format_, tensor_in_shape.dims() - 2));
auto status_or_shape = pooling.builder()->GetShape(pooling);
OP_REQUIRES_OK(ctx, status_or_shape.status());
OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(status_or_shape.ValueOrDie(),
&expected_out_shape));
OP_REQUIRES(ctx, expected_out_shape == out_backprop_shape,
errors::Unimplemented("The output dimensions do not match the "
"other input values."));
xla::PrimitiveType element_type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));

View File

@ -733,6 +733,7 @@ Status RearrangeFunctionArgumentPass::Run(
{"XlaLaunch", "function"},
};
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
bool fld_modified = false;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
@ -753,8 +754,14 @@ Status RearrangeFunctionArgumentPass::Run(
n->ClearAttr(func_attr);
func.set_name(new_func_name);
n->AddAttr(func_attr, func);
fld_modified = true;
}
}
if (fld_modified) {
TF_RETURN_IF_ERROR(
PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def));
}
if (VLOG_IS_ON(4)) {
DumpGraphToFile("rearrange_function_argument_after", *graph,

View File

@ -773,4 +773,17 @@ Status PropagateConstIntoFunctionalNodes(
return Status::OK();
}
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
FunctionLibraryDefinition* fld) {
GraphDef graph_def;
g.ToGraphDef(&graph_def);
FunctionLibraryDefinition reachable_functions =
fld->ReachableDefinitions(graph_def);
for (const string& func_name : fld->ListFunctionNames()) {
if (!reachable_functions.Find(func_name)) {
TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/op.h"
@ -197,6 +198,10 @@ Status PropagateConstIntoFunctionalNodes(
Graph* g, const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld);
// Prunes unreachable FunctionDefs from FunctionLibraryDefinition.
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
FunctionLibraryDefinition* fld);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_

View File

@ -58,18 +58,13 @@ class XlaCompilationAllocator : public Allocator {
// Make sure that even tensors with 0 elements have allocated
// buffers, so they get ids to track.
bool ShouldAllocateEmptyTensors() const override { return true; }
private:
// Don't run any constructors or destructors for complex objects,
// since there is no backing store for the tensor to run them
// on. strings are the only complex objects currently stored in
// Tensors. If others are added, this set of overrides must be
// extended to include them.
void RunStringCtor(string* p, size_t n) override {}
void RunStringDtor(string* p, size_t n) override {}
void RunResourceCtor(ResourceHandle* p, size_t n) override {}
void RunResourceDtor(ResourceHandle* p, size_t n) override {}
//
// NOTE: It is the caller's responsibility to track whether an allocated
// object is a buffer or an opaque handle. In particular, when this allocator
// is used, the caller must not run any constructors or destructors for
// complex objects, since there is no backing store for the tensor in which to
// place their outputs.
bool AllocatesOpaqueHandle() const override { return true; }
};
XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,

View File

@ -339,7 +339,7 @@ class XlaCompiler {
// here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
// allocate most or all available memory on the device, leaving none for the
// compiler to access, unless it can use TensorFlow's allocator.
xla::DeviceMemoryAllocator* device_allocator = nullptr;
se::DeviceMemoryAllocator* device_allocator = nullptr;
};
explicit XlaCompiler(Options options);

View File

@ -116,10 +116,9 @@ class XlaOpRegistry {
// If we should cluster operations returning DT_VARIANT.
bool cluster_variant_ops = false;
// If we should cluster the "Svd" op. The XLA implemenation of this op has
// some performance (b/128001705) and possibly correctness (b/127344411)
// issues so we avoid auto-clustering it for non XLA_* devices.
bool cluster_svd_op = false;
// Whether ops known to be slow or to have correctness issues should be
// auto-clustered.
bool cluster_slow_and_inaccurate_ops = false;
};
// Registers an XLA backend. `compilation_device_name` is the name of the

View File

@ -96,7 +96,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",
@ -117,7 +117,6 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:dump",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo_proto",
@ -126,6 +125,7 @@ cc_library(
"//tensorflow/compiler/xla/service:source_map_util",
"//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
"@llvm//:support",
@ -165,11 +165,11 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compile_only_service",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],

View File

@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/compile_only_service.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -39,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace xla {

View File

@ -22,12 +22,12 @@ limitations under the License.
namespace xla {
ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator(
DeviceMemoryAllocator* allocator) {
se::DeviceMemoryAllocator* allocator) {
device_allocator_ = allocator;
return *this;
}
DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const {
se::DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const {
return device_allocator_;
}

View File

@ -18,11 +18,11 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace xla {
@ -57,11 +57,11 @@ class ExecutableBuildOptions {
// want to run various algorithms on the device and pick the fastest one -- it
// might allocate buffers for use by these algorithms using this allocator.
//
// This does not need to be the same as the DeviceMemoryAllocator passed when
// running the executable.
// This does not need to be the same as the se::DeviceMemoryAllocator passed
// when running the executable.
ExecutableBuildOptions& set_device_allocator(
DeviceMemoryAllocator* allocator);
DeviceMemoryAllocator* device_allocator() const;
se::DeviceMemoryAllocator* allocator);
se::DeviceMemoryAllocator* device_allocator() const;
// Returns a string representation of the build options, suitable for
// debugging.
@ -77,7 +77,7 @@ class ExecutableBuildOptions {
Shape result_layout_;
bool result_layout_set_ = false;
absl::optional<DebugOptions> debug_options_;
DeviceMemoryAllocator* device_allocator_ = nullptr;
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
int num_replicas_ = 1;
};

View File

@ -528,7 +528,9 @@ XlaOp Asin(XlaOp x) {
XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); }
XlaOp Tan(XlaOp x) {
return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); });
}
// Hyperbolic trigonometric functions.
@ -574,9 +576,9 @@ XlaOp Acosh(XlaOp x) {
// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
// as 2*x and return log(2) + log(x).
//
// If x is negative, the above would give us some trouble, because we'd need to
// approximate x + sqrt(sqrt(x^2 + 1) - abs(x). But we're saved
// by the fact that asinh(-x) = -asinh(x).
// If x is negative, the above would give us some trouble; we can't approximate
// the result as x + abs(x) = 0! But we're saved by the fact that asinh(-x) =
// -asinh(x).
XlaOp Asinh(XlaOp x) {
XlaBuilder* b = x.builder();
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
@ -636,9 +638,39 @@ XlaOp Atanh(XlaOp x) {
});
}
XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); }
// Cosh(x) = (e^x + e^-x) / 2
// = e^(x + log(1/2)) + e^(-x + log(1/2)).
//
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
// inf.
//
// This incorrectly overflows to inf for two f32 input values, namely
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
XlaOp Cosh(XlaOp x) {
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
auto log_one_half = Log(ScalarLike(x, 0.5));
return Exp(x + log_one_half) + Exp(-x + log_one_half);
});
}
XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); }
// Sinh(x) = (e^x - e^-x) / 2
// = e^(x + log(1/2)) - e^(-x + log(1/2)).
//
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
// inf.
//
// This incorrectly overflows to +/-inf for two f32 input values, namely
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
XlaOp Sinh(XlaOp x) {
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
auto log_one_half = Log(ScalarLike(x, 0.5));
return Exp(x + log_one_half) - Exp(-x + log_one_half);
});
}
XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
XlaBuilder* builder = x.builder();

View File

@ -279,7 +279,7 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
const LiteralSlice& literal, int device_ordinal,
DeviceMemoryAllocator* allocator) {
se::DeviceMemoryAllocator* allocator) {
if (allocator == nullptr) {
allocator = backend().memory_allocator();
}

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/local_service.h"
@ -32,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace xla {
@ -137,7 +137,7 @@ class LocalClient : public Client {
// device is used.
StatusOr<ScopedShapedBuffer> LiteralToShapedBuffer(
const LiteralSlice& literal, int device_ordinal,
DeviceMemoryAllocator* allocator = nullptr);
se::DeviceMemoryAllocator* allocator = nullptr);
// Transfer the BorrowingLiteral to the device with the given ordinal.
StatusOr<TransferToServerResponse> TransferToLocalServer(

View File

@ -26,12 +26,13 @@ ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal(
int ExecutableRunOptions::device_ordinal() const { return device_ordinal_; }
ExecutableRunOptions& ExecutableRunOptions::set_allocator(
DeviceMemoryAllocator* allocator) {
stream_executor::DeviceMemoryAllocator* allocator) {
allocator_ = allocator;
return *this;
}
DeviceMemoryAllocator* ExecutableRunOptions::allocator() const {
stream_executor::DeviceMemoryAllocator* ExecutableRunOptions::allocator()
const {
return allocator_;
}

View File

@ -23,6 +23,7 @@ limitations under the License.
namespace stream_executor {
class Stream;
class Platform;
class DeviceMemoryAllocator;
} // namespace stream_executor
namespace Eigen {
@ -31,7 +32,6 @@ struct ThreadPoolDevice;
namespace xla {
class DeviceMemoryAllocator;
class DeviceAssignment;
class ExecutionProfile;
@ -39,8 +39,9 @@ class ExecutionProfile;
class ExecutableRunOptions {
public:
// Specifies the allocator to use during execution.
ExecutableRunOptions& set_allocator(DeviceMemoryAllocator* allocator);
DeviceMemoryAllocator* allocator() const;
ExecutableRunOptions& set_allocator(
stream_executor::DeviceMemoryAllocator* allocator);
stream_executor::DeviceMemoryAllocator* allocator() const;
// If set, this is the device to run the computation on. Valid device_ordinal
// values are: 0 to # of devices - 1. These values are identical to the device
@ -87,7 +88,7 @@ class ExecutableRunOptions {
int rng_seed() const;
private:
DeviceMemoryAllocator* allocator_ = nullptr;
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1;
const DeviceAssignment* device_assignment_ = nullptr;
stream_executor::Stream* stream_ = nullptr;

View File

@ -29,6 +29,8 @@ upper_tabs:
path: /xla/tiled_layout
- title: Using AOT compilation
path: /xla/tfcompile
- title: Writing custom calls
path: /xla/custom_call
- heading: Tutorials
- title: XLA compile API
path: /xla/tutorials/xla_compile

View File

@ -0,0 +1,329 @@
# XLA Custom Calls
This document describes how to write and use XLA "custom calls". Custom calls
let you invoke code written in a programming language like C++ or CUDA from an
XLA program.
Warning: Custom calls are a low-level power-user feature. It is easy to break
your program in difficult-to-debug (and even difficult-to-notice) ways using
custom-calls. You shouldn't use custom calls unless you're prepared to debug XLA
yourself when something goes wrong, and you should expect relatively less
assistance from XLA developers if you run into trouble.
Warning: The custom-call API/ABI is not currently stable. We don't intend to
change it capriciously, but it may change. Some possible future changes are
described below.
## Custom-call on CPU
You can create an HLO instruction which represents a custom-call via XLA's
client API. This is not exposed via TensorFlow as of writing.
For example, the following code uses a custom-call to compute
`A[i] = B[i % 128] + C[i]` on the CPU. (Of course you could -- and should! -- do
this with regular HLO.)
```c++
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(0, xla::ShapeUtil::CreateShape(F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(1, xla::ShapeUtil::CreateShape(F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*output_shape=*/ShapeUtil::CreateShape(F32, {2048}));
}
void do_custom_call(void* out, const void** in) {
float* out_buf = reinterpret_cast<float*>(out);
const float* in0 = reinterpret_cast<const float*>(in[0]);
const float* in1 = reinterpret_cast<const float*>(in[1]);
for (int i = 0; i < 2048; ++i) {
out_buf[i] = in0[i % 128] + in1[i];
}
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");
```
Notice that the function `do_custom_call` needs to know the dimensions of the
buffers it operates over. In this example we hardcode the sizes 128 and 2048. If
you don't want to do this, you can pass the dimensions in as parameters to the
call.
## Custom-call on GPU
The GPU custom call framework is somewhat different than that on the CPU. Here
is a CUDA example that does the same `A[i] = B[i % 128] + C[i]` computation as
the CPU code above.
```c++
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = threadIdx.x * blockSize.x + gridIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, void** buffers,
const char* opaque, size_t opaque_len) {
const float* in0 = reinterpret_cast<const float*>(buffers[0]);
const float* in1 = reinterpret_cast<const float*>(buffers[1]);
float* out = reinterpret_cast<float*>(buffers[2]);
const int64 block_dim = 64;
const int64 grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim,
/*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");
```
Notice first that the GPU custom call function *is still a function executed on
the CPU*. Our `do_custom_call` CPU function is responsible for enqueueing work
on the GPU. Here it launches a CUDA kernel, but it could also do something else,
like call cublas.
`buffers` is an array of pointers which lives on the host, and each element it
contains points to device (i.e. GPU) memory. The parameters come first, followed
by the output value. This is notably different from the CPU calling convention,
which has two params, `ins` and `out`. The main reason we diverge is to make it
possible to handle tuple-shaped inputs/outputs efficiently; see the section
below.
As in the CPU example, we've hardcoded the input and output buffer sizes into
our custom call. However unlike in the CPU case, passing the buffer sizes in as
operands to the custom call would not work well. Usually we need the buffer
sizes available to us on the CPU; e.g. when launching a kernel, we need to know
the block/grid dimensions to use. But if we were to pass the buffer sizes as
operands to our custom call, their values would live in GPU memory. We'd then
have to do an expensive synchronous device-to-host memcpy at the start of our
operation just to read the sizes.
To let you work around this, we provide the `opaque` parameter. You can set this
to an arbitrary string of bytes when you create the custom call:
```c++
std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*output_shape=*/ShapeUtil::CreateShape(F32, {2048}),
opaque);
```
Since `xla::Shape` has a protocol buffer representation, you could store this
serialized proto inside of `opaque` and deserialize it within your GPU
custom-call. Note however that although `xla::ShapeProto` does not change
frequently, it *does* change. Check the git log to see how it has changed in the
past.
## Passing tuples to custom-calls
Consider the following custom-call.
```c++
using xla::ShapeUtil;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeTuple(F32, {64}),
ShapeUtil::MakeTuple(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);
```
On both CPU and GPU, a tuple is represented in memory as an array of pointers.
In C++-pseudocode, parameter 0 above is laid out as follows.
```c++
// In-memory layout of parameter 0 from custom-call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];
void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;
void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;
```
Although the in-memory representation of tuples is the same in CPU and GPU, they
are handled differently in the CPU and GPU custom-call calling conventions.
### Tuple outputs as temp buffers
Tuple inputs to custom-calls are a convenience, but they aren't strictly
necessary. If we didn't support tuple inputs to custom calls, you could always
unpack the tuples using get-tuple-element before passing them to the custom
call.
On the other hand, tuple *outputs* do let you do things you couldn't otherwise.
The obvious reason to have tuple outputs is, that's how a custom call (or any
other XLA op) returns multiple independent arrays.
But less obviously, a tuple output is also a way to give your custom call temp
memory. Yes, an *output* can represent a temp buffer. Consider, an output buffer
has the property that the op can write to it, and it can read from it after it's
been written to. That's exactly what you want from a temp buffer.
In the example above, suppose we wanted to use the `F32[1024]` as a temp buffer.
Then we'd write the HLO just as above, and we'd simply never read tuple index 1
of the custom call's output.
### Tuples in CPU custom-calls
In CPU code, we have a function `do_custom_call(const void** ins, void* out)`.
`ins` is an array with just one element, which points to `param0`. The
subbuffers of `param0` are accessible by dereferencing that pointer, and the
subbuffers of `output_tuple` are accessible by dereferencing `out`.
### Tuples in GPU custom-calls
In GPU code, we have a function `do_custom_call(..., void** buffers, ...)`. In
this case `buffers` is a host array of *nine* device pointers, one for each
nested buffer. To generate the flat list, we iterate over the parameters and
output, and then do preorder traversal of their shapes. Concretely:
```c++
// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == param0
buffers[1] == subbuf0 or null
buffers[2] == subtuple or null
buffers[3] == subbuf1 or null
buffers[4] == subbuf2 or null
buffers[5] == subbuf3 or null
buffers[6] == output_tuple
buffers[7] == output_subbuf0
buffers[8] == output_subbuf1
```
The `or null` part is significant. A sub-buffer of an input tuple will be
non-null in the `buffers` list if XLA is able to statically analyze the program
and figure out the address of the sub-buffer. This is usually the case, but may
not be in programs with control flow and/or `select` ops over tuples.
A correct custom-call implementation that accepts a tuple as input must always
handle null input sub-buffers, by dereferencing the root tuple.
The rule is reversed for output buffers. The output sub-buffers will always be
populated, but it's up to the custom call to populate the root tuple at the end.
See the following code. Note that we leave out CUDA error handling for clarity,
but you'll be thankful if you do it, because otherwise it can be hard to tell
when a stream encounters an error.
```c++
void do_custom_call(CUstream stream, void** buffers, const char* opaque,
size_t opaque_len) {
bool needs_sync = false;
const float* subbuf0 = reinterpret_cast<const float*>(buffers[1]);
if (subbuf0 == nullptr) {
needs_sync = true;
cudaMemcpyAsync(&subbuf0, buffers[0], sizeof(void*),
cudaMemcpyDeviceToHost, stream);
}
const void** subtuple = reinterpret_cast<const void**>(buffers[2]);
if (subtuple == nullptr) {
needs_sync = true;
cudaMemcpyAsync(&subtuple, buffers[2], ...);
}
// ... similarly for other params ...
// Wait for copies enqueued above to complete.
if (needs_sync) {
cudaStreamSynchronize(stream);
}
needs_sync = false;
// Now that we have `subtuple`, we can get subbuf1 and subbuf2.
float* subbuf1 = buffers[3];
if (subbuf1 == nullptr) {
needs_sync = true;
cudaMemcpyAsync(&subbuf1, subtuple, ...);
}
float* subbuf2 = buffers[4];
if (subbuf2 == nullptr) {
needs_sync = true;
cudaMemcpyAsync(&subbuf2, subtuple + 1, ...);
}
// Wait for copies enqueued above to complete.
if (needs_sync) {
cudaStreamSynchronize(stream);
}
// ... actually run the kernel ...
// Fill the output tuple.
void* outputs[2] = {buffers[7], buffers[8]};
cudaMemcpyAsync(buffers[6], outputs, sizeof(outputs), cudaMemcpyHostToDevice,
stream);
// Necessary to force the cudaMemcpyAsync above to complete before `outputs`
// goes out of scope. A sync is only necessary in the tuple output case, and
// see below for a way to avoid this.
cudaStreamSynchronize(stream);
}
```
The `cudaStreamSynchronize` at the end of the function is unfortunate, as it's
not required in the non-tuple-output case, and it can be expensive. One way to
get around this would be to make `outputs` into a global variable and ensure
that the previous cudaMemcpyAsync completed before overwriting the global and
enqueueing another one. This is sketched below.
```
void do_custom_call(CUstream stream, void** buffers, const char* opaque,
size_t opaque_len) {
// ... Beginning of function is the same as above ...
// ... actually run the kernel ...
static std::atomic<bool> first_time{true};
static CUevent event;
static void* outputs[2];
if (first_time.fetch_and(false)) {
// First time running this function. Initialize `event`.
cuEventCreate(&event, CU_EVENT_DISABLE_TIMING);
} else {
// Not first time running this function. Wait for previous event to
// complete before touching `outputs`.
cuEventSynchronize(event);
}
// Fill the output tuple.
outputs[0] = buffers[7];
outputs[1] = buffers[8];
cudaMemcpyAsync(buffers[6], outputs, sizeof(outputs), cudaMemcpyHostToDevice,
stream);
// Unblock `event` after the memcpy completes.
cuEventRecord(event, stream);
}
```
This simple implementation would limit parallelism if you want to run this op on
multiple GPUs concurrently (or on one GPU with multiple streams); in that case
you might need multiple events and globals. We have seen one implementation of
this algorithm which keeps a pool of globals and events and periodically polls
them (perhaps on each call to the op) to garbage collect.

View File

@ -67,8 +67,8 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:lib",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional",
"@pybind11",
@ -109,9 +109,9 @@ cc_library(
hdrs = ["shared_device_buffer.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/container:flat_hash_set",
],
)
@ -131,11 +131,50 @@ tf_cc_test(
],
)
cc_library(
name = "local_client",
srcs = ["local_client.cc"],
hdrs = ["local_client.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
"-Wno-c++98-c++11-compat",
],
features = ["-use_header_modules"],
deps = [
":shared_device_buffer",
":types",
":worker_thread",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@pybind11",
],
)
tf_pybind_extension(
name = "xla_extension",
srcs = [
"local_client.cc",
"local_client.h",
"xla.cc",
],
copts = [
@ -146,22 +185,19 @@ tf_pybind_extension(
features = ["-use_header_modules"],
module_name = "xla_extension",
deps = [
":local_client",
":shared_device_buffer",
":types",
":worker_thread",
":xrt",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@pybind11",
"//third_party/python_runtime:headers", # buildcleaner: keep
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
@ -178,7 +214,7 @@ tf_pybind_extension(
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
"//tensorflow/compiler/xla/client/lib:svd",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:name_uniquer",
@ -186,9 +222,7 @@ tf_pybind_extension(
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
# Do NOT remove this dependency. The XLA Python extension must not
# depend on any part of TensorFlow at runtime, **including**
# libtensorflow_framework.so. The XLA module is deployed self-contained

View File

@ -78,7 +78,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
@ -101,8 +101,8 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
}
cpu::CustomCallTargetRegistry::Global()->Register(
std::string(fn_name.begin(), fn_name.end()), static_cast<void*>(capsule));
CustomCallTargetRegistry::Global()->Register(
fn_name, static_cast<void*>(capsule), "Host");
return Status::OK();
}
@ -147,7 +147,12 @@ Device::Device(se::StreamExecutor* executor, bool use_multiple_streams,
"py_xla_execute");
}
Device::~Device() { compute_stream_->parent()->SynchronizeAllActivity(); }
Device::~Device() {
bool ok = compute_stream_->parent()->SynchronizeAllActivity();
if (!ok) {
LOG(ERROR) << "SynchronizeAllActivity failed when destroying Device.";
}
}
void Device::ThenExecuteOnWorkerThread(se::Stream* stream,
std::function<void()> callback) const {
@ -155,7 +160,7 @@ void Device::ThenExecuteOnWorkerThread(se::Stream* stream,
[this, callback]() { worker_thread_->Schedule(std::move(callback)); });
}
StatusOr<std::unique_ptr<PyLocalClient>> PyLocalClient::Get(
StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
const std::string& platform_name, const std::string& xla_platform_name,
bool asynchronous) {
TF_ASSIGN_OR_RETURN(se::Platform * platform,
@ -168,7 +173,7 @@ StatusOr<std::unique_ptr<PyLocalClient>> PyLocalClient::Get(
options.set_platform(platform);
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
return absl::make_unique<PyLocalClient>(platform_name, client, asynchronous);
return std::make_shared<PyLocalClient>(platform_name, client, asynchronous);
}
PyLocalClient::PyLocalClient(std::string platform_name, LocalClient* client,
@ -210,9 +215,9 @@ StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
}
static StatusOr<PyLocalBuffer> TransferHostToDeviceAsync(
const PythonBufferTree& tree, int device_ordinal, PyLocalClient* client,
const Device& device) {
DeviceMemoryAllocator* allocator =
const PythonBufferTree& tree, int device_ordinal,
std::shared_ptr<PyLocalClient> client, const Device& device) {
se::DeviceMemoryAllocator* allocator =
client->client()->backend().memory_allocator();
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
@ -255,13 +260,13 @@ static StatusOr<PyLocalBuffer> TransferHostToDeviceAsync(
device.ThenReleaseOnWorkerThread(device.host_to_device_stream(),
device_buffer);
}
return PyLocalBuffer(shape, std::move(device_buffer), client);
return PyLocalBuffer(shape, std::move(device_buffer), std::move(client));
}
/* static */
StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(const py::object& argument,
PyLocalClient* client,
int device_ordinal) {
StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(
const py::object& argument, std::shared_ptr<PyLocalClient> client,
int device_ordinal) {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython");
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
@ -277,13 +282,13 @@ StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(const py::object& argument,
<< " device ordinal: " << device_ordinal;
const Device& device = client->device(device_ordinal);
TF_ASSIGN_OR_RETURN(
PyLocalBuffer buffer,
TransferHostToDeviceAsync(tree, device_ordinal, client, device));
TF_ASSIGN_OR_RETURN(PyLocalBuffer buffer,
TransferHostToDeviceAsync(tree, device_ordinal,
std::move(client), device));
device.ThenRelease(device.host_to_device_stream(), std::move(py_buffer_ref));
if (!device.asynchronous()) {
device.host_to_device_stream()->BlockHostUntilDone();
TF_RETURN_IF_ERROR(device.host_to_device_stream()->BlockHostUntilDone());
}
return buffer;
}
@ -291,7 +296,7 @@ StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(const py::object& argument,
/*static */ StatusOr<std::vector<PyLocalBuffer>>
PyLocalBuffer::FromPythonValues(
const std::vector<std::pair<py::object, int>>& arguments,
PyLocalClient* client) {
std::shared_ptr<PyLocalClient> client) {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPythonValues");
int num_arguments = static_cast<int>(arguments.size());
std::vector<PyLocalBuffer> outputs(num_arguments);
@ -344,7 +349,7 @@ PyLocalBuffer::FromPythonValues(
device.ThenRelease(device.host_to_device_stream(),
std::move(transfers[i].py_buffer_ref));
if (!device.asynchronous()) {
device.host_to_device_stream()->BlockHostUntilDone();
TF_RETURN_IF_ERROR(device.host_to_device_stream()->BlockHostUntilDone());
}
}
@ -355,8 +360,8 @@ PyLocalBuffer::FromPythonValues(
}
/* static */ StatusOr<PyLocalBuffer> PyLocalBuffer::MakeTuple(
const std::vector<PyLocalBuffer> buffers, PyLocalClient* client,
int device_ordinal) {
const std::vector<PyLocalBuffer> buffers,
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
std::vector<xla::Shape> host_shapes;
std::vector<std::shared_ptr<PySharedDeviceBuffer>> device_buffers;
host_shapes.reserve(buffers.size());
@ -367,7 +372,7 @@ PyLocalBuffer::FromPythonValues(
host_shapes.push_back(buffer.on_host_shape());
device_buffers.push_back(buffer.device_buffer());
}
DeviceMemoryAllocator* allocator =
se::DeviceMemoryAllocator* allocator =
client->client()->backend().memory_allocator();
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
@ -382,7 +387,7 @@ PyLocalBuffer::FromPythonValues(
device_buffers, transfer_manager, allocator,
device_ordinal, definition_event));
PyLocalBuffer buffer(ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer,
client);
std::move(client));
// TODO(phawkins): extend TransferManager so we do not need to form a full
// ShapedBuffer just to write the root tuple index table.
@ -393,8 +398,8 @@ PyLocalBuffer::FromPythonValues(
// Wait for the compute stream so that memory allocations are synchronized.
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
}
transfer_manager->WriteRootTupleIndexTable(device.host_to_device_stream(),
shaped_buffer);
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
device.host_to_device_stream(), shaped_buffer));
if (definition_event) {
definition_event->RecordOnStream(device.host_to_device_stream());
}
@ -404,7 +409,7 @@ PyLocalBuffer::FromPythonValues(
std::move(tuple_buffer));
}
if (!device.asynchronous()) {
device.host_to_device_stream()->BlockHostUntilDone();
TF_RETURN_IF_ERROR(device.host_to_device_stream()->BlockHostUntilDone());
}
return buffer;
@ -412,10 +417,10 @@ PyLocalBuffer::FromPythonValues(
PyLocalBuffer::PyLocalBuffer(
Shape on_host_shape, std::shared_ptr<PySharedDeviceBuffer> device_buffer,
PyLocalClient* client)
: on_host_shape_(std::move(on_host_shape)),
device_buffer_(std::move(device_buffer)),
client_(client) {}
std::shared_ptr<PyLocalClient> client)
: client_(std::move(client)),
on_host_shape_(std::move(on_host_shape)),
device_buffer_(std::move(device_buffer)) {}
StatusOr<py::object> PyLocalBuffer::ToPython() const {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython");
@ -462,10 +467,10 @@ StatusOr<std::vector<PyLocalBuffer>> PyLocalBuffer::DestructureTuple() {
PyLocalExecutable::PyLocalExecutable(
std::shared_ptr<LocalExecutable> executable,
DeviceAssignment device_assignment, PyLocalClient* client)
: executable_(std::move(executable)),
device_assignment_(std::move(device_assignment)),
client_(client) {}
DeviceAssignment device_assignment, std::shared_ptr<PyLocalClient> client)
: client_(std::move(client)),
executable_(std::move(executable)),
device_assignment_(std::move(device_assignment)) {}
std::vector<int> PyLocalExecutable::DeviceOrdinals() const {
int num_replicas = device_assignment_.replica_count();
@ -543,7 +548,7 @@ StatusOr<PyLocalBuffer> PyLocalExecutable::ExecuteHelper(
device.ThenReleaseOnWorkerThread(device.compute_stream(), executable_);
}
if (!device.asynchronous()) {
device.compute_stream()->BlockHostUntilDone();
TF_RETURN_IF_ERROR(device.compute_stream()->BlockHostUntilDone());
}
return PyLocalBuffer(on_host_shape, std::move(out_buffer), client_);
}
@ -652,7 +657,7 @@ StatusOr<std::vector<PyLocalBuffer>> PyLocalExecutable::ExecutePerReplica(
PyLocalExecutable::Compile(const XlaComputation& computation,
std::vector<Shape> argument_layouts,
const ExecutableBuildOptions* build_options,
PyLocalClient* client) {
std::shared_ptr<PyLocalClient> client) {
tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile");
std::vector<const Shape*> argument_layout_pointers;
argument_layout_pointers.reserve(argument_layouts.size());
@ -705,7 +710,7 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
return absl::make_unique<PyLocalExecutable>(
std::shared_ptr<LocalExecutable>(std::move(local_executable)),
std::move(device_assignment), client);
std::move(device_assignment), std::move(client));
}
} // namespace xla

View File

@ -169,12 +169,13 @@ class PyLocalClient {
public:
// Initializes a local XLA client for `platform_name`. Returns an error if no
// such platform exists, or if the platform has no visible devices.
static StatusOr<std::unique_ptr<PyLocalClient>> Get(
const std::string& platform_name, const std::string& xla_platform_id,
static StatusOr<std::shared_ptr<PyLocalClient>> Get(
const std::string& platform_name, const std::string& xla_platform_name,
bool asynchronous);
explicit PyLocalClient(std::string platform_name, LocalClient* client,
bool asynchronous);
virtual ~PyLocalClient() = default;
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
StatusOr<pybind11::object> TransferFromOutfeed(const Shape& shape,
@ -192,7 +193,7 @@ class PyLocalClient {
PythonRefManager& py_ref_manager() { return py_ref_manager_; }
private:
protected:
std::string platform_name_;
LocalClient* client_;
std::vector<std::unique_ptr<Device>> devices_;
@ -205,29 +206,30 @@ class PyLocalClient {
// Holds a reference from Python to one or more device buffers.
class PyLocalBuffer {
public:
static StatusOr<PyLocalBuffer> FromPython(const pybind11::object& argument,
PyLocalClient* client,
int device_ordinal);
static StatusOr<PyLocalBuffer> FromPython(
const pybind11::object& argument, std::shared_ptr<PyLocalClient> client,
int device_ordinal);
// Converts multiple (python object, device ordinal) pairs into
// PyLocalBuffers in parallel.
static StatusOr<std::vector<PyLocalBuffer>> FromPythonValues(
const std::vector<std::pair<pybind11::object, int>>& argument,
PyLocalClient* client);
std::shared_ptr<PyLocalClient> client);
static StatusOr<PyLocalBuffer> MakeTuple(
const std::vector<PyLocalBuffer> buffers, PyLocalClient* client,
int device_ordinal);
const std::vector<PyLocalBuffer> buffers,
std::shared_ptr<PyLocalClient> client, int device_ordinal);
PyLocalBuffer() = default;
PyLocalBuffer(Shape on_host_shape,
std::shared_ptr<PySharedDeviceBuffer> device_buffer,
PyLocalClient* client);
std::shared_ptr<PyLocalClient> client);
StatusOr<pybind11::object> ToPython() const;
const Shape& on_host_shape() const { return on_host_shape_; }
const std::shared_ptr<PySharedDeviceBuffer>& device_buffer() const {
return device_buffer_;
}
int device_ordinal() const { return device_buffer_->device_ordinal(); }
void Delete() {
device_buffer_ = nullptr;
@ -242,9 +244,9 @@ class PyLocalBuffer {
StatusOr<std::vector<PyLocalBuffer>> DestructureTuple();
private:
std::shared_ptr<PyLocalClient> client_ = nullptr;
Shape on_host_shape_;
std::shared_ptr<PySharedDeviceBuffer> device_buffer_;
PyLocalClient* client_ = nullptr;
};
// Represents a compiled computation that can be executed given handles to
@ -254,10 +256,12 @@ class PyLocalExecutable {
// Compiles a computation to an executable.
static StatusOr<std::unique_ptr<PyLocalExecutable>> Compile(
const XlaComputation& computation, std::vector<Shape> argument_layouts,
const ExecutableBuildOptions* build_options, PyLocalClient* client);
const ExecutableBuildOptions* build_options,
std::shared_ptr<PyLocalClient> client);
PyLocalExecutable(std::shared_ptr<LocalExecutable> executable,
DeviceAssignment device_assignment, PyLocalClient* client);
DeviceAssignment device_assignment,
std::shared_ptr<PyLocalClient> client);
int num_replicas() const {
return executable_->build_options().num_replicas();
@ -285,9 +289,9 @@ class PyLocalExecutable {
StatusOr<PyLocalBuffer> ExecuteHelper(
absl::Span<PyLocalBuffer* const> argument_handles, int replica);
std::shared_ptr<PyLocalClient> const client_;
std::shared_ptr<LocalExecutable> executable_;
const DeviceAssignment device_assignment_;
PyLocalClient* const client_;
};
} // namespace xla

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace xla {
@ -47,14 +47,14 @@ void BufferDefinitionEvent::WaitForEventOnStream(se::Stream* stream) {
static std::shared_ptr<PySharedDeviceBuffer>
BufferFromScopedShapedBufferIterator(
const Shape& on_device_shape, int device_ordinal,
DeviceMemoryAllocator* allocator,
se::DeviceMemoryAllocator* allocator,
ShapeTree<se::DeviceMemoryBase>::iterator* iterator,
const ShapeTree<se::DeviceMemoryBase>::iterator& end,
const std::shared_ptr<BufferDefinitionEvent>& definition_event) {
CHECK(*iterator != end);
OwningDeviceMemory device_memory((*iterator)->second, device_ordinal,
allocator);
se::OwningDeviceMemory device_memory((*iterator)->second, device_ordinal,
allocator);
(*iterator)->second = se::DeviceMemoryBase();
++*iterator;
@ -90,7 +90,7 @@ PySharedDeviceBuffer::FromScopedShapedBuffer(
/* static */ StatusOr<std::shared_ptr<PySharedDeviceBuffer>>
PySharedDeviceBuffer::MakeTuple(
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
TransferManager* transfer_manager, DeviceMemoryAllocator* allocator,
TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator,
int device_ordinal,
std::shared_ptr<BufferDefinitionEvent> definition_event) {
std::vector<Shape> child_shapes;
@ -102,7 +102,7 @@ PySharedDeviceBuffer::MakeTuple(
Shape shape = ShapeUtil::MakeTupleShape(child_shapes);
TF_ASSIGN_OR_RETURN(
OwningDeviceMemory device_memory,
se::OwningDeviceMemory device_memory,
allocator->Allocate(device_ordinal,
transfer_manager->GetByteSizeRequirement(shape)));
return std::make_shared<PySharedDeviceBuffer>(
@ -113,10 +113,10 @@ PySharedDeviceBuffer::MakeTuple(
/* static */ StatusOr<std::shared_ptr<PySharedDeviceBuffer>>
PySharedDeviceBuffer::MakeArray(
Shape on_device_shape, TransferManager* transfer_manager,
DeviceMemoryAllocator* allocator, int device_ordinal,
se::DeviceMemoryAllocator* allocator, int device_ordinal,
std::shared_ptr<BufferDefinitionEvent> definition_event) {
TF_ASSIGN_OR_RETURN(
OwningDeviceMemory device_memory,
se::OwningDeviceMemory device_memory,
allocator->Allocate(
device_ordinal,
transfer_manager->GetByteSizeRequirement(on_device_shape)));
@ -153,7 +153,7 @@ ShapedBuffer PySharedDeviceBuffer::AsShapedBuffer(
}
PySharedDeviceBuffer::PySharedDeviceBuffer(
Shape on_device_shape, OwningDeviceMemory device_memory,
Shape on_device_shape, se::OwningDeviceMemory device_memory,
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
std::shared_ptr<BufferDefinitionEvent> definition_event)
: on_device_shape_(std::move(on_device_shape)),

View File

@ -17,11 +17,11 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/owning_device_memory.h"
namespace xla {
@ -93,14 +93,14 @@ class PySharedDeviceBuffer {
// Makes a tuple buffer. Does not initialize the tuple table.
static StatusOr<std::shared_ptr<PySharedDeviceBuffer>> MakeTuple(
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
TransferManager* transfer_manager, DeviceMemoryAllocator* allocator,
TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator,
int device_ordinal,
std::shared_ptr<BufferDefinitionEvent> definition_event);
// Makes an uninitialized array buffer.
static StatusOr<std::shared_ptr<PySharedDeviceBuffer>> MakeArray(
Shape on_device_shape, TransferManager* transfer_manager,
DeviceMemoryAllocator* allocator, int device_ordinal,
se::DeviceMemoryAllocator* allocator, int device_ordinal,
std::shared_ptr<BufferDefinitionEvent> definition_event);
// Builds a ShapedBuffer view onto the buffers of 'tree'. Since
@ -113,7 +113,7 @@ class PySharedDeviceBuffer {
const std::vector<std::shared_ptr<PySharedDeviceBuffer>>& children() const {
return children_;
}
const OwningDeviceMemory& device_memory() const { return device_memory_; }
const se::OwningDeviceMemory& device_memory() const { return device_memory_; }
int device_ordinal() const { return device_memory_.device_ordinal(); }
const std::shared_ptr<BufferDefinitionEvent> definition_event() const {
return definition_event_;
@ -121,7 +121,7 @@ class PySharedDeviceBuffer {
PySharedDeviceBuffer() = default;
PySharedDeviceBuffer(
Shape on_device_shape, OwningDeviceMemory device_memory,
Shape on_device_shape, se::OwningDeviceMemory device_memory,
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
std::shared_ptr<BufferDefinitionEvent> definition_event);
@ -130,7 +130,7 @@ class PySharedDeviceBuffer {
// one-to-one with the tree of device buffers, so to avoid representational
// awkwardness we maintain on-host shapes separately.
Shape on_device_shape_;
OwningDeviceMemory device_memory_;
se::OwningDeviceMemory device_memory_;
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children_;
// An event that is triggered when the content of one or more buffers is

View File

@ -16,15 +16,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/types.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/stream_executor/owning_device_memory.h"
namespace xla {
namespace py = pybind11;
xla::StatusOr<PrimitiveType> NumpyTypeToPrimitiveType(
const py::dtype& np_type) {
xla::StatusOr<PrimitiveType> DtypeToPrimitiveType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, PrimitiveType>({
{{'b', 1}, PRED},
@ -50,6 +49,42 @@ xla::StatusOr<PrimitiveType> NumpyTypeToPrimitiveType(
return it->second;
}
xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
switch (type) {
case PRED:
return py::dtype::of<bool>();
case S8:
return py::dtype::of<int8>();
case S16:
return py::dtype::of<int16>();
case S32:
return py::dtype::of<int32>();
case S64:
return py::dtype::of<int64>();
case U8:
return py::dtype::of<uint8>();
case U16:
return py::dtype::of<uint16>();
case U32:
return py::dtype::of<uint32>();
case U64:
return py::dtype::of<uint64>();
case F16:
return py::dtype("e");
case F32:
return py::dtype::of<float>();
case F64:
return py::dtype::of<double>();
case C64:
return py::dtype::of<std::complex<float>>();
case C128:
return py::dtype::of<std::complex<double>>();
default:
return Unimplemented("Unimplemented primitive type %s",
PrimitiveType_Name(type));
}
}
// Returns a numpy-style format descriptor string for `type`.
StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type) {
switch (type) {
@ -159,4 +194,20 @@ StatusOr<PythonBufferTree> GetPythonBufferTree(const py::object& argument) {
return tree;
}
py::tuple IntSpanToTuple(absl::Span<int64 const> xs) {
py::tuple out(xs.size());
for (int i = 0; i < xs.size(); ++i) {
out[i] = py::int_(xs[i]);
}
return out;
}
std::vector<int64> IntSequenceToVector(const py::object& sequence) {
std::vector<int64> output;
for (auto item : sequence) {
output.push_back(item.cast<int64>());
}
return output;
}
} // namespace xla

View File

@ -32,29 +32,48 @@ limitations under the License.
namespace xla {
// Converts a pybind11-style NumPy dtype to a PrimitiveType.
StatusOr<PrimitiveType> NumpyTypeToPrimitiveType(
const pybind11::dtype& np_type);
// Helper that converts a failing StatusOr to an exception.
// For use only inside pybind11 code.
template <typename T>
T ValueOrThrow(StatusOr<T> v) {
if (!v.ok()) {
throw std::runtime_error(v.status().ToString());
}
return v.ConsumeValueOrDie();
}
// Converts a NumPy dtype to a PrimitiveType.
StatusOr<PrimitiveType> DtypeToPrimitiveType(const pybind11::dtype& np_type);
// Converts a PrimitiveType to a Numpy dtype.
StatusOr<pybind11::dtype> PrimitiveTypeToDtype(PrimitiveType type);
// Converts a literal to (possibly-nested tuples of) NumPy arrays.
// The literal's leaf arrays are not copied; instead the NumPy arrays share
// buffers with the literals. Takes ownership of `literal` and keeps the
// necessary pieces alive using Python reference counting.
// Requires the GIL.
StatusOr<pybind11::object> LiteralToPython(
std::unique_ptr<xla::Literal> literal);
StatusOr<pybind11::object> LiteralToPython(std::unique_ptr<Literal> literal);
// Converts a Python object into an XLA shape and a vector of leaf buffers.
// The leaf buffers correspond to a depth-first, left-to-right traversal of
// the Python value.
// Requires the GIL.
struct PythonBufferTree {
absl::InlinedVector<xla::BorrowingLiteral, 1> leaves;
xla::Shape shape;
absl::InlinedVector<BorrowingLiteral, 1> leaves;
Shape shape;
};
StatusOr<PythonBufferTree> GetPythonBufferTree(
const pybind11::object& argument);
// Converts a sequence of int64s to a Python tuple of ints.
// Pybind11 by default converts a std::vector<int64> to a Python list; for
// shapes we frequently want a tuple instead.
pybind11::tuple IntSpanToTuple(absl::Span<int64 const> xs);
// Converts a Python sequence of integers to a std::vector<int64>
std::vector<int64> IntSequenceToVector(const pybind11::object& sequence);
} // namespace xla
// This namespace is a documented pybind11 extension point.
@ -161,7 +180,7 @@ struct type_caster<xla::BorrowingLiteral> {
for (int i = 0; i < array.ndim(); ++i) {
dims[i] = array.shape(i);
}
auto type = xla::NumpyTypeToPrimitiveType(array.dtype());
auto type = xla::DtypeToPrimitiveType(array.dtype());
if (!type.ok()) {
throw std::runtime_error(type.status().ToString());
}

View File

@ -16,10 +16,11 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/hash/hash.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/python/xrt.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@ -129,34 +129,95 @@ PYBIND11_MODULE(xla_extension, m) {
.value("TOKEN", TOKEN);
// Shapes
py::class_<Shape>(m, "Shape")
py::class_<Shape> shape_class(m, "Shape");
shape_class
.def_static(
"Tuple",
"tuple_shape",
[](std::vector<Shape> shapes) -> Shape {
return ShapeUtil::MakeTupleShape(shapes);
},
"Makes a tuple shape.")
"Constructs a tuple shape.")
.def_static(
"Array",
[](PrimitiveType type, std::vector<int64> dims,
absl::optional<std::vector<int64>> layout) -> Shape {
if (layout) {
return ShapeUtil::MakeShapeWithLayout(type, dims, *layout);
"array_shape",
[](PrimitiveType type, py::object dims_seq,
absl::optional<py::object> layout_seq) -> Shape {
std::vector<int64> dims = IntSequenceToVector(dims_seq);
if (layout_seq) {
std::vector<int64> layout = IntSequenceToVector(*layout_seq);
return ShapeUtil::MakeShapeWithLayout(type, dims, layout);
} else {
Shape shape = ShapeUtil::MakeShape(type, dims);
shape.clear_layout();
return shape;
}
},
"Makes an array shape.", py::arg("type"), py::arg("dims"),
"Constructs an array shape.", py::arg("type"), py::arg("dims"),
py::arg("layout") = absl::nullopt)
.def_static(
"array_shape",
[](py::dtype dtype, py::object dims_seq,
absl::optional<py::object> layout_seq) -> Shape {
PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype));
std::vector<int64> dims = IntSequenceToVector(dims_seq);
if (layout_seq) {
std::vector<int64> layout = IntSequenceToVector(*layout_seq);
return ShapeUtil::MakeShapeWithLayout(type, dims, layout);
} else {
Shape shape = ShapeUtil::MakeShape(type, dims);
shape.clear_layout();
return shape;
}
},
"Constructs an array shape.", py::arg("type"), py::arg("dims"),
py::arg("layout") = absl::nullopt)
.def("dimensions",
static_cast<const std::vector<int64>& (Shape::*)() const>(
&Shape::dimensions))
.def("element_type", &Shape::element_type)
[](const Shape& shape) -> py::tuple {
return IntSpanToTuple(shape.dimensions());
})
.def("xla_element_type", &Shape::element_type)
.def("element_type",
[](const Shape& shape) {
return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
})
.def("numpy_dtype",
[](const Shape& shape) {
if (shape.IsTuple()) {
return py::dtype("O");
}
return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
})
.def("is_tuple", &Shape::IsTuple)
.def("is_array", &Shape::IsArray)
.def("rank", &Shape::rank)
.def("to_serialized_proto",
[](const Shape& shape) {
ShapeProto proto = shape.ToProto();
return py::bytes(proto.SerializeAsString());
})
.def("tuple_shapes",
static_cast<const std::vector<Shape>& (Shape::*)() const>(
&Shape::tuple_shapes))
[](const Shape& shape) {
return std::vector<Shape>(shape.tuple_shapes());
})
.def(
"with_major_to_minor_layout_if_absent",
[](const Shape& shape) {
Shape out = shape;
ShapeUtil::ForEachMutableSubshape(
&out, [](Shape* subshape, const ShapeIndex&) {
if (!subshape->has_layout()) {
LayoutUtil::SetToDefaultLayout(subshape);
}
});
return out;
},
"Returns a copy of a shape with missing layouts set to "
"major-to-minor.")
.def("__eq__", [](const Shape& shape,
const Shape& other) { return shape == other; })
.def("__ne__", [](const Shape& shape,
const Shape& other) { return shape != other; })
.def("__hash__",
[](const Shape& shape) { return absl::Hash<Shape>()(shape); })
.def("__repr__", [](const Shape& shape) {
return shape.ToString(/*print_layouts=*/true);
});
@ -171,10 +232,10 @@ PYBIND11_MODULE(xla_extension, m) {
*program_shape.mutable_result() = result;
return program_shape;
}))
.def("Parameters",
.def("parameter_shapes",
static_cast<const std::vector<Shape>& (ProgramShape::*)() const>(
&ProgramShape::parameters))
.def("Result", &ProgramShape::result)
.def("result_shape", &ProgramShape::result)
.def("__repr__", &ProgramShape::ToString);
// Literals
@ -211,22 +272,25 @@ PYBIND11_MODULE(xla_extension, m) {
// CPU custom-call targets.
m.def("RegisterCpuCustomCallTarget", &RegisterCpuCustomCallTarget);
// The LocalClient object allows dynamic attributes to allow external backends
// (e.g., TPU) to stash private data in the client.
py::class_<PyLocalClient>(m, "LocalClient", py::dynamic_attr())
.def_static("Get", &PyLocalClient::Get)
py::class_<PyLocalClient, std::shared_ptr<PyLocalClient>>(m, "LocalClient")
.def_static("Get", &PyLocalClient::Get, py::arg("platform"),
py::arg("xla_platform_id"), py::arg("asynchronous"))
.def("DeviceCount", &PyLocalClient::device_count)
.def("TransferToInfeed", &PyLocalClient::TransferToInfeed)
.def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed);
py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
.def_static("FromPython", &PyLocalBuffer::FromPython)
.def_static("FromPythonValues", &PyLocalBuffer::FromPythonValues)
.def_static("MakeTuple", &PyLocalBuffer::MakeTuple)
.def("Delete", &PyLocalBuffer::Delete)
.def("DestructureTuple", &PyLocalBuffer::DestructureTuple)
.def("ToPython", &PyLocalBuffer::ToPython)
.def("shape", &PyLocalBuffer::on_host_shape);
.def_static("from_python", &PyLocalBuffer::FromPython)
.def_static("from_python_values", &PyLocalBuffer::FromPythonValues)
.def_static("make_tuple", &PyLocalBuffer::MakeTuple)
.def("delete", &PyLocalBuffer::Delete)
.def("destructure", &PyLocalBuffer::DestructureTuple)
.def("to_py", &PyLocalBuffer::ToPython)
.def("shape", &PyLocalBuffer::on_host_shape)
.def("device", &PyLocalBuffer::device_ordinal)
.def("is_deleted", [](const PyLocalBuffer& buffer) {
return buffer.device_buffer() == nullptr;
});
py::class_<PyLocalExecutable>(m, "LocalExecutable")
.def_static("Compile", &PyLocalExecutable::Compile,
@ -301,7 +365,12 @@ PYBIND11_MODULE(xla_extension, m) {
// XlaBuilder.
py::module ops = m.def_submodule("ops", "XLA operations");
ops.def("AllReduce",
static_cast<XlaOp (*)(
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
const absl::optional<ChannelHandle>&)>(&CrossReplicaSum));
ops.def("AllToAll", &AllToAll);
ops.def("CollectivePermute", &CollectivePermute);
ops.def("CrossReplicaSum",
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
&CrossReplicaSum));

View File

@ -28,7 +28,6 @@ import os
import numpy as np
import six
from six.moves import xrange
# Note this module does *not* depend on any Python protocol buffers. The XLA
# Python bindings are currently packaged both as part of jaxlib and as part
@ -71,18 +70,10 @@ class Backend(object):
for pyval, device in pyvals_and_devices
]
@abc.abstractmethod
def delete_buffer(self, c_buffer):
"""Deletes buffer `c_buffer`."""
@abc.abstractmethod
def make_tuple(self, c_buffers, device_ordinal):
"""Makes a tuple from a sequence of backend buffer objects."""
@abc.abstractmethod
def destructure_tuple(self, c_buffer):
"""Destructures a tuple buffer into a sequence of buffers."""
@abc.abstractmethod
def compile(self, computation, compile_options):
"""Compiles a computation. Returns an executable."""
@ -103,47 +94,38 @@ class Backend(object):
class LocalBackend(Backend):
"""XLA backend implemented using the in-process xla::LocalClient API."""
def __init__(self, platform=None, xla_platform_id=None, asynchronous=False):
def __init__(self, platform, client):
"""Creates a new LocalBackend.
Args:
platform: A string; the user-visible platform name, e.g. 'gpu'.
xla_platform_id: A string; XLA's name for the platform, e.g., 'CUDA'.
asynchronous: A boolean; should we enable asynchronous execution?
(Experimental.)
client: An _xla.PyLocalClient object.
"""
super(LocalBackend, self).__init__(platform)
self.client = _xla.LocalClient.Get(platform, xla_platform_id, asynchronous)
self.client = client
def device_count(self):
return self.client.DeviceCount()
def buffer_from_pyval(self, pyval, device=0):
return _xla.PyLocalBuffer.FromPython(pyval, self.client, device)
return _xla.PyLocalBuffer.from_python(pyval, self.client, device)
def buffers_from_pyvals(self, pyvals_and_devices):
return _xla.PyLocalBuffer.FromPythonValues(pyvals_and_devices, self.client)
def delete_buffer(self, c_buffer):
c_buffer.Delete()
return _xla.PyLocalBuffer.from_python_values(pyvals_and_devices,
self.client)
def make_tuple(self, c_buffers, device_ordinal):
return _xla.PyLocalBuffer.MakeTuple(c_buffers, self.client, device_ordinal)
def destructure_tuple(self, c_buffer):
return c_buffer.DestructureTuple()
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device_ordinal)
def compile(self, c_computation, compile_options):
options = _xla.ExecutableBuildOptions()
options.num_replicas = compile_options.num_replicas
if compile_options.argument_layouts:
argument_layouts = [
s.as_xla_shape() for s in compile_options.argument_layouts
]
argument_layouts = compile_options.argument_layouts
else:
argument_layouts = c_computation.GetProgramShape().Parameters()
argument_layouts = c_computation.GetProgramShape().parameter_shapes()
if compile_options.result_layout:
options.result_layout = compile_options.result_layout.as_xla_shape()
options.result_layout = compile_options.result_layout
options.debug_options.xla_cpu_fast_math_honor_infs = True
options.debug_options.xla_cpu_fast_math_honor_nans = True
return _xla.LocalExecutable.Compile(c_computation, argument_layouts,
@ -159,10 +141,22 @@ class LocalBackend(Backend):
return executable.ExecutePerReplica(per_replica_args)
def _cpu_backend_factory():
client = _xla.LocalClient.Get(
platform='cpu', xla_platform_id='Host', asynchronous=True)
return LocalBackend(platform='cpu', client=client)
def _gpu_backend_factory():
client = _xla.LocalClient.Get(
platform='gpu', xla_platform_id='CUDA', asynchronous=False)
return LocalBackend(platform='gpu', client=client)
# Backend factories, keyed by user-visible name, in increasing priority order.
_local_backend_factories = collections.OrderedDict([
('cpu', lambda: LocalBackend(platform='cpu', xla_platform_id='Host')),
('gpu', lambda: LocalBackend(platform='gpu', xla_platform_id='CUDA')),
('cpu', _cpu_backend_factory),
('gpu', _gpu_backend_factory),
])
@ -291,95 +285,12 @@ def dtype_to_etype(dtype):
return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
class Buffer(object):
"""Represents a handle to data owned by XLA.
The referent is ready for use in executing a local, compiled
Computation. On XLA platforms involving a device (e.g. GPU), this
means the referent is in device memory.
"""
def __init__(self, c_buffer, backend, device):
self.c_buffer = c_buffer
self._backend = backend
self._device = device
@staticmethod
def from_pyval(pyval, device=0, backend=None):
"""Copies the `pyval` to a freshly allocated on-device buffer."""
backend = backend or get_local_backend()
pyval = require_numpy_array_layout(pyval)
cbuf = backend.buffer_from_pyval(pyval, device)
return Buffer(cbuf, backend, device)
@staticmethod
def from_pyvals(pyvals_and_devices, backend=None):
"""Copies multiple Python values to freshly allocated on-device buffers.
Arguments:
pyvals_and_devices: a list of `(pyval, device)` pairs, where `pyval` is
a Python value to copy (e.g., a NumPy array), and `device` is an integer
device ordinal.
backend: a Backend object, or `None` to use the default local backend.
Returns:
A list of `Buffer` objects corresponding to `pyvals_and_devices`.
"""
backend = backend or get_local_backend()
pyvals_and_devices = [(require_numpy_array_layout(pyval), device)
for pyval, device in pyvals_and_devices]
cbufs = backend.buffers_from_pyvals(pyvals_and_devices)
return [
Buffer(cbuf, backend, device)
for cbuf, (_, device) in zip(cbufs, pyvals_and_devices)
]
@staticmethod
def make_tuple(buffers, backend=None, device=0):
backend = backend or get_local_backend()
buf = backend.make_tuple([b.c_buffer for b in buffers],
device_ordinal=device)
return Buffer(buf, backend, device)
def to_py(self):
return self.c_buffer.ToPython()
def shape(self):
return _wrap_shape(self.c_buffer.shape())
def device(self):
return self._device
def delete(self):
if self.c_buffer is not None:
self._backend.delete_buffer(self.c_buffer)
self.c_buffer = None
def destructure(self):
"""Assuming a tuple buffer, unpack it into constituent tuple elements."""
assert self.c_buffer is not None
result = self._backend.destructure_tuple(self.c_buffer)
return tuple(
Buffer(sub_buffer, device=self._device, backend=self._backend)
for sub_buffer in result)
def is_deleted(self):
return self.c_buffer is None
# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops
# compatibility with Jaxlib versions older than 0.1.13.
LocalBuffer = Buffer
class Format(enum.IntEnum):
"""Python copy of the Format protocol buffer enum."""
INVALID_FORMAT = 0
DENSE = 1
SPARSE = 2
Shape = _xla.Shape
Shape.__doc__ = """
A Shape is an object defined in C++ that duck types like the following class:
class Shape(object):
"""Represents an XLA shape.
'''Represents an XLA shape.
A shape is either an array shape, having rank-many integer
dimensions and an element type (represented by a Numpy dtype), or it
@ -388,188 +299,120 @@ class Shape(object):
type shape =
TupleShape of shape list
| ArrayShape of { dimensions: int list; element_type: dtype }
'''
Callers are expected to instantiate this class only via the static
constructors: tuple_shape, array_shape, and from_pyval.
@staticmethod
def tuple_shape(tuple_shapes) -> Shape:
"Construct a tuple shape."
@staticmethod
def array_shape(element_type, dimensions, minor_to_major=None) -> Shape:
@staticmethod
def from_pyval(pyval) -> Shape:
"Returns a Shape that describes a tuple-tree of Numpy arrays."
def __eq__(self, other: Shape) -> bool:
def __ne__(self, other: Shape) -> bool:
def __hash__(self):
def __repr__(self):
def is_tuple(self) -> bool:
def is_array(self) -> bool:
def tuple_shapes(self) -> [Shape]:
def numpy_dtype(self) -> np.dtype:
"Like element_type(), but returns dtype('O') for a tuple shape."
def xla_element_type(self) -> PrimitiveType:
def element_type(self) -> np.dtype:
def dimensions(self) -> (int, int, ...):
def rank(self) -> int:
def minor_to_major(self) -> [int]:
def with_major_to_minor_layout_if_absent(self) -> Shape:
"Returns a copy with missing layouts set to major-to-minor."
def to_serialized_proto(self) -> bytes:
"Returns 'shape' as a serialized proto."
"""
ProgramShape = _xla.ProgramShape
ProgramShape.__doc__ = """
A ProgramShape is a C++ object that duck types like the following class.
class ProgramShape(object):
def __init__(self, parameter_shapes, result_shape):
def parameter_shapes(self) -> [Shape]:
def result_shape(self) -> Shape:
def __repr__(self):
"""
class Buffer(object):
"""Represents a handle to data owned by XLA.
The referent is ready for use in executing a local, compiled
Computation. On XLA platforms involving a device (e.g. GPU), this
means the referent is in device memory.
"""
@staticmethod
def tuple_shape(tuple_shapes):
"""Construct a tuple shape."""
if (not isinstance(tuple_shapes, (tuple, list)) or
not all(isinstance(t, Shape) for t in tuple_shapes)):
raise TypeError('tuple_shapes must be a tuple of Shapes')
return Shape(tuple_shapes, tuple)
def from_pyval(pyval, device=0, backend=None):
"""Copies the `pyval` to a freshly allocated on-device buffer."""
backend = backend or get_local_backend()
pyval = require_numpy_array_layout(pyval)
return backend.buffer_from_pyval(pyval, device)
@staticmethod
def array_shape(element_type, dimensions, minor_to_major=None):
"""Construct an array shape."""
if (not isinstance(dimensions, tuple) or
not all(isinstance(i, int) for i in dimensions)):
dimensions = tuple(int(i) for i in dimensions)
return Shape(
dimensions, np.dtype(element_type), minor_to_major=minor_to_major)
def from_pyvals(pyvals_and_devices, backend=None):
"""Copies multiple Python values to freshly allocated on-device buffers.
@staticmethod
def from_pyval(pyval):
"""Returns a Shape that describes a tuple-tree of Numpy arrays."""
def convert(pyval):
if isinstance(pyval, tuple):
return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
else:
pyval = require_numpy_array_layout(pyval)
return Shape.array_shape(pyval.dtype, np.shape(pyval))
return convert(pyval)
def __init__(self, dimensions, dtype, minor_to_major=None):
assert isinstance(dimensions, tuple)
self._dimensions = dimensions
self._dtype = dtype
self._is_tuple = dtype == tuple
self._minor_to_major = minor_to_major
self._check_minor_to_major()
def __eq__(self, other):
# pylint: disable=protected-access
return (self._dtype == other._dtype and
self._dimensions == other._dimensions and
self._minor_to_major == other._minor_to_major)
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash((self._dtype, self._dimensions, self._minor_to_major))
def __repr__(self):
return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, '
'_is_tuple={!r}, _minor_to_major={!r})').format(
self._dtype, self._dimensions, self._is_tuple,
self._minor_to_major)
def is_tuple(self):
return self._is_tuple
def is_array(self):
return not self._is_tuple
def tuple_shapes(self):
if not self.is_tuple():
raise ValueError('not a tuple shape')
return self._dimensions
def numpy_dtype(self):
"""Like element_type(), but returns dtype('O') in case of a tuple shape."""
if self.is_tuple():
return np.dtype(np.object)
else:
return self.element_type()
def xla_element_type(self):
return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.numpy_dtype())]
def element_type(self):
if not self.is_array():
raise ValueError('not an array shape')
return self._dtype
def dimensions(self):
if not self.is_array():
raise ValueError('not an array shape')
return self._dimensions
def rank(self):
return len(self.dimensions())
def minor_to_major(self):
return self._minor_to_major
def map_leaves(self, f):
"""Map f over each leaf-level array subshape.
Args:
f: The function to apply. Whenever f returns None, the identity is applied
instead.
Arguments:
pyvals_and_devices: a list of `(pyval, device)` pairs, where `pyval` is a
Python value to copy (e.g., a NumPy array), and `device` is an integer
device ordinal.
backend: a Backend object, or `None` to use the default local backend.
Returns:
A new Shape with the mapped leaves.
A list of `Buffer` objects corresponding to `pyvals_and_devices`.
"""
if self.is_tuple():
children = tuple(child.map_leaves(f) for child in self.tuple_shapes())
return Shape.tuple_shape(children)
backend = backend or get_local_backend()
pyvals_and_devices = [(require_numpy_array_layout(pyval), device)
for pyval, device in pyvals_and_devices]
return backend.buffers_from_pyvals(pyvals_and_devices)
@staticmethod
def make_tuple(buffers, backend=None, device=0):
backend = backend or get_local_backend()
return backend.make_tuple(buffers, device_ordinal=device)
# Buffer is not an instantiable type and exists only for its static methods.
# The underlying buffer objects are C++ object with the following
# API:
# def to_py(self):
# def shape(self) -> Shape:
# def device(self) -> int:
# def delete(self):
# def destructure(self) -> [Buffer]
# def is_deleted(self) -> bool:
#
# TODO(phawkins): remove Buffer and its static methods completely, have
# clients call methods on Backend to create buffers.
# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops
# compatibility with Jaxlib versions older than 0.1.13.
LocalBuffer = Buffer
def shape_from_pyval(pyval):
"""Returns a Shape that describes a tuple-tree of Numpy arrays."""
def convert(pyval):
if isinstance(pyval, tuple):
return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
else:
mapped = f(self)
return self if mapped is None else mapped
pyval = require_numpy_array_layout(pyval)
return Shape.array_shape(pyval.dtype, np.shape(pyval))
def _check_minor_to_major(self):
mtm = self._minor_to_major
if self.is_tuple():
assert mtm is None, self
if mtm is not None:
assert self.rank() == len(mtm), self
assert sorted(mtm) == list(range(len(mtm))), self
def update_minor_to_major(self, minor_to_major):
if not self.is_array():
raise ValueError('not an array shape')
if not isinstance(minor_to_major, tuple):
raise TypeError('minor_to_major must be a tuple')
updated = Shape.array_shape(self.element_type(), self.dimensions(),
minor_to_major)
updated._check_minor_to_major() # pylint: disable=protected-access
return updated
def with_major_to_minor_layout_if_absent(self):
"""Returns a copy of a shape with missing layouts set to major-to-minor."""
def f(a):
if a.minor_to_major():
return None
return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1)))
return self.map_leaves(f)
def serialize(self, proto):
"""Serializes 'shape' into proto."""
if self.is_tuple():
proto.element_type = int(PrimitiveType.TUPLE)
for shape in self.tuple_shapes():
shape.serialize(proto.tuple_shapes.add())
else:
proto.element_type = int(self.xla_element_type())
proto.dimensions.extend(self.dimensions())
proto.is_dynamic_dimension.extend([False for _ in self.dimensions()])
if self.minor_to_major():
proto.layout.format = Format.DENSE
proto.layout.minor_to_major.extend(self.minor_to_major())
def as_xla_shape(self):
if self.is_tuple():
return _xla.Shape.Tuple([x.as_xla_shape() for x in self.tuple_shapes()])
return _xla.Shape.Array(self.xla_element_type(), self.dimensions(),
self.minor_to_major())
ProgramShape = collections.namedtuple('ProgramShape',
('parameter_shapes', 'result_shape'))
def _wrap_shape(xla_shape):
element_type = xla_shape.element_type()
if element_type == PrimitiveType.TUPLE:
shapes = tuple(_wrap_shape(sub) for sub in xla_shape.tuple_shapes())
return Shape.tuple_shape(shapes)
else:
dtype = XLA_ELEMENT_TYPE_TO_DTYPE[element_type]
return Shape.array_shape(dtype, xla_shape.dimensions())
def _wrap_program_shape(program_shape):
return ProgramShape([_wrap_shape(arg) for arg in program_shape.Parameters()],
_wrap_shape(program_shape.Result()))
return convert(pyval)
def require_numpy_array_layout(value):
@ -612,8 +455,7 @@ def transfer_from_outfeed(shape, device_ordinal=0):
# TODO(phawkins): support non-default backends.
backend = get_local_backend()
return backend.client.TransferFromOutfeed(
shape.with_major_to_minor_layout_if_absent().as_xla_shape(),
device_ordinal)
shape.with_major_to_minor_layout_if_absent(), device_ordinal)
class CompileOptions(object):
@ -699,10 +541,10 @@ class Computation(object):
return Executable(c, backend=backend)
def GetProgramShape(self):
return _wrap_program_shape(self._c_computation.GetProgramShape())
return self._c_computation.GetProgramShape()
def GetReturnValueShape(self):
return _wrap_shape(self._c_computation.GetProgramShape().Result())
return self._c_computation.GetProgramShape().result_shape()
class Executable(object):
@ -717,14 +559,12 @@ class Executable(object):
"""Returns a list containing the device ordinals for each replica."""
return self._device_ordinals
def Execute(self, arguments=(), check_for_deleted_args=True):
def Execute(self, arguments=None, check_for_deleted_args=True):
"""Execute on one replica with Buffer arguments and return value."""
arguments = arguments or []
if check_for_deleted_args and any(arg.is_deleted() for arg in arguments):
raise ValueError('Executing with deleted local buffer argument')
raw_args = [arg.c_buffer for arg in arguments]
output_buffer = self._backend.execute(self._c_executable, raw_args)
return Buffer(
output_buffer, backend=self._backend, device=self._device_ordinals[0])
return self._backend.execute(self._c_executable, arguments)
def ExecutePerReplica(self, arguments=None):
"""Execute on many replicas with Buffer arguments and return value.
@ -753,23 +593,8 @@ class Executable(object):
'Executing on device {} with argument from device {}'.format(
self._device_ordinals[replica], arg.device()))
# Pull out argument buffer handles
# pylint: disable=g-complex-comprehension
stripped_args = [
[arg.c_buffer for arg in replica_args] for replica_args in arguments
]
# Execute
output_buffers = self._backend.execute_replicated(self._c_executable,
stripped_args)
# Wrap output handles in Buffer instances
return tuple(
Buffer(
output_buffer,
backend=self._backend,
device=self._device_ordinals[replica])
for replica, output_buffer in enumerate(output_buffers))
return self._backend.execute_replicated(self._c_executable, arguments)
def ExecuteWithPythonValues(self, arguments=()):
"""Execute on one replica with Python values as arguments and output."""
@ -877,7 +702,7 @@ class ComputationBuilder(object):
return Computation(self._builder.Build(), backend=backend)
def GetShape(self, operand):
return _wrap_shape(self._builder.GetShape(operand))
return self._builder.GetShape(operand)
def SetOpMetadata(self, op_metadata):
"""Set metadata for operations that are about to be enqueued."""
@ -896,9 +721,8 @@ class ComputationBuilder(object):
Returns:
An XlaOp.
"""
return ops.Infeed(
self._builder,
shape.with_major_to_minor_layout_if_absent().as_xla_shape())
return ops.Infeed(self._builder,
shape.with_major_to_minor_layout_if_absent())
def Outfeed(self, operand):
"""Enqueues an outfeed op onto the computation.
@ -995,10 +819,9 @@ class ComputationBuilder(object):
if parameter_num is None:
parameter_num = next(self._parameter_numbering)
return ops.Parameter(
self._builder, parameter_num,
shape.with_major_to_minor_layout_if_absent().as_xla_shape(),
name.encode('utf8'))
return ops.Parameter(self._builder, parameter_num,
shape.with_major_to_minor_layout_if_absent(),
name.encode('utf8'))
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
"""Enqueues a Parameter op onto the computation.
@ -1013,7 +836,7 @@ class ComputationBuilder(object):
An XlaOp.
"""
return self.ParameterWithShape(
Shape.from_pyval(value), name=name, parameter_num=parameter_num)
shape_from_pyval(value), name=name, parameter_num=parameter_num)
def Iota(self, dtype, size):
"""Enqueues an iota constant onto the computation.
@ -1040,7 +863,7 @@ class ComputationBuilder(object):
An XlaOp representing the added broadcasted iota constant.
"""
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
xla_shape = _xla.Shape.Array(element_type, shape, None)
xla_shape = _xla.Shape.array_shape(element_type, shape, None)
return ops.Iota(self._builder, xla_shape, dimension)
def Concatenate(self, operands, dimension):
@ -1097,6 +920,24 @@ class ComputationBuilder(object):
dimensions = tuple(range(ndim))
return ops.Reshape(operand, dimensions, new_sizes)
def AllReduce(self, operand, computation, replica_groups=None):
"""AllReduce op.
Args:
operand: XlaOp representing the input array
computation: a Computation object - binary reduction function.
replica_groups: optional, list of lists of ints encoding a partition of
the set {0, 1, ..., num_replicas} into equally-sized replica groups
within which the all-to-all is performed. If not supplied or None (the
default), all replicas belong to the same group.
Returns:
An XlaOp that represents the all-reduced result.
"""
replica_groups_protos = _get_replica_groups_protos(replica_groups)
return ops.AllReduce(operand, computation.computation,
replica_groups_protos, None)
def AllToAll(self,
operand,
split_dimension,
@ -1117,13 +958,7 @@ class ComputationBuilder(object):
Returns:
An XlaOp that represents the all-to-all concatenation.
"""
if replica_groups is None:
replica_groups_protos = [] # special value for XLA API
else:
replica_groups = list(replica_groups)
replica_groups_protos = [
_make_replica_group_proto(group) for group in replica_groups
]
replica_groups_protos = _get_replica_groups_protos(replica_groups)
if not replica_groups:
split_count = 1
else:
@ -1146,13 +981,8 @@ class ComputationBuilder(object):
Returns:
An XlaOp that represents on each replica the sum of its group's values.
"""
if replica_groups is None:
replica_groups = [] # special value for XLA API
else:
replica_groups = [
_make_replica_group_proto(group) for group in replica_groups
]
return ops.CrossReplicaSum(operand, replica_groups)
replica_groups_protos = _get_replica_groups_protos(replica_groups)
return ops.CrossReplicaSum(operand, replica_groups_protos)
def Trans(self, operand):
"""Specialized matrix transpose op."""
@ -1298,10 +1128,9 @@ class ComputationBuilder(object):
An XlaOp representing the added custom call op.
"""
opaque = opaque or b''
return ops.CustomCall(
self._builder, call_target_name, list(operands),
shape_with_layout.as_xla_shape(),
[s.as_xla_shape() for s in operand_shapes_with_layout], opaque)
return ops.CustomCall(self._builder, call_target_name,
list(operands), shape_with_layout,
list(operand_shapes_with_layout), opaque)
def Map(self, operands, computation_to_apply, dimensions):
"""Enqueues a map operation onto the computation.
@ -1389,7 +1218,7 @@ class ComputationBuilder(object):
dims: A 1D array-like of nonnegative integers specifying the dimensions.
Returns: a XlaOp to the generated array of F32 values.
"""
shape = _xla.Shape.Array(self.GetShape(mu).xla_element_type(), dims)
shape = _xla.Shape.array_shape(self.GetShape(mu).xla_element_type(), dims)
return ops.RngNormal(mu, sigma, shape)
def RngUniform(self, a, b, dims):
@ -1406,7 +1235,7 @@ class ComputationBuilder(object):
Returns: a XlaOp to the generated array of values with the same numeric type
(F32, S32, or U32) as the arguments a and b.
"""
shape = _xla.Shape.Array(self.GetShape(a).xla_element_type(), dims)
shape = _xla.Shape.array_shape(self.GetShape(a).xla_element_type(), dims)
return ops.RngUniform(a, b, shape)
def While(self, cond, body, init):
@ -1659,7 +1488,6 @@ class ComputationBuilder(object):
FftType = _xla.FftType
_UNARY_OPS = [
'Not',
'Clz',
@ -1732,6 +1560,7 @@ _OTHER_OPS = [
'Cholesky',
'Clamp',
'Collapse',
'CollectivePermute',
'ConvertElementType',
'Dot',
'Gather',
@ -1885,3 +1714,14 @@ def _make_replica_group_proto(replica_group):
replica_group_proto = ReplicaGroup()
replica_group_proto.replica_ids.extend(replica_group)
return replica_group_proto
def _get_replica_groups_protos(replica_groups):
if replica_groups is None:
replica_groups_protos = [] # special value for XLA API
else:
replica_groups = list(replica_groups)
replica_groups_protos = [
_make_replica_group_proto(group) for group in replica_groups
]
return replica_groups_protos

View File

@ -315,10 +315,11 @@ class ComputationsWithConstantsTest(ComputationTest):
c.CustomCall(
b"test_subtract_f32",
operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),
shape_with_layout=xla_client.Shape.array_shape(np.float32, (), ()),
shape_with_layout=xla_client.Shape.array_shape(
np.dtype(np.float32), (), ()),
operand_shapes_with_layout=(
xla_client.Shape.array_shape(np.float32, (), ()),
xla_client.Shape.array_shape(np.float32, (), ()),
xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
))
self._ExecuteAndCompareClose(c, expected=0.75)
@ -1745,7 +1746,7 @@ class EmbeddedComputationsTest(ComputationTest):
def testInfeedS32Values(self):
to_infeed = NumpyArrayS32([1, 2, 3, 4])
c = self._NewComputation()
c.Infeed(xla_client.Shape.from_pyval(to_infeed[0]))
c.Infeed(xla_client.shape_from_pyval(to_infeed[0]))
compiled_c = c.Build().Compile()
for item in to_infeed:
xla_client.transfer_to_infeed(item)
@ -1757,7 +1758,7 @@ class EmbeddedComputationsTest(ComputationTest):
def testInfeedThenOutfeedS32(self):
to_round_trip = NumpyArrayS32([1, 2, 3, 4])
c = self._NewComputation()
x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0]))
x = c.Infeed(xla_client.shape_from_pyval(to_round_trip[0]))
c.Outfeed(x)
compiled_c = c.Build().Compile()
@ -1767,7 +1768,7 @@ class EmbeddedComputationsTest(ComputationTest):
execution.start()
xla_client.transfer_to_infeed(want)
got = xla_client.transfer_from_outfeed(
xla_client.Shape.from_pyval(to_round_trip[0]))
xla_client.shape_from_pyval(to_round_trip[0]))
execution.join()
self.assertEqual(want, got)
@ -1803,7 +1804,9 @@ class ErrorTest(ComputationTest):
c.ClearOpMetadata()
options = xla_client.CompileOptions()
options.argument_layouts = [xla_client.Shape.array_shape(np.float32, [])]
options.argument_layouts = [
xla_client.Shape.array_shape(np.dtype(np.float32), [])
]
def TestFun():
return c.Build().Compile(compile_options=options)

View File

@ -84,9 +84,9 @@ void AddXrtSubmodule(py::module* module) {
.def_property_readonly("tf_device_ids", &XrtContext::tf_device_ids);
py::class_<XrtBuffer, std::shared_ptr<XrtBuffer>>(m, "XrtBuffer")
.def_static("FromLiteral", &XrtBuffer::FromLiteral)
.def_static("MakeTuple", &XrtBuffer::MakeTuple)
.def("ToPython",
.def_static("from_literal", &XrtBuffer::FromLiteral)
.def_static("make_tuple", &XrtBuffer::MakeTuple)
.def("to_py",
[](std::shared_ptr<XrtBuffer> buffer) -> xla::StatusOr<py::object> {
auto literal = absl::make_unique<xla::Literal>();
{
@ -95,8 +95,10 @@ void AddXrtSubmodule(py::module* module) {
}
return xla::LiteralToPython(std::move(literal));
})
.def("Delete", &XrtBuffer::Delete)
.def("DestructureTuple", &XrtBuffer::DestructureTuple);
.def("delete", &XrtBuffer::Delete)
.def("destructure", &XrtBuffer::DestructureTuple)
.def("is_deleted",
[](const XrtBuffer& buffer) { return !buffer.handle().valid(); });
py::class_<XrtExecutable, std::shared_ptr<XrtExecutable>>(m, "XrtExecutable")
.def_static("Compile",

View File

@ -31,13 +31,6 @@ from tensorflow.compiler.xla.python import xla_extension as _xla
# pylint: enable=g-direct-tensorflow-import
def _make_xla_shape(shape):
if shape.is_tuple():
return _xla.Shape.Tuple([_make_xla_shape(s) for s in shape.tuple_shapes()])
return _xla.Shape.Array(shape.xla_element_type(), shape.dimensions(),
shape.minor_to_major())
def get_tf_context(target, worker):
"""Returns a TensorFlow RPC client object.
@ -60,7 +53,8 @@ class XrtBackend(xla_client.Backend):
tf_device_type: the type of TensorFlow device to use for XRT (e.g. `"TPU"`).
"""
def __init__(self, tf_context, tf_device_type):
def __init__(self, tf_context, tf_device_type, platform="tpu"):
super(XrtBackend, self).__init__(platform)
self.tf_device_type = tf_device_type
self.context = _xla.xrt.XrtContext.Create(tf_context, tf_device_type)
@ -69,30 +63,23 @@ class XrtBackend(xla_client.Backend):
return self.context.DeviceCount()
def buffer_from_pyval(self, pyval, device=0):
return _xla.xrt.XrtBuffer.FromLiteral(self.context, device, pyval)
def delete_buffer(self, c_buffer):
c_buffer.Delete()
def destructure_tuple(self, c_buffer):
return c_buffer.DestructureTuple()
return _xla.xrt.XrtBuffer.from_literal(self.context, device, pyval)
def make_tuple(self, buffers, device_ordinal):
return _xla.xrt.XrtBuffer.MakeTuple(self.context, buffers)
return _xla.xrt.XrtBuffer.make_tuple(self.context, buffers)
def compile(self, computation, compile_options):
# pylint: disable=protected-access
program_shape = xla_client._wrap_program_shape(
computation.GetProgramShape())
program_shape = computation.GetProgramShape()
# pylint: enable=protected-access
proto = computation.GetSerializedProto()
# TODO(phawkins): use the layouts in compile_options.
arg_shapes = [
_make_xla_shape(shape.with_major_to_minor_layout_if_absent())
for shape in program_shape.parameter_shapes
shape.with_major_to_minor_layout_if_absent()
for shape in program_shape.parameter_shapes()
]
result_shape = _make_xla_shape(
program_shape.result_shape.with_major_to_minor_layout_if_absent())
result_shape = (
program_shape.result_shape().with_major_to_minor_layout_if_absent())
device_assignment = _xla.xrt.AssignDevices(compile_options.num_replicas, 1)
return _xla.xrt.XrtExecutable.Compile(self.context, proto, arg_shapes,
result_shape, device_assignment)

View File

@ -48,7 +48,7 @@ class XrtBackendTest(test.TestCase):
b = np.arange(10)
c = BuildAddAndScaleComputation(
xla_client.Shape.from_pyval(a), xla_client.Shape.from_pyval(b))
xla_client.shape_from_pyval(a), xla_client.shape_from_pyval(b))
executable = c.Compile(backend=backend)
output = executable.ExecuteWithPythonValues((a, b))

View File

@ -18,6 +18,9 @@ package_group(
includes = [
"//tensorflow/compiler/xla:friends",
],
packages = [
"//learning/brain/experimental/tf_runtime/...",
],
)
xla_proto_library(
@ -434,10 +437,10 @@ tf_cc_test(
srcs = ["pattern_matcher_test.cc"],
deps = [
":hlo",
":hlo_parser",
":pattern_matcher",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
@ -505,8 +508,8 @@ cc_library(
hdrs = ["hlo_matchers.h"],
deps = [
":hlo",
":hlo_parser",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@ -549,13 +552,13 @@ tf_cc_test(
srcs = ["hlo_sharding_test.cc"],
deps = [
":hlo",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@ -583,6 +586,7 @@ tf_cc_test(
srcs = ["call_graph_test.cc"],
deps = [
":call_graph",
":hlo",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -590,7 +594,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@ -653,6 +656,7 @@ tf_cc_test(
deps = [
":call_graph",
":flatten_call_graph",
":hlo",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -660,7 +664,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@ -691,7 +694,6 @@ cc_library(
deps = [
":compiler",
":computation_placer",
":device_memory_allocator",
":platform_util",
":stream_pool",
":transfer_manager",
@ -701,6 +703,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:device_memory_allocator",
"//third_party/eigen3",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
@ -721,7 +724,6 @@ cc_library(
":compiler",
":computation_layout",
":computation_placer",
":device_memory_allocator",
":dump",
":dynamic_dimension_inference",
":executable",
@ -751,6 +753,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@ -767,7 +770,6 @@ cc_library(
":backend",
":compiler",
":computation_layout",
":device_memory_allocator",
":executable",
":hlo",
":hlo_execution_profile",
@ -787,6 +789,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@ -855,7 +858,6 @@ cc_library(
srcs = ["shaped_buffer.cc"],
hdrs = ["shaped_buffer.h"],
deps = [
":device_memory_allocator",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -865,6 +867,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -878,7 +881,6 @@ tf_cc_test(
srcs = ["shaped_buffer_test.cc"],
deps = [
":cpu_plugin",
":device_memory_allocator",
":platform_util",
":shaped_buffer",
"//tensorflow/compiler/xla:shape_util",
@ -888,6 +890,7 @@ tf_cc_test(
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/memory",
],
)
@ -901,7 +904,6 @@ cc_library(
],
deps = [
":computation_layout",
":device_memory_allocator",
":dump",
":hlo",
":hlo_execution_profile",
@ -922,6 +924,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
@ -988,7 +991,6 @@ cc_library(
hdrs = ["allocation_tracker.h"],
deps = [
":backend",
":device_memory_allocator",
":transfer_manager",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -997,6 +999,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -1156,6 +1159,7 @@ tf_cc_test(
":hlo",
":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@ -1163,7 +1167,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@ -1205,10 +1208,10 @@ tf_cc_test(
":hlo_dataflow_analysis",
":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@ -1455,8 +1458,8 @@ tf_cc_test(
srcs = ["instruction_fusion_test.cc"],
deps = [
":hlo_matchers",
":hlo_parser",
":instruction_fusion",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@ -1467,11 +1470,11 @@ cc_library(
srcs = ["multi_output_fusion.cc"],
hdrs = ["multi_output_fusion.h"],
deps = [
":hlo",
":hlo_pass",
":hlo_reachability",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@ -1663,6 +1666,7 @@ cc_library(
":hlo_pass",
":hlo_query",
":pattern_matcher",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
@ -1788,8 +1792,8 @@ tf_cc_test(
srcs = ["gather_expander_test.cc"],
deps = [
":gather_expander",
":hlo_parser",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:test_macros_header",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
],
@ -1887,9 +1891,9 @@ tf_cc_test(
name = "while_loop_analysis_test",
srcs = ["while_loop_analysis_test.cc"],
deps = [
":hlo_parser",
":while_loop_analysis",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@ -2294,7 +2298,7 @@ tf_cc_test(
":cpu_plugin",
":hlo_cost_analysis",
":hlo_execution_profile",
"//tensorflow/compiler/xla/service:hlo_parser",
":hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@ -2307,14 +2311,14 @@ tf_cc_test(
srcs = ["hlo_computation_test.cc"],
deps = [
":hlo",
":hlo_matchers",
":hlo_parser",
":pattern_matcher",
":pattern_matcher_gmock",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/container:flat_hash_map",
@ -2519,13 +2523,13 @@ tf_cc_test(
deps = [
":hlo",
":hlo_liveness_analysis",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@ -2909,12 +2913,12 @@ tf_cc_test(
deps = [
":hlo",
":hlo_module_dce",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@ -3040,12 +3044,12 @@ tf_cc_test(
":hlo",
":hlo_cse",
":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@ -3229,27 +3233,6 @@ tf_cc_test(
],
)
cc_library(
name = "device_memory_allocator",
srcs = [
"device_memory_allocator.cc",
"owning_device_memory.cc",
],
hdrs = [
"device_memory_allocator.h",
"owning_device_memory.h",
],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "maybe_owning_device_memory",
srcs = [
@ -3259,7 +3242,7 @@ cc_library(
"maybe_owning_device_memory.h",
],
deps = [
":device_memory_allocator",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
@ -3302,10 +3285,10 @@ xla_test(
"gpu",
],
deps = [
":hlo_parser",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -3428,6 +3411,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_matchers",
":hlo_parser",
":shape_inference",
":transpose_folding",
"//tensorflow/compiler/xla:literal",
@ -3436,7 +3420,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -3679,10 +3662,10 @@ tf_cc_test(
name = "tuple_util_test",
srcs = ["tuple_util_test.cc"],
deps = [
":hlo_matchers",
":hlo_parser",
":tuple_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@ -3708,11 +3691,11 @@ tf_cc_test(
name = "while_util_test",
srcs = ["while_util_test.cc"],
deps = [
":hlo_matchers",
":hlo_parser",
":while_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/algorithm:container",
],
@ -3743,9 +3726,9 @@ tf_cc_test(
srcs = ["while_loop_invariant_code_motion_test.cc"],
deps = [
":hlo_matchers",
":hlo_parser",
":while_loop_invariant_code_motion",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
],
@ -3771,9 +3754,9 @@ tf_cc_test(
srcs = ["while_loop_constant_sinking_test.cc"],
deps = [
":hlo_matchers",
":hlo_parser",
":while_loop_constant_sinking",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
],
@ -3973,6 +3956,8 @@ cc_library(
hdrs = ["ar_crs_combiner.h"],
deps = [
":call_graph",
":hlo",
":hlo_pass",
":pattern_matcher",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
@ -3980,8 +3965,6 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
@ -4005,11 +3988,11 @@ cc_library(
srcs = ["dynamic_index_splitter.cc"],
hdrs = ["dynamic_index_splitter.h"],
deps = [
":hlo",
":hlo_casting_utils",
":hlo_pass",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
@ -4124,6 +4107,13 @@ cc_library(
],
)
cc_library(
name = "custom_call_target_registry",
srcs = ["custom_call_target_registry.cc"],
hdrs = ["custom_call_target_registry.h"],
visibility = ["//visibility:public"],
)
tf_cc_test(
name = "slice_sinker_test",
srcs = ["slice_sinker_test.cc"],

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
@ -183,6 +184,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleBroadcast(HloInstruction* broadcast) override;
Status HandleCompare(HloInstruction* compare) override;
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleConstant(HloInstruction* constant) override;
@ -2213,6 +2216,49 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
HloInstruction* lhs;
HloInstruction* rhs;
CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
auto replace_with_pred_broadcast = [&](bool value) {
return ReplaceWithNewInstruction(
compare,
HloInstruction::CreateBroadcast(
compare->shape(),
computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(value))),
{}));
};
if (compare->comparison_direction() == ComparisonDirection::kLt &&
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
return replace_with_pred_broadcast(false);
} else if (compare->comparison_direction() == ComparisonDirection::kGt &&
IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
return replace_with_pred_broadcast(false);
} else if (compare->comparison_direction() == ComparisonDirection::kGe &&
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
return replace_with_pred_broadcast(true);
} else if (compare->comparison_direction() == ComparisonDirection::kLe &&
IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
return replace_with_pred_broadcast(true);
}
if (lhs == rhs &&
primitive_util::IsIntegralType(lhs->shape().element_type())) {
switch (compare->comparison_direction()) {
case ComparisonDirection::kGt:
case ComparisonDirection::kLt:
case ComparisonDirection::kNe:
return replace_with_pred_broadcast(false);
case ComparisonDirection::kEq:
case ComparisonDirection::kGe:
case ComparisonDirection::kLe:
return replace_with_pred_broadcast(true);
}
}
return Status::OK();
}
// A conversion to the same element type as the operand is a nop and can be
// removed. A conversion of a constant can be simplified by making a new
// constant.

View File

@ -5372,21 +5372,54 @@ TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
}
// This test exposes a real bug: It tries to read an out-of-bounds array index
// from within ComposePermutations(). TODO(b/132330723): Fix this.
TEST_F(AlgebraicSimplifierTest,
DotContractingReorder_NoChangeInContractingDimsOrder) {
DISABLED_DotContractingReorder_NoChangeInContractingDimsOrder) {
// No optimization opportunity here because the transpose does not reorder the
// contracting dims.
const char* kModuleStr = R"(
param = f32[2,5,1,3] parameter(0)
transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
reshape = f32[5,6] reshape(transpose)
constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
ROOT dot = f32[5,4] dot(reshape, constant),
lhs_contracting_dims={1}, rhs_contracting_dims={0}}
HloModule m
test {
param = f32[2,5,1,3] parameter(0)
transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
reshape = f32[5,6] reshape(transpose)
constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
ROOT dot = f32[5,4] dot(reshape, constant),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
}
TEST_F(AlgebraicSimplifierTest, CompareIota) {
const char* kModuleStr = R"(
HloModule m
test {
zero = s32[] constant(0)
iota = s32[128] iota(), iota_dimension=0
broad = s32[128] broadcast(zero), dimensions={}
ROOT compare = pred[128] compare(iota, broad), direction=LT
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::ConstantScalar(false))));
}
TEST_F(AlgebraicSimplifierTest, CompareSame) {
const char* kModuleStr = R"(
HloModule m
test {
param = s32[123] parameter(0)
ROOT compare = pred[123] compare(param, param), direction=GE
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::ConstantScalar(true))));
}
} // namespace
} // namespace xla

View File

@ -20,13 +20,13 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace xla {
@ -221,8 +221,8 @@ void AllocationTracker::AddAllocationOrIncrementRefCount(
auto it = allocation_map.find(device_memory.opaque());
if (it == allocation_map.end()) {
allocation_map[device_memory.opaque()] = {
OwningDeviceMemory(device_memory, device_ordinal,
backend_->memory_allocator()),
se::OwningDeviceMemory(device_memory, device_ordinal,
backend_->memory_allocator()),
/*ref_count=*/1};
} else {
it->second.ref_count++;

View File

@ -77,7 +77,7 @@ class AllocationTracker {
// Data structure encapsulating single memory allocation on the device.
struct Allocation {
// The pointer to this allocation.
OwningDeviceMemory device_memory;
se::OwningDeviceMemory device_memory;
// This is the number of times this memory allocation is referred to by
// registered data handles.

View File

@ -107,44 +107,90 @@ absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
return absl::nullopt;
}
absl::optional<HloInstruction*> ArCrsCombiner::ConditionalFromBodyParameter(
HloInstruction* instruction) {
CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
HloComputation* computation = instruction->parent();
auto caller_instructions = call_graph_->GetComputationCallers(computation);
if (caller_instructions.size() == 1) {
auto caller_instruction = caller_instructions[0];
if (caller_instruction->opcode() == HloOpcode::kConditional) {
return caller_instruction;
}
}
return absl::nullopt;
}
std::vector<HloInstruction*> ArCrsCombiner::GetAllTuples(
HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kTuple) {
return {instruction};
}
if (instruction->opcode() == HloOpcode::kDomain) {
return GetAllTuples(instruction->operands()[0]);
}
if (instruction->opcode() == HloOpcode::kParameter) {
auto maybe_while = WhileFromBodyParameter(instruction);
if (!maybe_while) {
return {};
}
auto while_instr = *maybe_while;
auto init_tuples = GetAllTuples(while_instr->while_init());
auto body_tuples =
GetAllTuples(while_instr->while_body()->root_instruction());
if (init_tuples.empty() || body_tuples.empty()) {
return {};
}
init_tuples.insert(init_tuples.end(), body_tuples.begin(),
body_tuples.end());
return init_tuples;
}
if (instruction->opcode() == HloOpcode::kGetTupleElement) {
std::vector<HloInstruction*> result_tuples;
for (auto tuple : GetAllTuples(instruction->operands()[0])) {
auto tmp_tuples =
GetAllTuples(tuple->mutable_operand(instruction->tuple_index()));
if (tmp_tuples.empty()) {
return {};
switch (instruction->opcode()) {
case HloOpcode::kTuple:
return {instruction};
case HloOpcode::kDomain:
return GetAllTuples(instruction->operands()[0]);
case HloOpcode::kParameter: {
auto maybe_while = WhileFromBodyParameter(instruction);
if (maybe_while) {
auto while_instr = *maybe_while;
auto init_tuples = GetAllTuples(while_instr->while_init());
auto body_tuples =
GetAllTuples(while_instr->while_body()->root_instruction());
if (init_tuples.empty() || body_tuples.empty()) {
return {};
}
init_tuples.insert(init_tuples.end(), body_tuples.begin(),
body_tuples.end());
return init_tuples;
}
result_tuples.insert(result_tuples.end(), tmp_tuples.begin(),
tmp_tuples.end());
auto maybe_conditional = ConditionalFromBodyParameter(instruction);
if (maybe_conditional) {
auto cond_instr = *maybe_conditional;
std::vector<HloInstruction*> tuples;
for (int64 i = 0; i < cond_instr->branch_computations().size(); ++i) {
if (cond_instr->branch_computation(i)->parameter_instruction(0) ==
instruction) {
// If the same computation is used for more than one branch of the
// conditional, we collect the arguments that flow to the
// computation from all branches.
auto branch_tuples =
GetAllTuples(cond_instr->mutable_operand(i + 1));
if (branch_tuples.empty()) {
return {};
}
tuples.insert(tuples.end(), branch_tuples.begin(),
branch_tuples.end());
}
}
return tuples;
}
return {};
}
return result_tuples;
case HloOpcode::kGetTupleElement: {
std::vector<HloInstruction*> result_tuples;
for (auto tuple : GetAllTuples(instruction->operands()[0])) {
auto tmp_tuples =
GetAllTuples(tuple->mutable_operand(instruction->tuple_index()));
if (tmp_tuples.empty()) {
return {};
}
result_tuples.insert(result_tuples.end(), tmp_tuples.begin(),
tmp_tuples.end());
}
return result_tuples;
}
case HloOpcode::kConditional: {
std::vector<HloInstruction*> result_tuples;
for (HloComputation* body : instruction->branch_computations()) {
if (body->root_instruction()->opcode() != HloOpcode::kTuple) {
return {};
}
result_tuples.push_back(body->root_instruction());
}
return result_tuples;
}
default:
return {};
}
return {};
}
bool ArCrsCombiner::TupleElementsComputeSameValue(

View File

@ -119,6 +119,12 @@ class ArCrsCombiner : public HloModulePass {
absl::optional<HloInstruction*> WhileFromBodyParameter(
HloInstruction* instruction);
// If the passed instruction is a parameter in one of the branch computations,
// and the branch body is only called by a single instruction, return the
// conditional instruction.
absl::optional<HloInstruction*> ConditionalFromBodyParameter(
HloInstruction* instruction);
// Returns a vector of tuple instructions.
// If all instructions that flow to "instruction" are tuples, return them.
// Otherwise, return an empty vector.

View File

@ -1173,5 +1173,47 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
EXPECT_FALSE(changed);
}
TEST_F(ArCrsCombinerTest, SameValueTestConditional) {
const char* module_str = R"(
HloModule foobar
branch_true {
pt = (f32[2,4], f32[2,4]) parameter(0)
gte.0 = f32[2,4] get-tuple-element(pt), index=0
gte.1 = f32[2,4] get-tuple-element(pt), index=1
ROOT tuple.t = (f32[2,4], f32[2,4]) tuple(gte.1, gte.0)
}
branch_false {
pf = (f32[2,4], f32[2,4]) parameter(0)
gte.0 = f32[2,4] get-tuple-element(pf), index=0
gte.1 = f32[2,4] get-tuple-element(pf), index=1
add = f32[2,4] add(gte.1, gte.1)
ROOT tuple.f = (f32[2,4], f32[2,4]) tuple(gte.0, add)
}
ENTRY Parameters1.v4 {
constant = pred[] constant(true)
p = f32[2,4] parameter(0)
tuple = (f32[2,4], f32[2,4]) tuple(p, p)
ROOT conditional = (f32[2,4], f32[2,4]) conditional(constant, tuple, tuple), true_computation=branch_true, false_computation=branch_false
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto cond = module->entry_computation()->root_instruction();
auto branch_true = cond->branch_computation(0)->root_instruction();
auto t0 = branch_true->mutable_operand(0);
auto t1 = branch_true->mutable_operand(1);
EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(t0, t1));
auto branch_false = cond->branch_computation(1)->root_instruction();
auto f0 = branch_false->mutable_operand(0);
auto f1 = branch_false->mutable_operand(1);
EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(f0, f1));
}
} // namespace
} // namespace xla

View File

@ -134,7 +134,7 @@ Backend::Backend(se::Platform* platform, Compiler* compiler,
}
}
// Create a memory allocator for the valid stream executors.
memory_allocator_ = absl::make_unique<StreamExecutorMemoryAllocator>(
memory_allocator_ = absl::make_unique<se::StreamExecutorMemoryAllocator>(
platform, stream_executors);
CHECK(!stream_executors_.empty())
<< "Service found no devices for backend " << platform_->Name() << '.';

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -35,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace Eigen {
struct ThreadPoolDevice;
@ -88,7 +88,7 @@ class Backend {
// Accessors for the various objects.
se::Platform* platform() const { return platform_; }
Compiler* compiler() const { return compiler_; }
DeviceMemoryAllocator* memory_allocator() const {
se::DeviceMemoryAllocator* memory_allocator() const {
return memory_allocator_.get();
}
TransferManager* transfer_manager() const { return transfer_manager_; }
@ -179,7 +179,7 @@ class Backend {
stream_pools_ GUARDED_BY(mu_);
// The default memory allocator to use.
std::unique_ptr<StreamExecutorMemoryAllocator> memory_allocator_;
std::unique_ptr<se::StreamExecutorMemoryAllocator> memory_allocator_;
// For the CPU backend, an Eigen threadpool device for use by Eigen code.
struct IntraOpThreadPool;

View File

@ -75,8 +75,10 @@ class AotCompilationOptions {
// Optional allocator that may be used for allocating temp space on the device
// during compilation.
DeviceMemoryAllocator* device_allocator() const { return device_allocator_; }
void set_device_allocator(DeviceMemoryAllocator* device_allocator) {
se::DeviceMemoryAllocator* device_allocator() const {
return device_allocator_;
}
void set_device_allocator(se::DeviceMemoryAllocator* device_allocator) {
device_allocator_ = device_allocator;
}
@ -98,7 +100,7 @@ class AotCompilationOptions {
AotCompilationOptions();
private:
DeviceMemoryAllocator* device_allocator_ = nullptr;
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
DebugOptions debug_options_;
absl::optional<DeviceAssignment> static_device_assignment_;
};
@ -147,14 +149,14 @@ class Compiler {
// allocated should be deallocated before this function returns.
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) = 0;
se::DeviceMemoryAllocator* device_allocator) = 0;
// Optimizes a HLO module group, a set of module which runs concurrently on
// multiple devices potentially communicating data between the modules.
virtual Status RunHloPassesOnModuleGroup(
HloModuleGroup* module_group,
absl::Span<se::StreamExecutor* const> executors,
DeviceMemoryAllocator* device_allocator) = 0;
se::DeviceMemoryAllocator* device_allocator) = 0;
// Compiles the HLO module for execution on a device given by the executor,
// and returns an executable object or an error status. No HLO passes are
@ -168,7 +170,7 @@ class Compiler {
// device_allocator is optional; see RunHloPasses.
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) = 0;
se::DeviceMemoryAllocator* device_allocator) = 0;
// Compiles a set of HLO modules that can run in parallel, potentially
// communicating data between the modules.
@ -176,7 +178,7 @@ class Compiler {
RunBackendOnModuleGroup(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
DeviceMemoryAllocator* device_allocator) = 0;
se::DeviceMemoryAllocator* device_allocator) = 0;
// Compiles a set of HLO modules that can run in parallel, potentially
// communicating data between the modules, and returns a corresponding
@ -189,7 +191,7 @@ class Compiler {
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
DeviceMemoryAllocator* device_allocator) = 0;
se::DeviceMemoryAllocator* device_allocator) = 0;
// Returns the backend configurations that the backend will consider for the
// given HLO. Returns no configurations if the backend does not support

View File

@ -182,7 +182,6 @@ cc_library(
deps = [
":compiler_functor",
":cpu_runtime",
":custom_call_target_registry",
":disassembler",
":orc_jit_memory_mapper",
":runtime_fp16",
@ -203,6 +202,7 @@ cc_library(
"@llvm//:orc_jit",
"@llvm//:support",
"@llvm//:target", # fixdeps: keep
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
@ -245,7 +245,6 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_execution_profile",
@ -255,6 +254,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor/host:host_stream",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@ -946,17 +946,6 @@ cc_library(
],
)
cc_library(
name = "custom_call_target_registry",
srcs = [
"custom_call_target_registry.cc",
],
hdrs = [
"custom_call_target_registry.h",
],
visibility = ["//visibility:public"],
)
cc_library(
name = "orc_jit_memory_mapper",
srcs = ["orc_jit_memory_mapper.cc"],

View File

@ -537,7 +537,7 @@ Status CreateHloProfilingArtifacts(
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
DeviceMemoryAllocator* /*device_allocator*/) {
se::DeviceMemoryAllocator* /*device_allocator*/) {
std::unique_ptr<llvm::TargetMachine> jit_target_machine =
SimpleOrcJIT::InferTargetMachineForJIT(
CompilerTargetOptions(module->config()),
@ -597,7 +597,7 @@ struct OrcJITPostCompilationHook {
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* /*device_allocator*/) {
se::DeviceMemoryAllocator* /*device_allocator*/) {
VLOG(1) << "Compiling: " << module->name();
XLA_SCOPED_LOGGING_TIMER(
absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));

View File

@ -133,11 +133,11 @@ class CpuCompiler : public LLVMCompiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* device_allocator) override;
se::DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* device_allocator) override;
se::DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,

View File

@ -73,13 +73,13 @@ CpuExecutable::CpuExecutable(
}
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
std::vector<se::OwningDeviceMemory>>>
CpuExecutable::CreateBufferTable(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
absl::Span<const ShapedBuffer* const> arguments) {
std::vector<se::DeviceMemoryBase> unowning_buffers(
assignment_->Allocations().size());
std::vector<OwningDeviceMemory> owning_buffers(
std::vector<se::OwningDeviceMemory> owning_buffers(
assignment_->Allocations().size());
VLOG(3) << "Allocating " << assignment_->Allocations().size()
<< " allocations for module " << module().name();
@ -207,7 +207,7 @@ Status CpuExecutable::ExecuteComputeFunction(
StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
absl::Span<OwningDeviceMemory> buffers) {
absl::Span<se::OwningDeviceMemory> buffers) {
se::Stream* stream = run_options->stream();
ScopedShapedBuffer result_buffer(
/*on_host_shape=*/result_shape(),
@ -216,7 +216,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
const HloInputOutputAliasConfig& input_output_alias =
module().input_output_alias_config();
// Move OwningDeviceMemory values which contain the array(s) of the result
// Move se::OwningDeviceMemory values which contain the array(s) of the result
// into the respective location in ScopedShapedBuffer which is returned to the
// caller.
TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus(
@ -235,7 +235,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
const BufferAllocation::Slice slice,
this->assignment_->GetUniqueSlice(src, buffer_source->index()));
const BufferAllocation::Index buffer_index = slice.index();
OwningDeviceMemory& buffer = buffers[buffer_index];
se::OwningDeviceMemory& buffer = buffers[buffer_index];
if (!slice.allocation()->is_entry_computation_parameter()) {
// If the buffer coming out of the result is from a parameter, the
// owning buffer will be null, and that means the caller aliased some
@ -297,8 +297,8 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
auto* host_stream = dynamic_cast<se::host::HostStream*>(
run_options->stream()->implementation());
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
std::vector<OwningDeviceMemory> owning_buffers;
se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
std::vector<se::OwningDeviceMemory> owning_buffers;
std::vector<se::DeviceMemoryBase> unowning_buffers;
TF_ASSIGN_OR_RETURN(
std::tie(unowning_buffers, owning_buffers),
@ -326,7 +326,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
CpuExecutable* executable;
ServiceExecutableRunOptions run_options;
std::vector<se::DeviceMemoryBase> unowning_buffers;
std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
std::shared_ptr<std::vector<se::OwningDeviceMemory>> buffers;
HloExecutionProfile* hlo_execution_profile;
void operator()() {
@ -338,7 +338,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
};
host_stream->EnqueueTask(
AsyncRunTask{this, *run_options, std::move(unowning_buffers),
std::make_shared<std::vector<OwningDeviceMemory>>(
std::make_shared<std::vector<se::OwningDeviceMemory>>(
std::move(owning_buffers)),
hlo_execution_profile});

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -37,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace xla {
namespace cpu {
@ -111,8 +111,9 @@ class CpuExecutable : public Executable {
// storage and the live-out buffer into which the computation writes it
// result.
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
std::vector<se::OwningDeviceMemory>>>
CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
int device_ordinal,
absl::Span<const ShapedBuffer* const> arguments);
// Calls the generated function performing the computation with the given
@ -126,7 +127,7 @@ class CpuExecutable : public Executable {
// The addresses are set according to buffer assignment.
StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
absl::Span<OwningDeviceMemory> buffers);
absl::Span<se::OwningDeviceMemory> buffers);
// Returns the points-to set of the root instruction of the entry
// computation. Uses points-to analysis from buffer assignment.

View File

@ -1,74 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
// This file is depended on by kernels that have to build for mobile devices.
// For this reason, we avoid relying on TensorFlow and instead only use the
// standard C++ library.
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
namespace xla {
namespace cpu {
// The CPU JIT compiler uses this registry to resolve symbolic CustomCall
// targets; so when using the CPU JIT, CustomCall targets need to be registered
// here with the symbol name used in the CustomCall.
//
// The XLA AOT compiler links using a standard offline linker; so when compiling
// in AOT mode, you *also* need to make sure the name of the callee (presumably
// implemented in C++) matches up with the symbolic name used in the CustomCall.
//
// We maintain the registry in both the JIT and the AOT cases for simplicity,
// but we only use it when running in JIT mode.
class CustomCallTargetRegistry {
public:
static CustomCallTargetRegistry* Global();
void Register(const std::string& symbol, void* address);
void* Lookup(const std::string& symbol) const;
private:
std::unordered_map<std::string, void*> registered_symbols_;
mutable std::mutex mu_;
};
class RegisterCustomCallTarget {
public:
explicit RegisterCustomCallTarget(const std::string& name, void* address) {
CustomCallTargetRegistry::Global()->Register(name, address);
}
};
#define REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, counter) \
static ::xla::cpu::RegisterCustomCallTarget REGISTER_CUSTOM_CALL_CONCAT( \
custom_call_target_register, counter)(symbol, \
reinterpret_cast<void*>(address))
#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, __COUNTER__)
#define REGISTER_CUSTOM_CALL_TARGET(function) \
REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)
} // namespace cpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_

View File

@ -119,13 +119,9 @@ llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input,
int32 vector_width) {
VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32");
// This implements the same polynomial approximation as implemented in Eigen3.
// This implements the same polynomial approximation as implemented in Cephes.
const llvm::APFloat half = GetIeeeF32(0.5);
const llvm::APFloat one = GetIeeeF32(1.0);
const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950);
const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949);
const llvm::APFloat one = GetIeeeF32(1);
const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341);
const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375);
@ -138,39 +134,79 @@ llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input,
const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1);
const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1);
llvm::Value* input_clamped =
vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi);
llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half));
llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx);
llvm::Value* z = vsl.Mul(cephes_exp_C2, fx);
llvm::Value* x = vsl.Sub(input_clamped, tmp);
x = vsl.Sub(x, z);
z = vsl.Mul(x, x);
// To compute e^input, we re-express it as
//
// e^input = e^(a + b)
// = e^(a + n log(2))
// = e^a * 2^n.
//
// We choose n = floor(a * log(2) + 0.5), restricting the value of `a` to
// (-0.5, 0.5). We then use a polynomial to compute e^a.
llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1);
y = vsl.MulAdd(y, x, cephes_exp_p2);
y = vsl.MulAdd(y, x, cephes_exp_p3);
y = vsl.MulAdd(y, x, cephes_exp_p4);
y = vsl.MulAdd(y, x, cephes_exp_p5);
y = vsl.MulAdd(y, z, x);
y = vsl.Add(one, y);
// Restrict input to a small range, including some values that evaluate to
// +/- inf. Our computations below aren't particularly sensitive to the exact
// choices here, so we choose values a bit larger/smaller than
//
// log(F32_MAX) = 88.723...
// log(F32_EPSILON) = -103.279....
//
input = vsl.Clamp(input, GetIeeeF32(-104), GetIeeeF32(88.8));
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
// time so drop down to IRBuilder for this bit.
llvm::Value* vector_constant_0x7f =
b->CreateVectorSplat(vector_width, b->getInt32(0x7f));
llvm::Value* vector_constant_23 =
b->CreateVectorSplat(vector_width, b->getInt32(23));
llvm::Type* i32_vector_type =
llvm::VectorType::get(b->getInt32Ty(), vector_width);
// fx is clamped so we don't have to worry about it being out of range for
// i32.
llvm::Value* emm0 = b->CreateFPToSI(fx, i32_vector_type);
emm0 = b->CreateAdd(emm0, vector_constant_0x7f);
emm0 = b->CreateShl(emm0, vector_constant_23);
llvm::Value* emm0_f32 = b->CreateBitCast(emm0, vsl.vector_type());
llvm::Value* x = input;
llvm::Value* n = vsl.Floor(vsl.MulAdd(input, cephes_LOG2EF, half));
return vsl.Max(vsl.Mul(y, emm0_f32), input);
// When we eventually do the multiplication in e^a * 2^n, we need to handle
// the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
// (so e^a * 2^n != inf). There's a similar problem for n < -126, the
// smallest fp32 exponent.
//
// A straightforward solution would be to detect n out of range and split it
// up, doing
//
// e^a * 2^n = e^a * 2^(n1 + n2)
// = (2^n1 * e^a) * 2^n2.
//
// But it turns out this approach is quite slow. It's not clear why; our
// hypothesis is that the integer operations on the exponent `n` have nonlocal
// effects on the pipeline.
//
// The approach we use instead is to clamp n to [-126, 127] so 2^n doesn't
// over/underflow. This causes `a` to be outside the range (-0.5, 0.5), which
// means that our polynomial for e^a will give a less-accurate result. In
// practice this seems to work well enough; it passes our exhaustive tests,
// breaking only one result, and by one ulp (we return exp(88.7228394) =
// max-float but we should return inf).
n = vsl.Clamp(n, GetIeeeF32(-126), GetIeeeF32(127));
// Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
x = vsl.Sub(x, vsl.Mul(cephes_exp_C1, n));
x = vsl.Sub(x, vsl.Mul(cephes_exp_C2, n));
llvm::Value* z = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1);
z = vsl.MulAdd(z, x, cephes_exp_p2);
z = vsl.MulAdd(z, x, cephes_exp_p3);
z = vsl.MulAdd(z, x, cephes_exp_p4);
z = vsl.MulAdd(z, x, cephes_exp_p5);
z = vsl.MulAdd(z, vsl.Mul(x, x), x);
z = vsl.Add(one, z);
// Convert n to an i32. This is safe because we clamped it above.
llvm::Value* n_i32 =
b->CreateFPToSI(n, llvm::VectorType::get(b->getInt32Ty(), vector_width));
// Create 2^n as an fp32. This works because -126 <= n <= 127 means that n is
// within the bounds for an fp32 exponent.
auto splat_i32 = [&](int32 v) {
return b->CreateVectorSplat(vector_width, b->getInt32(v));
};
const int32 kF32SignificandBits = 23;
llvm::Value* exp_bias = splat_i32(0x7f);
llvm::Value* pow2 =
b->CreateBitCast(b->CreateShl(b->CreateAdd(n_i32, exp_bias),
splat_i32(kF32SignificandBits)),
vsl.vector_type());
// Return z * 2^n.
return vsl.Mul(z, pow2);
}
llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input,

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include <stdint.h>
#include <algorithm>
#include <list>
#include <utility>
@ -28,7 +29,6 @@ limitations under the License.
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Host.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/logging.h"
@ -146,16 +147,18 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
// On Mac OS X, 'name' may have a leading underscore prefix, even though the
// registered name may not.
std::string stripped_name(name.begin() + 1, name.end());
func_addr = CustomCallTargetRegistry::Global()->Lookup(stripped_name);
func_addr =
xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
} else {
func_addr = CustomCallTargetRegistry::Global()->Lookup(name);
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host");
}
if (func_addr == nullptr) {
LOG(ERROR)
<< "Unable to resolve runtime symbol: `" << name
<< "'. Hint: if the symbol a custom call target, make sure you've "
"registered it with the JIT using REGISTER_CUSTOM_CALL_TARGET.";
"registered it with the JIT using "
"XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.";
return nullptr;
}
llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),
@ -209,14 +212,15 @@ llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) {
namespace {
// Register some known symbols with the CustomCallTargetRegistry.
bool RegisterKnownJITSymbols() {
CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global();
xla::CustomCallTargetRegistry* registry =
xla::CustomCallTargetRegistry::Global();
#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
do { \
auto* function_address = \
reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
function_address); \
function_address, "Host"); \
CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
"__xla_cpu_runtime_" #base_name); \
} while (false)
@ -247,8 +251,10 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee),
"Host");
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
"Host");
#undef REGISTER_CPU_RUNTIME_SYMBOL
@ -256,11 +262,12 @@ bool RegisterKnownJITSymbols() {
// Unfortunately the double versions are overloaded on some systems, e.g.
// Mac so we need an explicit cast. This requires passing the function signature
// for that case.
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
do { \
registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \
registry->Register( \
#name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
do { \
registry->Register(#name "f", reinterpret_cast<void*>(name##f), "Host"); \
registry->Register(#name, \
reinterpret_cast<void*>(static_cast<double_sig>(name)), \
"Host"); \
} while (false)
REGISTER_LIBM_SYMBOL(acos, double (*)(double));
@ -318,8 +325,9 @@ bool RegisterKnownJITSymbols() {
#ifdef __APPLE__
REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
registry->Register("__sincosf_stret",
reinterpret_cast<void*>(__sincosf_stret));
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret));
reinterpret_cast<void*>(__sincosf_stret), "Host");
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret),
"Host");
#else
REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
#endif
@ -332,19 +340,19 @@ bool RegisterKnownJITSymbols() {
#undef REGISTER_LIBM_SYMBOL
registry->Register("memcpy", reinterpret_cast<void*>(memcpy));
registry->Register("memmove", reinterpret_cast<void*>(memmove));
registry->Register("memset", reinterpret_cast<void*>(memset));
registry->Register("memcpy", reinterpret_cast<void*>(memcpy), "Host");
registry->Register("memmove", reinterpret_cast<void*>(memmove), "Host");
registry->Register("memset", reinterpret_cast<void*>(memset), "Host");
#ifdef __APPLE__
registry->Register("__bzero", reinterpret_cast<void*>(bzero));
registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host");
registry->Register("memset_pattern16",
reinterpret_cast<void*>(memset_pattern16));
reinterpret_cast<void*>(memset_pattern16), "Host");
#endif
#ifdef MEMORY_SANITIZER
registry->Register("__msan_unpoison",
reinterpret_cast<void*>(__msan_unpoison));
reinterpret_cast<void*>(__msan_unpoison), "Host");
#endif
return true;

View File

@ -107,13 +107,19 @@ llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
const llvm::APFloat& low,
const llvm::APFloat& high) {
CHECK(!low.isNaN());
CHECK(!high.isNaN());
CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
AssertCorrectTypes({a});
llvm::Type* type = a->getType();
CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
CHECK(scalar_type_->isFloatingPointTy());
return llvm_ir::EmitFloatMin(
llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), b_),
GetConstantFloat(type, high), b_);
llvm::Value* low_value = GetConstantFloat(type, low);
llvm::Value* high_value = GetConstantFloat(type, high);
a = b_->CreateSelect(b_->CreateFCmpUGE(a, low_value), a, low_value);
a = b_->CreateSelect(b_->CreateFCmpULE(a, high_value), a, high_value);
return a;
}
llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,

View File

@ -100,8 +100,10 @@ class VectorSupportLibrary {
llvm::Value* Floor(llvm::Value* a);
// Precondition: Neither `low` nor `high` is nan.
llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
const llvm::APFloat& high);
llvm::Value* SplatFloat(const llvm::APFloat& d) {
return GetConstantFloat(vector_type(), d);
}

View File

@ -13,10 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
namespace xla {
namespace cpu {
CustomCallTargetRegistry* CustomCallTargetRegistry::Global() {
static auto* registry = new CustomCallTargetRegistry;
@ -24,16 +23,17 @@ CustomCallTargetRegistry* CustomCallTargetRegistry::Global() {
}
void CustomCallTargetRegistry::Register(const std::string& symbol,
void* address) {
void* address,
const std::string& platform) {
std::lock_guard<std::mutex> lock(mu_);
registered_symbols_[symbol] = address;
registered_symbols_[std::make_pair(symbol, platform)] = address;
}
void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const {
void* CustomCallTargetRegistry::Lookup(const std::string& symbol,
const std::string& platform) const {
std::lock_guard<std::mutex> lock(mu_);
auto it = registered_symbols_.find(symbol);
auto it = registered_symbols_.find(std::make_pair(symbol, platform));
return it == registered_symbols_.end() ? nullptr : it->second;
}
} // namespace cpu
} // namespace xla

Some files were not shown because too many files have changed in this diff Show More