Merge branch 'master' into java-eager-tensor
This commit is contained in:
commit
4384648a78
@ -293,7 +293,6 @@ tf_cuda_cc_test(
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = [
|
||||
"no_oss", # http://b/119522529
|
||||
"noasan",
|
||||
],
|
||||
# We must ensure that the dependencies can be dynamically linked since
|
||||
|
@ -30,8 +30,8 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/while_loop.h"
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/framework/logging.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/kernels/logging_ops.h"
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
|
@ -36,6 +36,7 @@ py_binary(
|
||||
name = "make_test_graphs",
|
||||
testonly = 1,
|
||||
srcs = ["make_test_graphs.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
|
@ -220,6 +220,7 @@ cc_library(
|
||||
name = "shape_inference_helpers",
|
||||
srcs = ["shape_inference_helpers.cc"],
|
||||
hdrs = ["shape_inference_helpers.h"],
|
||||
visibility = [":friends"],
|
||||
deps = ["//tensorflow/core:graph"],
|
||||
)
|
||||
|
||||
@ -262,7 +263,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -270,6 +270,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -466,6 +467,9 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
|
@ -165,6 +165,18 @@ bool LogNotCompilableAndReturn(const Node& node,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) {
|
||||
// b/127344411: SelfAdjointEigV2 and Svd precision issues.
|
||||
return node.type_string() == "SelfAdjointEigV2" ||
|
||||
node.type_string() == "Svd";
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) {
|
||||
// b/128001705: SelfAdjointEigV2 and Svd performance issues.
|
||||
return node.type_string() == "SelfAdjointEigV2" ||
|
||||
node.type_string() == "Svd" || node.type_string() == "Qr";
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
const Node& node, int depth, FunctionLibraryRuntime* lib_runtime) {
|
||||
// _Arg nodes in a top-level function represent feeds and _Retval nodes in a
|
||||
@ -228,8 +240,12 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
"resource variable op in called function");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_svd_op && node.type_string() == "Svd") {
|
||||
return LogNotCompilableAndReturn(node, "Svd ops disabled");
|
||||
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) {
|
||||
return LogNotCompilableAndReturn(node, "operation with correctness issues");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) {
|
||||
return LogNotCompilableAndReturn(node, "slow operation");
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -248,7 +264,8 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
registration.elide_assert_and_checknumerics;
|
||||
op_filter.allow_ops_producing_or_consuming_variant =
|
||||
registration.cluster_variant_ops;
|
||||
op_filter.allow_svd_op = registration.cluster_svd_op;
|
||||
op_filter.allow_slow_and_inaccurate_ops =
|
||||
registration.cluster_slow_and_inaccurate_ops;
|
||||
return op_filter;
|
||||
}
|
||||
|
||||
|
@ -97,10 +97,9 @@ class RecursiveCompilabilityChecker {
|
||||
// live-out DT_VARIANT values.
|
||||
bool allow_ops_producing_or_consuming_variant;
|
||||
|
||||
// Whether the "Svd" op should be auto-clustered. The XLA implemenation of
|
||||
// this op has some performance (b/128001705) and possibly correctness
|
||||
// (b/127344411) issues so we avoid auto-clustering it.
|
||||
bool allow_svd_op;
|
||||
// Whether ops known to be slow or to have correctness issues should be
|
||||
// auto-clustered.
|
||||
bool allow_slow_and_inaccurate_ops;
|
||||
};
|
||||
|
||||
RecursiveCompilabilityChecker(const OperationFilter* op_filter,
|
||||
@ -119,6 +118,11 @@ class RecursiveCompilabilityChecker {
|
||||
return IsCompilableCall(call_def, /*depth=*/0, lib_runtime);
|
||||
}
|
||||
|
||||
// Returns true if XLA supports this Op, but we don't want to cluster it (ie:
|
||||
// due to performance or correctness concerns).
|
||||
bool OpIsInaccurate(const Node& node);
|
||||
bool OpIsSlow(const Node& node);
|
||||
|
||||
private:
|
||||
bool IsCompilableNode(const Node& node, int depth,
|
||||
FunctionLibraryRuntime* lib_runtime);
|
||||
|
@ -371,7 +371,8 @@ class PredicateFactory {
|
||||
Predicate** predicate) {
|
||||
TensorId tensor_id(node->name(), output_idx);
|
||||
|
||||
bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL;
|
||||
bool is_boolean_tensor =
|
||||
BaseType(node->output_type(tensor_id.index())) == DT_BOOL;
|
||||
TF_RET_CHECK(!must_be_true || is_boolean_tensor);
|
||||
|
||||
if (node->type_string() == "Const" && must_be_true) {
|
||||
|
@ -1067,5 +1067,25 @@ TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, RefBoolSwitchCondition) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output condition_ref_var =
|
||||
ops::Variable(root.WithOpName("cond_ref"), TensorShape({}), DT_BOOL);
|
||||
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
|
||||
ops::Switch sw(root.WithOpName("switch"), value, condition_ref_var);
|
||||
|
||||
Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
|
||||
Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
|
||||
|
||||
FixupSourceAndSinkEdges(root.graph());
|
||||
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "~*cond_ref:0");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "*cond_ref:0");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -97,28 +97,19 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
bool* out_can_pick_device,
|
||||
absl::optional<jit::DeviceId>* out_device_picked) {
|
||||
if (out_can_pick_device) {
|
||||
*out_can_pick_device = true;
|
||||
}
|
||||
|
||||
xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
|
||||
bool failure_to_pick_is_error) {
|
||||
#define FAILED_TO_PICK_DEVICE(failing_status) \
|
||||
do { \
|
||||
if (out_can_pick_device) { \
|
||||
*out_can_pick_device = false; \
|
||||
return Status::OK(); \
|
||||
} else { \
|
||||
if (failure_to_pick_is_error) { \
|
||||
return failing_status; \
|
||||
} else { \
|
||||
return {absl::nullopt}; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
TF_RET_CHECK(!devices.IsEmpty()) << "No devices to choose from";
|
||||
DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
|
||||
|
||||
absl::optional<jit::DeviceId> maybe_gpu_device;
|
||||
absl::optional<jit::DeviceId> maybe_cpu_device;
|
||||
absl::optional<jit::DeviceId> maybe_unknown_device;
|
||||
@ -182,17 +173,15 @@ Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache,
|
||||
}
|
||||
}
|
||||
|
||||
if (out_device_picked) {
|
||||
if (maybe_gpu_device) {
|
||||
*out_device_picked = *maybe_gpu_device;
|
||||
} else if (maybe_unknown_device) {
|
||||
*out_device_picked = *maybe_unknown_device;
|
||||
} else {
|
||||
*out_device_picked = *maybe_cpu_device;
|
||||
}
|
||||
if (maybe_gpu_device) {
|
||||
return {*maybe_gpu_device};
|
||||
} else if (maybe_unknown_device) {
|
||||
return {*maybe_unknown_device};
|
||||
} else if (maybe_cpu_device) {
|
||||
return {*maybe_cpu_device};
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
|
||||
|
||||
#undef FAILED_TO_PICK_DEVICE
|
||||
}
|
||||
@ -200,21 +189,18 @@ Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache,
|
||||
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
|
||||
absl::optional<jit::DeviceId> device;
|
||||
TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(
|
||||
device_info_cache, devices, allow_mixing_unknown_and_cpu,
|
||||
/*out_can_pick_device=*/nullptr, &device));
|
||||
return *device;
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<jit::DeviceId> device_id,
|
||||
PickDeviceForXlaImpl(device_info_cache, devices,
|
||||
allow_mixing_unknown_and_cpu,
|
||||
/*failure_to_pick_is_error=*/true));
|
||||
return *device_id;
|
||||
}
|
||||
|
||||
xla::StatusOr<bool> CanPickDeviceForXla(
|
||||
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
|
||||
bool can_pick_device;
|
||||
TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(device_info_cache, devices,
|
||||
allow_mixing_unknown_and_cpu,
|
||||
&can_pick_device,
|
||||
/*out_device_picked=*/nullptr));
|
||||
return can_pick_device;
|
||||
return PickDeviceForXlaImpl(device_info_cache, devices,
|
||||
allow_mixing_unknown_and_cpu,
|
||||
/*failure_to_pick_is_error=*/false);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -71,17 +71,34 @@ class DeviceSet {
|
||||
// iterator if this ends up being used widely.
|
||||
for (int word_index = 0; word_index < storage_.size(); word_index++) {
|
||||
uint64 word = storage_[word_index];
|
||||
for (int bit_index = 0; bit_index < kWordSize; bit_index++) {
|
||||
if (word & (1ull << bit_index)) {
|
||||
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
|
||||
return;
|
||||
}
|
||||
while (word != 0) {
|
||||
uint64 only_lowest_bit_set = word & -word;
|
||||
// The number of trailing zeros in a non-zero word is the index of the
|
||||
// least significant 1.
|
||||
int bit_index = ctz_uint64(word);
|
||||
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
|
||||
return;
|
||||
}
|
||||
word ^= only_lowest_bit_set;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static int ctz_uint64(uint64 x) {
|
||||
DCHECK_NE(x, 0);
|
||||
#ifdef __GNUC__
|
||||
return __builtin_ctzl(x);
|
||||
#else
|
||||
int result = 0u;
|
||||
while ((x & 1u) == 0u) {
|
||||
x >>= 1;
|
||||
++result;
|
||||
}
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
absl::InlinedVector<uint64, 1> storage_;
|
||||
|
||||
const int kWordSize = 64;
|
||||
@ -181,9 +198,12 @@ xla::StatusOr<jit::DeviceId> PickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
|
||||
|
||||
// This is like `PickDeviceForXla` except that it returns false (instead of a
|
||||
// This is like `PickDeviceForXla` except that it returns nullopt (instead of a
|
||||
// non-OK Status) if no unambiguous choice of device exists.
|
||||
xla::StatusOr<bool> CanPickDeviceForXla(
|
||||
//
|
||||
// We return a failing Status for errors unrelated to the device choice
|
||||
// algorithm itself.
|
||||
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
|
||||
} // namespace tensorflow
|
||||
|
@ -537,8 +537,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
|
||||
XlaClusterInfo{func, func_name_attrs, xla_computation_node,
|
||||
std::map<string, int>{}});
|
||||
}
|
||||
bool modified;
|
||||
s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
|
||||
graph_out.get(), flr, lib_def.get());
|
||||
graph_out.get(), flr, lib_def.get(), &modified);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
GraphDef graphdef_out;
|
||||
@ -1105,7 +1106,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>({"_xla_token_arg_node"})}},
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute"})}},
|
||||
{"F"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
@ -1985,7 +1988,9 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>({"_xla_token_arg_node"})}}},
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute"})}}},
|
||||
},
|
||||
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
|
||||
{"h_0_retval_retval", "H:o:0"}});
|
||||
@ -2110,7 +2115,9 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>({"_xla_token_arg_node"})}}},
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute"})}}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
@ -2258,7 +2265,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>({"_xla_token_arg_node"})}},
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
|
||||
{}},
|
||||
{{"outside_compilation_O3_host_compute"},
|
||||
"XlaHostCompute",
|
||||
@ -2271,7 +2279,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O3"},
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>({"_xla_token_arg_node"})}},
|
||||
absl::Span<const string>({"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute",
|
||||
"outside_compilation_O2_host_compute"})}},
|
||||
{}}},
|
||||
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
|
||||
{"h_0_retval_retval", "H:o:0"}});
|
||||
|
@ -14,9 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/encapsulate_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
@ -24,6 +27,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -333,6 +339,43 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
|
||||
OutsideCompilationClusterDependencies(
|
||||
const Graph* g, const string& outside_compilation_attr_name) {
|
||||
auto cluster_deps = absl::make_unique<
|
||||
absl::flat_hash_map<string, absl::flat_hash_set<string>>>();
|
||||
|
||||
for (const Edge* e : g->edges()) {
|
||||
auto src_outside_compilation =
|
||||
GetStringAttr(*e->src(), outside_compilation_attr_name);
|
||||
auto dst_outside_compilation =
|
||||
GetStringAttr(*e->dst(), outside_compilation_attr_name);
|
||||
|
||||
if (src_outside_compilation && dst_outside_compilation &&
|
||||
*src_outside_compilation != *dst_outside_compilation) {
|
||||
auto dst_deps_it = cluster_deps->find(*dst_outside_compilation);
|
||||
if (dst_deps_it == cluster_deps->end()) {
|
||||
cluster_deps->insert(std::make_pair(
|
||||
*dst_outside_compilation,
|
||||
absl::flat_hash_set<string>({*src_outside_compilation})));
|
||||
} else {
|
||||
dst_deps_it->second.insert(*src_outside_compilation);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto cluster_deps_ordered =
|
||||
absl::make_unique<absl::flat_hash_map<string, std::vector<string>>>();
|
||||
|
||||
for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) {
|
||||
std::vector<string> ordered_deps(it->second.begin(), it->second.end());
|
||||
std::sort(ordered_deps.begin(), ordered_deps.end());
|
||||
cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps));
|
||||
}
|
||||
|
||||
return std::move(cluster_deps_ordered);
|
||||
}
|
||||
|
||||
Status PreprocessEdgesBetweenOutsideCompilations(
|
||||
Graph* g, const string& outside_compilation_attr_name) {
|
||||
// Remove edges from source node to outside compilation nodes, and edges
|
||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -89,6 +91,15 @@ struct XlaClusterInfo {
|
||||
const std::map<string, int> host_compute_core;
|
||||
};
|
||||
|
||||
// Finds dependencies between outside compilation clusters, including both data
|
||||
// dependencies and control dependencies. cluster_deps maps the name name of an
|
||||
// outside compilation cluster to a set of names of outside compilation clusters
|
||||
// that it depends on.
|
||||
stream_executor::port::StatusOr<
|
||||
std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
|
||||
OutsideCompilationClusterDependencies(
|
||||
const Graph* g, const string& outside_compilation_attr_name);
|
||||
|
||||
// Preprocesses edges within the same XLA cluster. It will perform the following
|
||||
// operations in order:
|
||||
//
|
||||
|
@ -15,12 +15,14 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
@ -287,15 +289,20 @@ absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
|
||||
return results;
|
||||
}
|
||||
|
||||
string host_compute_node_name(const string& original_oc_name) {
|
||||
return absl::StrCat("outside_compilation_", original_oc_name,
|
||||
"_host_compute");
|
||||
}
|
||||
|
||||
// Builds XlaHostCompute NodeDef from the outside compilation call node.
|
||||
xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
||||
const Node* call_node, const std::map<string, int>& host_compute_core) {
|
||||
const Node* call_node, const std::map<string, int>& host_compute_core,
|
||||
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
|
||||
string original_oc_name;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(
|
||||
call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
|
||||
NodeDefBuilder host_compute_builder(
|
||||
absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"),
|
||||
"XlaHostCompute");
|
||||
NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
|
||||
"XlaHostCompute");
|
||||
|
||||
// Copy all attributes.
|
||||
for (auto attr : call_node->attrs()) {
|
||||
@ -309,9 +316,25 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
||||
host_compute_builder.Attr("tpu_core", core);
|
||||
}
|
||||
|
||||
// Set input tokens.
|
||||
host_compute_builder.Attr(kXlaTokenInputNodesAttrName,
|
||||
std::vector<string>{kXlaTokenArgNodeName});
|
||||
// Set input tokens and other outside compilation clusters that current
|
||||
// cluster depends in `kXlaTokenArgNodeName`. This is needed because when
|
||||
// outside compilation subgraphs are encapsulated and moved to host graph,
|
||||
// control/data edges between them will only be reflected in host graph.
|
||||
// From XLA's perspective, two originally dependent clusters are no longer
|
||||
// connected, which makes them look like they can be scheduled for execution
|
||||
// in arbitrary order even though in fact they must be executed in order
|
||||
// according to their host-side graph dependency. This can cause deadlock.
|
||||
// Therefore, we hint XLA what the correct ordering of these clusters should
|
||||
// be to avoid deadlocks.
|
||||
std::vector<string> xla_token_input_nodes;
|
||||
xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
|
||||
auto cluster_deps_it = cluster_deps.find(original_oc_name);
|
||||
if (cluster_deps_it != cluster_deps.end()) {
|
||||
for (auto dep : cluster_deps_it->second) {
|
||||
xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
|
||||
}
|
||||
}
|
||||
host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
|
||||
|
||||
// Populate inputs.
|
||||
std::vector<DataType> input_dtypes;
|
||||
@ -371,7 +394,8 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) {
|
||||
// If the function call node has no input/output edges, we will just remove it
|
||||
// and not create a XlaHostCompute node.
|
||||
Status ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core) {
|
||||
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
|
||||
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
|
||||
// If the function call node has no input/output edges, just remove it.
|
||||
bool has_edge = false;
|
||||
for (auto e : call_node->in_edges()) {
|
||||
@ -393,8 +417,9 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
}
|
||||
|
||||
// Build XlaHostCompute NodeDef.
|
||||
TF_ASSIGN_OR_RETURN(NodeDef node_def,
|
||||
BuildXlaHostComputeNodeDef(call_node, host_compute_core));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
NodeDef node_def,
|
||||
BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
|
||||
TF_ASSIGN_OR_RETURN(Node * host_compute_node,
|
||||
ReplaceNode(g, call_node, node_def));
|
||||
VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
|
||||
@ -1589,6 +1614,11 @@ Status ExtractOutsideCompilationForFunction(
|
||||
// We cannot early return here, because we might have outside compilation in
|
||||
// If/While function body.
|
||||
|
||||
// Find dependencies between outside compilation clusters.
|
||||
TF_ASSIGN_OR_RETURN(auto cluster_deps,
|
||||
OutsideCompilationClusterDependencies(
|
||||
fbody->graph, outside_compilation_attr_name));
|
||||
|
||||
// Preprocess edges between different outside compilations. They will be
|
||||
// restored in `ConstructHostGraph()`.
|
||||
TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
|
||||
@ -1643,7 +1673,7 @@ Status ExtractOutsideCompilationForFunction(
|
||||
for (Node* n : outside_compilation_nodes) {
|
||||
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
|
||||
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
graph_out.get(), n, host_compute_core));
|
||||
graph_out.get(), n, host_compute_core, *cluster_deps));
|
||||
}
|
||||
|
||||
// Handle nodes with associated functions.
|
||||
@ -1691,11 +1721,13 @@ Status ExtractOutsideCompilation(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name,
|
||||
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
|
||||
bool* modified) {
|
||||
if (VLOG_IS_ON(4)) {
|
||||
DumpGraphToFile("extract_outside_compilation_before", *g, fld);
|
||||
}
|
||||
|
||||
*modified = false;
|
||||
auto node_name_index = g->BuildNodeNameIndex();
|
||||
for (auto& iter : clusters) {
|
||||
string xla_cluster_name = iter.first;
|
||||
@ -1711,6 +1743,7 @@ Status ExtractOutsideCompilation(
|
||||
func_name_attrs, func_name_attrs.name(), host_graph_func_name,
|
||||
host_compute_core, flr, fld, &shape_inference_graphs,
|
||||
&has_outside_compilation));
|
||||
*modified |= has_outside_compilation;
|
||||
|
||||
string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
|
||||
Node* pivot_node = node_name_index[pivot_name];
|
||||
|
@ -101,7 +101,8 @@ Status ExtractOutsideCompilation(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name,
|
||||
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld);
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
|
||||
bool* modified);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -922,4 +922,145 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ExtractOutsideCompilationForFunctionTest,
|
||||
OutsideCompilationClusterDataDependency) {
|
||||
// Build the XLA computation func.
|
||||
// "const0"
|
||||
// "identity0" = "const0" (outside compilation cluster "0")
|
||||
// "identity1" = "identity0" (outside compilation cluster "1")
|
||||
// "identity2" = "identity1"
|
||||
FunctionDefLibrary fdl;
|
||||
{
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
|
||||
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
|
||||
Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
|
||||
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
|
||||
<< std::endl;
|
||||
auto node_name_image = g->BuildNodeNameIndex();
|
||||
node_name_image["identity0"]->AddAttr("_oc", "0");
|
||||
node_name_image["identity1"]->AddAttr("_oc", "1");
|
||||
|
||||
PartialTensorShape shape({2});
|
||||
node_name_image["identity1"]->AddAttr(
|
||||
kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
|
||||
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
|
||||
}
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
|
||||
|
||||
protobuf::Map<string, tensorflow::AttrValue> attrs;
|
||||
std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
|
||||
std::vector<string> shape_inference_graphs;
|
||||
bool has_outside_compilation;
|
||||
NameAttrList name_attrs;
|
||||
name_attrs.set_name("cluster");
|
||||
*name_attrs.mutable_attr() = attrs;
|
||||
TF_CHECK_OK(ExtractOutsideCompilationTest(
|
||||
"_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
|
||||
host_compute_core, &fld, &shape_inference_graphs,
|
||||
&has_outside_compilation));
|
||||
|
||||
// Get rewritten XLA computation function.
|
||||
std::unique_ptr<FunctionBody> xla_fbody;
|
||||
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
|
||||
AttrSlice(), &fld, &xla_fbody));
|
||||
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
|
||||
|
||||
// Check XlaHostCompute nodes.
|
||||
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
|
||||
EXPECT_NE(host_compute_0, nullptr);
|
||||
Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
|
||||
EXPECT_NE(host_compute_1, nullptr);
|
||||
|
||||
// Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
|
||||
std::vector<string> token_input_nodes;
|
||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
|
||||
"_xla_token_input_nodes", &token_input_nodes));
|
||||
|
||||
std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
|
||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
|
||||
token_input_nodes.clear();
|
||||
std::vector<string> expected_token_input_nodes_1(
|
||||
{"_xla_token_arg_node", "outside_compilation_0_host_compute"});
|
||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
|
||||
"_xla_token_input_nodes", &token_input_nodes));
|
||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
|
||||
}
|
||||
|
||||
TEST_F(ExtractOutsideCompilationForFunctionTest,
|
||||
OutsideCompilationClusterControlDependency) {
|
||||
// Build the XLA computation func.
|
||||
// "const0"
|
||||
// "identity0" = "const0" (outside compilation cluster "0")
|
||||
// "identity1" = "const0" "^identity0" (outside compilation cluster "1",
|
||||
// control depdent on cluster "0")
|
||||
// "identity2" = "identity1"
|
||||
FunctionDefLibrary fdl;
|
||||
{
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
|
||||
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
|
||||
Output identity1 = ops::Identity(
|
||||
s.WithOpName("identity1").WithControlDependencies(identity0), const0);
|
||||
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
|
||||
<< std::endl;
|
||||
auto node_name_image = g->BuildNodeNameIndex();
|
||||
node_name_image["identity0"]->AddAttr("_oc", "0");
|
||||
node_name_image["identity1"]->AddAttr("_oc", "1");
|
||||
|
||||
PartialTensorShape shape({2});
|
||||
node_name_image["identity1"]->AddAttr(
|
||||
kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
|
||||
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
|
||||
}
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
|
||||
|
||||
protobuf::Map<string, tensorflow::AttrValue> attrs;
|
||||
std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
|
||||
std::vector<string> shape_inference_graphs;
|
||||
bool has_outside_compilation;
|
||||
NameAttrList name_attrs;
|
||||
name_attrs.set_name("cluster");
|
||||
*name_attrs.mutable_attr() = attrs;
|
||||
TF_CHECK_OK(ExtractOutsideCompilationTest(
|
||||
"_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
|
||||
host_compute_core, &fld, &shape_inference_graphs,
|
||||
&has_outside_compilation));
|
||||
|
||||
// Get rewritten XLA computation function.
|
||||
std::unique_ptr<FunctionBody> xla_fbody;
|
||||
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
|
||||
AttrSlice(), &fld, &xla_fbody));
|
||||
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
|
||||
|
||||
// Check XlaHostCompute nodes.
|
||||
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
|
||||
EXPECT_NE(host_compute_0, nullptr);
|
||||
Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
|
||||
EXPECT_NE(host_compute_1, nullptr);
|
||||
|
||||
// Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
|
||||
std::vector<string> token_input_nodes;
|
||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
|
||||
"_xla_token_input_nodes", &token_input_nodes));
|
||||
|
||||
std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
|
||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
|
||||
token_input_nodes.clear();
|
||||
std::vector<string> expected_token_input_nodes_1(
|
||||
{"_xla_token_arg_node", "outside_compilation_0_host_compute"});
|
||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
|
||||
"_xla_token_input_nodes", &token_input_nodes));
|
||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -13,11 +13,23 @@ cc_library(
|
||||
srcs = ["graphcycles.cc"],
|
||||
hdrs = ["graphcycles.h"],
|
||||
deps = [
|
||||
":ordered_set",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ordered_set",
|
||||
hdrs = ["ordered_set.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -31,3 +43,14 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "ordered_set_test",
|
||||
srcs = ["ordered_set_test.cc"],
|
||||
deps = [
|
||||
":ordered_set",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
@ -38,13 +38,16 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
typedef std::unordered_set<int32> NodeSet;
|
||||
using NodeSet = absl::flat_hash_set<int32>;
|
||||
using OrderedNodeSet = OrderedSet<int32>;
|
||||
|
||||
template <typename T>
|
||||
struct VecStruct {
|
||||
typedef absl::InlinedVector<T, 4> type;
|
||||
@ -53,13 +56,11 @@ template <typename T>
|
||||
using Vec = typename VecStruct<T>::type;
|
||||
|
||||
struct Node {
|
||||
Node() : in(4), out(4) {} // Small hashtables for in/out edges
|
||||
|
||||
int32 rank; // rank number assigned by Pearce-Kelly algorithm
|
||||
bool visited; // Temporary marker used by depth-first-search
|
||||
void* data; // User-supplied data
|
||||
NodeSet in; // List of immediate predecessor nodes in graph
|
||||
NodeSet out; // List of immediate successor nodes in graph
|
||||
OrderedNodeSet in; // List of immediate predecessor nodes in graph
|
||||
OrderedNodeSet out; // List of immediate successor nodes in graph
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -96,7 +97,7 @@ bool GraphCycles::CheckInvariants() const {
|
||||
if (!ranks.insert(nx->rank).second) {
|
||||
LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank;
|
||||
}
|
||||
for (auto y : nx->out) {
|
||||
for (int32 y : nx->out.GetSequence()) {
|
||||
Node* ny = r->nodes_[y];
|
||||
if (nx->rank >= ny->rank) {
|
||||
LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment "
|
||||
@ -127,14 +128,14 @@ int32 GraphCycles::NewNode() {
|
||||
|
||||
void GraphCycles::RemoveNode(int32 node) {
|
||||
Node* x = rep_->nodes_[node];
|
||||
for (auto y : x->out) {
|
||||
rep_->nodes_[y]->in.erase(node);
|
||||
for (int32 y : x->out.GetSequence()) {
|
||||
rep_->nodes_[y]->in.Erase(node);
|
||||
}
|
||||
for (auto y : x->in) {
|
||||
rep_->nodes_[y]->out.erase(node);
|
||||
for (int32 y : x->in.GetSequence()) {
|
||||
rep_->nodes_[y]->out.Erase(node);
|
||||
}
|
||||
x->in.clear();
|
||||
x->out.clear();
|
||||
x->in.Clear();
|
||||
x->out.Clear();
|
||||
rep_->free_nodes_.push_back(node);
|
||||
}
|
||||
|
||||
@ -147,12 +148,12 @@ void GraphCycles::SetNodeData(int32 node, void* data) {
|
||||
}
|
||||
|
||||
bool GraphCycles::HasEdge(int32 x, int32 y) const {
|
||||
return rep_->nodes_[x]->out.find(y) != rep_->nodes_[x]->out.end();
|
||||
return rep_->nodes_[x]->out.Contains(y);
|
||||
}
|
||||
|
||||
void GraphCycles::RemoveEdge(int32 x, int32 y) {
|
||||
rep_->nodes_[x]->out.erase(y);
|
||||
rep_->nodes_[y]->in.erase(x);
|
||||
rep_->nodes_[x]->out.Erase(y);
|
||||
rep_->nodes_[y]->in.Erase(x);
|
||||
// No need to update the rank assignment since a previous valid
|
||||
// rank assignment remains valid after an edge deletion.
|
||||
}
|
||||
@ -168,13 +169,13 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) {
|
||||
if (x == y) return false;
|
||||
Rep* r = rep_;
|
||||
Node* nx = r->nodes_[x];
|
||||
if (!nx->out.insert(y).second) {
|
||||
if (!nx->out.Insert(y)) {
|
||||
// Edge already exists.
|
||||
return true;
|
||||
}
|
||||
|
||||
Node* ny = r->nodes_[y];
|
||||
ny->in.insert(x);
|
||||
ny->in.Insert(x);
|
||||
|
||||
if (nx->rank <= ny->rank) {
|
||||
// New edge is consistent with existing rank assignment.
|
||||
@ -185,8 +186,8 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) {
|
||||
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
|
||||
if (!ForwardDFS(r, y, nx->rank)) {
|
||||
// Found a cycle. Undo the insertion and tell caller.
|
||||
nx->out.erase(y);
|
||||
ny->in.erase(x);
|
||||
nx->out.Erase(y);
|
||||
ny->in.Erase(x);
|
||||
// Since we do not call Reorder() on this path, clear any visited
|
||||
// markers left by ForwardDFS.
|
||||
ClearVisitedBits(r, r->deltaf_);
|
||||
@ -212,7 +213,7 @@ static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound) {
|
||||
nn->visited = true;
|
||||
r->deltaf_.push_back(n);
|
||||
|
||||
for (auto w : nn->out) {
|
||||
for (auto w : nn->out.GetSequence()) {
|
||||
Node* nw = r->nodes_[w];
|
||||
if (nw->rank == upper_bound) {
|
||||
return false; // Cycle
|
||||
@ -238,7 +239,7 @@ static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound) {
|
||||
nn->visited = true;
|
||||
r->deltab_.push_back(n);
|
||||
|
||||
for (auto w : nn->in) {
|
||||
for (auto w : nn->in.GetSequence()) {
|
||||
Node* nw = r->nodes_[w];
|
||||
if (!nw->visited && lower_bound < nw->rank) {
|
||||
r->stack_.push_back(w);
|
||||
@ -324,7 +325,7 @@ int GraphCycles::FindPath(int32 x, int32 y, int max_path_len,
|
||||
return path_len;
|
||||
}
|
||||
|
||||
for (auto w : r->nodes_[n]->out) {
|
||||
for (auto w : r->nodes_[n]->out.GetSequence()) {
|
||||
if (seen.insert(w).second) {
|
||||
r->stack_.push_back(w);
|
||||
}
|
||||
@ -378,31 +379,35 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
|
||||
}
|
||||
|
||||
Node* nb = rep_->nodes_[b];
|
||||
std::unordered_set<int32> out = std::move(nb->out);
|
||||
std::unordered_set<int32> in = std::move(nb->in);
|
||||
for (auto y : out) {
|
||||
rep_->nodes_[y]->in.erase(b);
|
||||
OrderedNodeSet out = std::move(nb->out);
|
||||
OrderedNodeSet in = std::move(nb->in);
|
||||
for (int32 y : out.GetSequence()) {
|
||||
rep_->nodes_[y]->in.Erase(b);
|
||||
}
|
||||
for (auto y : in) {
|
||||
rep_->nodes_[y]->out.erase(b);
|
||||
for (int32 y : in.GetSequence()) {
|
||||
rep_->nodes_[y]->out.Erase(b);
|
||||
}
|
||||
rep_->free_nodes_.push_back(b);
|
||||
|
||||
for (auto y : out) {
|
||||
rep_->nodes_[a]->out.Reserve(rep_->nodes_[a]->out.Size() + out.Size());
|
||||
for (int32 y : out.GetSequence()) {
|
||||
InsertEdge(a, y);
|
||||
}
|
||||
for (auto y : in) {
|
||||
|
||||
rep_->nodes_[a]->in.Reserve(rep_->nodes_[a]->in.Size() + in.Size());
|
||||
for (int32 y : in.GetSequence()) {
|
||||
InsertEdge(y, a);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_set<int32> GraphCycles::Successors(int32 node) const {
|
||||
return rep_->nodes_[node]->out;
|
||||
absl::Span<const int32> GraphCycles::Successors(int32 node) const {
|
||||
return rep_->nodes_[node]->out.GetSequence();
|
||||
}
|
||||
|
||||
std::unordered_set<int32> GraphCycles::Predecessors(int32 node) const {
|
||||
return rep_->nodes_[node]->in;
|
||||
absl::Span<const int32> GraphCycles::Predecessors(int32 node) const {
|
||||
return rep_->nodes_[node]->in.GetSequence();
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -444,7 +449,7 @@ string GraphCycles::DebugString() const {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int32 succ : rep_->nodes_[i]->out) {
|
||||
for (int32 succ : rep_->nodes_[i]->out.GetSequence()) {
|
||||
absl::StrAppend(&result, " \"", i, "\" -> \"", succ, "\"\n");
|
||||
}
|
||||
}
|
||||
|
@ -40,8 +40,7 @@ limitations under the License.
|
||||
// FindPath() is linear in the size of the graph.
|
||||
// The current implementation uses O(|V|+|E|) space.
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -119,8 +118,8 @@ class GraphCycles {
|
||||
// Expensive: should only be called from graphcycles_test.cc.
|
||||
bool CheckInvariants() const;
|
||||
|
||||
std::unordered_set<int32> Successors(int32 node) const;
|
||||
std::unordered_set<int32> Predecessors(int32 node) const;
|
||||
absl::Span<const int32> Successors(int32 node) const;
|
||||
absl::Span<const int32> Predecessors(int32 node) const;
|
||||
|
||||
// Returns all nodes in post order.
|
||||
//
|
||||
|
85
tensorflow/compiler/jit/graphcycles/ordered_set.h
Normal file
85
tensorflow/compiler/jit/graphcycles/ordered_set.h
Normal file
@ -0,0 +1,85 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// This is a set data structure that provides a deterministic iteration order.
|
||||
// The iteration order of elements only depends on the sequence of
|
||||
// inserts/deletes, so as long as the inserts/deletes happen in the same
|
||||
// sequence, the set will have the same iteration order.
|
||||
//
|
||||
// Assumes that T can be cheaply copied for simplicity.
|
||||
template <typename T>
|
||||
class OrderedSet {
|
||||
public:
|
||||
// Inserts `value` into the ordered set. Returns true if the value was not
|
||||
// present in the set before the insertion.
|
||||
bool Insert(T value) {
|
||||
bool new_insertion =
|
||||
value_to_index_.insert({value, value_sequence_.size()}).second;
|
||||
if (new_insertion) {
|
||||
value_sequence_.push_back(value);
|
||||
}
|
||||
return new_insertion;
|
||||
}
|
||||
|
||||
// Removes `value` from the set. Assumes `value` is already present in the
|
||||
// set.
|
||||
void Erase(T value) {
|
||||
auto it = value_to_index_.find(value);
|
||||
DCHECK(it != value_to_index_.end());
|
||||
|
||||
// Since we don't want to move values around in `value_sequence_` we swap
|
||||
// the value in the last position and with value to be deleted and then
|
||||
// pop_back.
|
||||
value_to_index_[value_sequence_.back()] = it->second;
|
||||
std::swap(value_sequence_[it->second], value_sequence_.back());
|
||||
value_sequence_.pop_back();
|
||||
value_to_index_.erase(it);
|
||||
}
|
||||
|
||||
void Reserve(size_t new_size) {
|
||||
value_to_index_.reserve(new_size);
|
||||
value_sequence_.reserve(new_size);
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
value_to_index_.clear();
|
||||
value_sequence_.clear();
|
||||
}
|
||||
|
||||
bool Contains(T value) const { return value_to_index_.contains(value); }
|
||||
size_t Size() const { return value_sequence_.size(); }
|
||||
|
||||
absl::Span<T const> GetSequence() const { return value_sequence_; }
|
||||
|
||||
private:
|
||||
// The stable order that we maintain through insertions and deletions.
|
||||
std::vector<T> value_sequence_;
|
||||
|
||||
// Maps values to their indices in `value_sequence_`.
|
||||
absl::flat_hash_map<T, int> value_to_index_;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
|
117
tensorflow/compiler/jit/graphcycles/ordered_set_test.cc
Normal file
117
tensorflow/compiler/jit/graphcycles/ordered_set_test.cc
Normal file
@ -0,0 +1,117 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
TEST(OrderedSetTest, Insert) {
|
||||
OrderedSet<int> ordered_set;
|
||||
EXPECT_TRUE(ordered_set.Insert(90));
|
||||
EXPECT_TRUE(ordered_set.Insert(100));
|
||||
EXPECT_TRUE(ordered_set.Insert(80));
|
||||
|
||||
EXPECT_FALSE(ordered_set.Insert(100));
|
||||
|
||||
EXPECT_EQ(ordered_set.Size(), 3);
|
||||
|
||||
EXPECT_TRUE(ordered_set.Contains(90));
|
||||
EXPECT_TRUE(ordered_set.Contains(100));
|
||||
EXPECT_TRUE(ordered_set.Contains(80));
|
||||
|
||||
EXPECT_FALSE(ordered_set.Contains(40));
|
||||
|
||||
std::array<int, 3> expected_sequence = {90, 100, 80};
|
||||
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence);
|
||||
}
|
||||
|
||||
TEST(OrderedSetTest, Erase) {
|
||||
OrderedSet<int> ordered_set;
|
||||
EXPECT_TRUE(ordered_set.Insert(90));
|
||||
EXPECT_TRUE(ordered_set.Insert(100));
|
||||
EXPECT_TRUE(ordered_set.Insert(80));
|
||||
|
||||
ordered_set.Erase(100);
|
||||
|
||||
EXPECT_EQ(ordered_set.Size(), 2);
|
||||
|
||||
EXPECT_TRUE(ordered_set.Contains(90));
|
||||
EXPECT_FALSE(ordered_set.Contains(100));
|
||||
EXPECT_TRUE(ordered_set.Contains(80));
|
||||
|
||||
std::array<int, 2> expected_sequence_0 = {90, 80};
|
||||
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_0);
|
||||
|
||||
ordered_set.Erase(80);
|
||||
|
||||
EXPECT_EQ(ordered_set.Size(), 1);
|
||||
|
||||
EXPECT_TRUE(ordered_set.Contains(90));
|
||||
EXPECT_FALSE(ordered_set.Contains(100));
|
||||
EXPECT_FALSE(ordered_set.Contains(80));
|
||||
|
||||
std::array<int, 1> expected_sequence_1 = {90};
|
||||
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_1);
|
||||
|
||||
ordered_set.Erase(90);
|
||||
|
||||
EXPECT_EQ(ordered_set.Size(), 0);
|
||||
|
||||
EXPECT_FALSE(ordered_set.Contains(90));
|
||||
EXPECT_FALSE(ordered_set.Contains(100));
|
||||
EXPECT_FALSE(ordered_set.Contains(80));
|
||||
|
||||
std::array<int, 0> expected_sequence_2 = {};
|
||||
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_2);
|
||||
}
|
||||
|
||||
TEST(OrderedSetTest, Clear) {
|
||||
OrderedSet<int> ordered_set;
|
||||
EXPECT_TRUE(ordered_set.Insert(90));
|
||||
EXPECT_TRUE(ordered_set.Insert(100));
|
||||
EXPECT_TRUE(ordered_set.Insert(80));
|
||||
|
||||
ordered_set.Clear();
|
||||
|
||||
EXPECT_EQ(ordered_set.Size(), 0);
|
||||
|
||||
EXPECT_FALSE(ordered_set.Contains(90));
|
||||
EXPECT_FALSE(ordered_set.Contains(100));
|
||||
EXPECT_FALSE(ordered_set.Contains(80));
|
||||
|
||||
std::array<int, 0> expected_sequence = {};
|
||||
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence);
|
||||
}
|
||||
|
||||
TEST(OrderedSetTest, LargeInsertions) {
|
||||
const int kSize = 50 * 9000;
|
||||
|
||||
OrderedSet<int> ordered_set;
|
||||
|
||||
for (int i = 0; i < kSize; i++) {
|
||||
EXPECT_TRUE(ordered_set.Insert(i + 500));
|
||||
}
|
||||
|
||||
for (int i = 0; i < kSize; i++) {
|
||||
EXPECT_EQ(ordered_set.GetSequence()[i], i + 500);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -62,7 +62,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
std::unique_ptr<XlaAllocator> xla_allocator;
|
||||
xla::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
se::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
|
||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
|
@ -40,7 +40,7 @@ class XlaPlatformInfo {
|
||||
se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
std::unique_ptr<XlaAllocator> xla_allocator,
|
||||
xla::DeviceMemoryAllocator* device_allocator)
|
||||
se::DeviceMemoryAllocator* device_allocator)
|
||||
: device_type_(device_type),
|
||||
platform_id_(platform_id),
|
||||
xla_device_metadata_(xla_device_metadata),
|
||||
@ -55,7 +55,7 @@ class XlaPlatformInfo {
|
||||
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
|
||||
}
|
||||
|
||||
xla::DeviceMemoryAllocator* allocator() const {
|
||||
se::DeviceMemoryAllocator* allocator() const {
|
||||
return device_allocator_ ? device_allocator_ : xla_allocator_.get();
|
||||
}
|
||||
DeviceType device_type() const { return device_type_; }
|
||||
@ -86,7 +86,7 @@ class XlaPlatformInfo {
|
||||
// then device_allocator_ is null and xla_allocator_ points to an appropriate
|
||||
// XlaAllocator instance.
|
||||
std::unique_ptr<XlaAllocator> xla_allocator_;
|
||||
xla::DeviceMemoryAllocator* device_allocator_;
|
||||
se::DeviceMemoryAllocator* device_allocator_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
|
||||
};
|
||||
|
@ -270,7 +270,7 @@ class MarkForCompilationPassImpl {
|
||||
StatusOr<bool> ShouldCompileCluster(const Cluster& cluster);
|
||||
|
||||
StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(
|
||||
const Cluster& to);
|
||||
const Cluster& from, const Cluster& to);
|
||||
|
||||
// Returns true if the devices in `cluster_a` and `cluster_b` are compatible
|
||||
// and therefore not a hindrance for combining the two clusters into a larger
|
||||
@ -698,7 +698,7 @@ Status MarkForCompilationPassImpl::DumpDebugInfo() {
|
||||
|
||||
StatusOr<bool>
|
||||
MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
|
||||
const Cluster& cluster_to) {
|
||||
const Cluster& cluster_from, const Cluster& cluster_to) {
|
||||
// If any of the consumer's producers are on a different device, do not
|
||||
// cluster these nodes. This prevents other work on this device from being
|
||||
// delayed by work on other devices. We consider predecessors of the entire
|
||||
@ -722,6 +722,11 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
|
||||
if (!devices_compatible) {
|
||||
return true;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(devices_compatible,
|
||||
AreDevicesCompatible(cluster_from, *cluster_in));
|
||||
if (!devices_compatible) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1026,7 +1031,7 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency,
|
||||
ClusteringWillIntroduceInterDeviceDependency(*to));
|
||||
ClusteringWillIntroduceInterDeviceDependency(*from, *to));
|
||||
|
||||
if (will_introduce_cross_device_dependency) {
|
||||
return LogNotContractableAndReturnFalse(
|
||||
@ -1062,8 +1067,16 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
|
||||
StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdgesFrom(
|
||||
Cluster* cluster_from) {
|
||||
bool changed = false;
|
||||
for (int to :
|
||||
cycles_graph_.Successors(cluster_from->cycles_graph_node_id())) {
|
||||
|
||||
// Make a copy of the set of successors because we may modify the graph in
|
||||
// TryToContractEdge.
|
||||
std::vector<int32> successors_copy = [&] {
|
||||
absl::Span<const int32> successors =
|
||||
cycles_graph_.Successors(cluster_from->cycles_graph_node_id());
|
||||
return std::vector<int32>(successors.begin(), successors.end());
|
||||
}();
|
||||
|
||||
for (int to : successors_copy) {
|
||||
iteration_count_++;
|
||||
if (to >= graph_->num_node_ids()) {
|
||||
// Node is a fictitious node that is present only in the cycle detection
|
||||
@ -1265,19 +1278,15 @@ StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
|
||||
DeviceSet devices = cluster_a.devices();
|
||||
devices.UnionWith(cluster_b.devices());
|
||||
|
||||
// First check if we will even be able to pick a device for the larger
|
||||
// combined cluster.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool can_pick_device,
|
||||
CanPickDeviceForXla(device_info_cache_, devices,
|
||||
/*allow_mixing_unknown_and_cpu=*/false));
|
||||
if (!can_pick_device) {
|
||||
absl::optional<jit::DeviceId> maybe_chosen_device,
|
||||
MaybePickDeviceForXla(device_info_cache_, devices,
|
||||
/*allow_mixing_unknown_and_cpu=*/false));
|
||||
if (!maybe_chosen_device.has_value()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
|
||||
PickDeviceForXla(device_info_cache_, devices,
|
||||
/*allow_mixing_unknown_and_cpu=*/false));
|
||||
jit::DeviceId chosen_device = *maybe_chosen_device;
|
||||
|
||||
// If we are able to pick a device `chosen_device` for the larger cluster, the
|
||||
// resource operations in `cluster_a` and `cluster_b` must be placed on the
|
||||
@ -1415,7 +1424,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
|
||||
op_filter.allow_control_trigger = true;
|
||||
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
||||
op_filter.allow_ops_producing_or_consuming_variant = true;
|
||||
op_filter.allow_svd_op = true;
|
||||
op_filter.allow_slow_and_inaccurate_ops = true;
|
||||
|
||||
return RecursiveCompilabilityChecker{&op_filter, &jit_device_type}
|
||||
.IsCompilableCall(ndef, flr);
|
||||
|
@ -1112,6 +1112,45 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) {
|
||||
EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) {
|
||||
// This is similar to the 'DontClusterMergingNodes' above, except
|
||||
// MatMulCombined is placed on the CPU.
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
absl::string_view xla_gpu_dev0 = "/job:worker/replica:0/task:0/device:GPU:0";
|
||||
absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:GPU:1";
|
||||
absl::string_view xla_cpu_dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
|
||||
ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
|
||||
Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
|
||||
ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
|
||||
Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
|
||||
Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
|
||||
|
||||
Output combined =
|
||||
ops::MatMul(root.WithOpName("MatMulCombined_cpu"), matmul0, matmul1);
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) {
|
||||
n->set_assigned_device_name(string(xla_cpu_dev0));
|
||||
} else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
|
||||
n->set_assigned_device_name(string(xla_gpu_dev0));
|
||||
} else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
|
||||
n->set_assigned_device_name(string(xla_gpu_dev1));
|
||||
}
|
||||
}
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
// Each of the MatMuls should be in a separate cluster.
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
|
||||
EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]);
|
||||
EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]);
|
||||
EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
|
||||
EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
|
||||
}
|
||||
|
||||
// TODO(b/117085735): This form of clustering should be prevented.
|
||||
TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
|
||||
// MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed
|
||||
|
@ -60,7 +60,7 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
registration.cluster_svd_op = true;
|
||||
registration.cluster_slow_and_inaccurate_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
|
||||
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
|
@ -95,7 +95,7 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
registration.cluster_svd_op = true;
|
||||
registration.cluster_slow_and_inaccurate_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
|
||||
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
|
@ -63,7 +63,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
registration.cluster_svd_op = true;
|
||||
registration.cluster_slow_and_inaccurate_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
|
||||
registration);
|
||||
|
||||
|
@ -168,11 +168,11 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
}
|
||||
|
||||
XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
|
||||
: xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
|
||||
: se::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
|
||||
|
||||
XlaAllocator::~XlaAllocator() {}
|
||||
|
||||
xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
|
||||
xla::StatusOr<se::OwningDeviceMemory> XlaAllocator::Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) {
|
||||
AllocationAttributes attrs;
|
||||
attrs.no_retry_on_failure = !retry_on_failure;
|
||||
@ -184,8 +184,8 @@ xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
|
||||
"Out of memory while trying to allocate ", size, " bytes.");
|
||||
}
|
||||
}
|
||||
return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
|
||||
device_ordinal, this);
|
||||
return se::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
|
||||
device_ordinal, this);
|
||||
}
|
||||
|
||||
Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
|
||||
@ -194,7 +194,7 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
|
||||
}
|
||||
|
||||
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
|
||||
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
|
||||
bool allocate_xla_tensors, bool use_multiple_streams)
|
||||
: client_(client),
|
||||
xla_allocator_(xla_allocator),
|
||||
@ -374,7 +374,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
} else {
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
@ -435,7 +435,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
write.type, write.shape, buffer, allocator);
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
|
@ -23,14 +23,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/owning_device_memory.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class XlaAllocator;
|
||||
@ -108,11 +108,11 @@ Status LockVariables(absl::Span<VariableInfo> variables)
|
||||
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
|
||||
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
|
||||
// see comment on `AllowsAsynchronousDeallocation()`.
|
||||
class XlaAllocator : public xla::DeviceMemoryAllocator {
|
||||
class XlaAllocator : public se::DeviceMemoryAllocator {
|
||||
public:
|
||||
XlaAllocator(const se::Platform* platform, Allocator* wrapped);
|
||||
~XlaAllocator() override;
|
||||
xla::StatusOr<xla::OwningDeviceMemory> Allocate(
|
||||
xla::StatusOr<se::OwningDeviceMemory> Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) override;
|
||||
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
|
||||
|
||||
@ -142,7 +142,7 @@ class XlaComputationLaunchContext {
|
||||
// because we track inter-stream dependencies through events inside XlaTensor
|
||||
// objects.
|
||||
XlaComputationLaunchContext(xla::LocalClient* client,
|
||||
xla::DeviceMemoryAllocator* xla_allocator,
|
||||
se::DeviceMemoryAllocator* xla_allocator,
|
||||
bool allocate_xla_tensors,
|
||||
bool use_multiple_streams);
|
||||
|
||||
@ -186,7 +186,7 @@ class XlaComputationLaunchContext {
|
||||
|
||||
private:
|
||||
xla::LocalClient* client_;
|
||||
xla::DeviceMemoryAllocator* xla_allocator_;
|
||||
se::DeviceMemoryAllocator* xla_allocator_;
|
||||
bool allocate_xla_tensors_;
|
||||
bool use_multiple_streams_;
|
||||
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
|
||||
|
@ -59,7 +59,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype,
|
||||
xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
|
||||
uint64 size =
|
||||
client->backend().transfer_manager()->GetByteSizeRequirement(subshape);
|
||||
TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
|
||||
TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer,
|
||||
client->backend().memory_allocator()->Allocate(
|
||||
device_ordinal, size, /*retry_on_failure=*/false));
|
||||
// Move our buffer into shaped_buffer, which takes ownership of it.
|
||||
|
@ -458,10 +458,6 @@ tf_xla_py_test(
|
||||
name = "extract_image_patches_op_test",
|
||||
size = "small",
|
||||
srcs = ["extract_image_patches_op_test.py"],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -57,7 +57,7 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
@ -790,6 +790,14 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor {
|
||||
float getDynamicRangeMax() const override { return 0.f; }
|
||||
#endif
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
void setAllowedFormats(nvinfer1::TensorFormats formats) override {}
|
||||
|
||||
nvinfer1::TensorFormats getAllowedFormats() const override { return 1; }
|
||||
|
||||
bool isShape() const override { return false; }
|
||||
#endif
|
||||
|
||||
private:
|
||||
nvinfer1::DataType trt_dtype_;
|
||||
nvinfer1::Dims trt_dims_;
|
||||
@ -4455,6 +4463,40 @@ Status ConvertDepthSpaceShuffle(OpConverterParams* params) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertSquaredDifference(OpConverterParams* params) {
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", false}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
// Broadcast inputs.
|
||||
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
|
||||
TF_RETURN_IF_ERROR(params->converter->GetTrtBroadcastShape(
|
||||
inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r));
|
||||
nvinfer1::ITensor* tensor_l = nullptr;
|
||||
nvinfer1::ITensor* tensor_r = nullptr;
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l));
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r));
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// Subtract x - y.
|
||||
nvinfer1::IElementWiseLayer* sub =
|
||||
params->converter->network()->addElementWise(
|
||||
*tensor_l, *tensor_r, nvinfer1::ElementWiseOperation::kSUB);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(sub, node_def.name());
|
||||
// Multiply (x - y) * (x - y).
|
||||
nvinfer1::IElementWiseLayer* mul =
|
||||
params->converter->network()->addElementWise(
|
||||
*sub->getOutput(0), *sub->getOutput(0),
|
||||
nvinfer1::ElementWiseOperation::kPROD);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(mul, node_def.name());
|
||||
|
||||
params->outputs->push_back(TRT_TensorOrWeights(mul->getOutput(0)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
|
||||
Status ConvertCombinedNMS(OpConverterParams* params) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -4641,7 +4683,6 @@ static void RegisterValidatableOpConverters(
|
||||
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
|
||||
(*registration)["ExpandDims"] = ConvertExpandDims;
|
||||
(*registration)["GatherV2"] = ConvertGather;
|
||||
(*registration)["Identity"] = ConvertIdentity; // Identity should be removed
|
||||
(*registration)["LeakyRelu"] = ConvertLeakyRelu;
|
||||
(*registration)["MatMul"] = ConvertMatMul;
|
||||
(*registration)["Pack"] = ConvertPack;
|
||||
@ -4650,11 +4691,11 @@ static void RegisterValidatableOpConverters(
|
||||
(*registration)["Reshape"] = ConvertReshape;
|
||||
(*registration)["Rsqrt"] = ConvertRsqrt;
|
||||
(*registration)["Slice"] = ConvertSlice;
|
||||
(*registration)["Snapshot"] = ConvertIdentity; // Snapshot should be removed
|
||||
(*registration)["Softmax"] = ConvertSoftmax;
|
||||
(*registration)["SpaceToDepth"] = ConvertDepthSpaceShuffle;
|
||||
(*registration)["Split"] = ConvertSplit;
|
||||
(*registration)["Square"] = ConvertSquare;
|
||||
(*registration)["SquaredDifference"] = ConvertSquaredDifference;
|
||||
(*registration)["Squeeze"] = ConvertSqueeze;
|
||||
(*registration)["StridedSlice"] = ConvertStridedSlice;
|
||||
(*registration)["TopKV2"] = ConvertTopK;
|
||||
@ -4688,6 +4729,11 @@ static void RegisterValidatableOpConverters(
|
||||
for (auto arg_minmax_type : {"ArgMin", "ArgMax"}) {
|
||||
(*registration)[arg_minmax_type] = ConvertArgMinMax;
|
||||
}
|
||||
// The following are no-ops during inference and will not be mapped to any TRT
|
||||
// layer.
|
||||
for (auto identity_op_type : {"Identity", "Snapshot", "StopGradient"}) {
|
||||
(*registration)[identity_op_type] = ConvertIdentity;
|
||||
}
|
||||
}
|
||||
|
||||
void TrtNodeValidator::RegisterOpValidators() {
|
||||
|
@ -50,8 +50,8 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda/include/cuda.h"
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -280,6 +280,14 @@ class FakeITensor : public nvinfer1::ITensor {
|
||||
float getDynamicRangeMax() const override { return 0.f; }
|
||||
#endif
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
void setAllowedFormats(nvinfer1::TensorFormats formats) override {}
|
||||
|
||||
nvinfer1::TensorFormats getAllowedFormats() const override { return 1; }
|
||||
|
||||
bool isShape() const override { return false; }
|
||||
#endif
|
||||
|
||||
private:
|
||||
string name_;
|
||||
nvinfer1::Dims dims_;
|
||||
@ -5353,6 +5361,108 @@ TEST_F(OpConverterTest, ConvertClipByValue) {
|
||||
}
|
||||
#endif // IS_TRT_VERSION_GE(5, 1, 2, 0)
|
||||
|
||||
// Get the NodeDef for SquaredDifference.
|
||||
NodeDef GetSquaredDifferenceNodeDef(DataType dtype) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto x = ops::Placeholder(s.WithOpName("x"), dtype);
|
||||
auto y = ops::Placeholder(s.WithOpName("y"), dtype);
|
||||
auto squared_diff =
|
||||
ops::SquaredDifference(s.WithOpName("my_squared_diff"), x, y);
|
||||
return squared_diff.operation.node()->def();
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
void TestConvertSquaredDifference(OpConverterTest* test) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
|
||||
struct TestParams {
|
||||
std::vector<int> dims_x;
|
||||
std::vector<int> dims_y;
|
||||
std::vector<CType> value_x;
|
||||
std::vector<CType> value_y;
|
||||
std::vector<int> expected_output_dims;
|
||||
std::vector<CType> expected_output;
|
||||
};
|
||||
|
||||
const std::vector<CType> common_input = InitTestVector<CType>(6);
|
||||
std::vector<TestParams> params = {
|
||||
{
|
||||
/*dims_x=*/{1, 2, 3},
|
||||
/*dims_y=*/{1, 2, 3},
|
||||
/*value_x=*/common_input,
|
||||
/*value_y=*/CastTestVector<int, CType>({0, -1, 3, 0, 10, -7}),
|
||||
/*expected_output_dims=*/{1, 2, 3},
|
||||
/*expected_output=*/CastTestVector<int, CType>({0, 4, 1, 9, 36, 144}),
|
||||
},
|
||||
{
|
||||
/*dims_x=*/{1, 2, 3},
|
||||
/*dims_y=*/{1, 1, 3},
|
||||
/*value_x=*/common_input,
|
||||
/*value_y=*/CastTestVector<int, CType>({0, 1, 2}),
|
||||
/*expected_output_dims=*/{1, 2, 3},
|
||||
/*expected_output=*/CastTestVector<int, CType>({0, 0, 0, 9, 9, 9}),
|
||||
},
|
||||
};
|
||||
|
||||
for (int i = 0; i < params.size(); ++i) {
|
||||
test->Reset();
|
||||
|
||||
NodeDef node_def = GetSquaredDifferenceNodeDef(dtype);
|
||||
test->AddTestTensor("x", params[i].dims_x, 1, TfDataTypeToTrt(dtype));
|
||||
test->AddTestTensor("y", params[i].dims_y, 1, TfDataTypeToTrt(dtype));
|
||||
test->RunValidationAndConversion(node_def);
|
||||
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_squared_diff", &output));
|
||||
EXPECT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray(params[i].expected_output_dims,
|
||||
output.tensor()->getDimensions());
|
||||
|
||||
DataVec input_data{{"x", test::AsTensor<CType>(params[i].value_x)},
|
||||
{"y", test::AsTensor<CType>(params[i].value_y)}};
|
||||
DataVec output_data{
|
||||
{"my_squared_diff",
|
||||
ConstructTensor<CType>(params[i].expected_output.size())}};
|
||||
test->BuildAndRun(
|
||||
input_data, &output_data,
|
||||
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAreArray(params[i].expected_output));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertSquaredDifference) {
|
||||
{
|
||||
// Input list is empty, should fail.
|
||||
NodeDef node_def = MakeNodeDef("my_squared_diff", "SquaredDifference", {});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"SquaredDifference got 0 inputs but expected 2, at my_squared_diff");
|
||||
}
|
||||
{
|
||||
// Input is a weight, should fail.
|
||||
Reset();
|
||||
NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT);
|
||||
AddTestWeights<float>("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
|
||||
AddTestTensor("y", {1, 2, 3});
|
||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||
"The input \"x\" for SquaredDifference must be "
|
||||
"a tensor, at my_squared_diff");
|
||||
}
|
||||
{
|
||||
// Shapes are not broadcastable, should fail.
|
||||
Reset();
|
||||
NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT);
|
||||
AddTestTensor("x", {2, 3});
|
||||
AddTestTensor("y", {7, 5});
|
||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||
"Infeasible broadcast scheme");
|
||||
}
|
||||
|
||||
TestConvertSquaredDifference<DT_FLOAT>(this);
|
||||
TestConvertSquaredDifference<DT_HALF>(this);
|
||||
}
|
||||
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
@ -41,7 +41,7 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -259,7 +259,6 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
||||
}
|
||||
auto lib = ctx->function_library();
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
opts.rendezvous = ctx->rendezvous();
|
||||
opts.cancellation_manager = ctx->cancellation_manager();
|
||||
opts.runner = ctx->runner();
|
||||
|
@ -32,7 +32,7 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda/include/cuda.h"
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -27,7 +27,10 @@ package(
|
||||
default_visibility = [":internal"],
|
||||
)
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
)
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
|
||||
|
||||
cc_library(
|
||||
|
@ -253,6 +253,7 @@ Status FunctionalizeControlFlowPass::Run(
|
||||
{"XlaLaunch", "function"},
|
||||
};
|
||||
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
|
||||
bool fld_modified = false;
|
||||
for (Node* n : graph->nodes()) {
|
||||
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
|
||||
if (it == kNodeTypeToFunctionAttrMapping->end()) {
|
||||
@ -273,9 +274,16 @@ Status FunctionalizeControlFlowPass::Run(
|
||||
n->ClearAttr(func_attr);
|
||||
func.set_name(new_func_name);
|
||||
n->AddAttr(func_attr, func);
|
||||
|
||||
fld_modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (fld_modified) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PruneUnreachableFunctionsFromGraph(*graph, options.flib_def));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(4)) {
|
||||
DumpGraphToFile("functionalize_control_flow_after", *graph,
|
||||
options.flib_def);
|
||||
|
@ -367,7 +367,7 @@ cc_library(
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
|
||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
@ -380,7 +380,7 @@ cc_library(
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
|
||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
|
@ -17,7 +17,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -99,23 +101,22 @@ class ExtractImagePatchesOp : public XlaOpKernel {
|
||||
// The following code is equivalent to:
|
||||
// eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
|
||||
int64 kernel_size = 1;
|
||||
std::vector<int64> lhs_shape(num_dims, 1);
|
||||
std::vector<int64> kernel_shape(num_dims, 1);
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
|
||||
lhs_shape[i] = ksizes_[input_dim];
|
||||
kernel_shape[i] = ksizes_[input_dim];
|
||||
kernel_size *= ksizes_[input_dim];
|
||||
}
|
||||
lhs_shape[num_spatial_dims] = depth;
|
||||
lhs_shape[num_spatial_dims + 1] = 1;
|
||||
|
||||
// Builds an identity matrix as a broadcast equality of iotas.
|
||||
// iota = np.arange(np.prod(ksize), depth)
|
||||
// filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
|
||||
xla::XlaOp iota = xla::Iota(builder, xla::S32, kernel_size * depth);
|
||||
|
||||
auto lhs = xla::Reshape(iota, lhs_shape);
|
||||
auto filter = xla::ConvertElementType(
|
||||
xla::Eq(lhs, iota, {num_spatial_dims + 1}), type);
|
||||
kernel_shape[num_spatial_dims] = 1;
|
||||
kernel_shape[num_spatial_dims + 1] = kernel_size * depth;
|
||||
xla::Shape iota_kernel_shape =
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {kernel_size, depth, kernel_size});
|
||||
xla::XlaOp filter =
|
||||
xla::Reshape(xla::ConvertElementType(
|
||||
xla::Eq(xla::Iota(builder, iota_kernel_shape, 0),
|
||||
xla::Iota(builder, iota_kernel_shape, 2)),
|
||||
type),
|
||||
kernel_shape);
|
||||
|
||||
xla::ConvolutionDimensionNumbers dims;
|
||||
std::vector<int64> window_strides(num_spatial_dims);
|
||||
@ -148,7 +149,7 @@ class ExtractImagePatchesOp : public XlaOpKernel {
|
||||
|
||||
xla::XlaOp conv =
|
||||
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dims);
|
||||
lhs_dilation, rhs_dilation, dims, depth);
|
||||
ctx->SetOutput(0, conv);
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -46,4 +46,4 @@ extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
|
||||
tensorflow::argmax_float_1d_xla_impl(out, data);
|
||||
}
|
||||
|
||||
REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -51,4 +51,4 @@ extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
|
||||
tensorflow::argmax_float_2d_xla_impl(out, data);
|
||||
}
|
||||
|
||||
REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
// XLA specific pooling ops.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
@ -327,6 +328,20 @@ class MaxPoolGradOp : public XlaOpKernel {
|
||||
xla::Padding xla_padding =
|
||||
(padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
|
||||
|
||||
// Create a MaxPool operation to check the expected resulting shape, and
|
||||
// then throw away the operation because we don't actually neeed it here.
|
||||
TensorShape expected_out_shape;
|
||||
auto pooling =
|
||||
xla::MaxPool(ctx->Input(0), ksize_, stride_, xla_padding,
|
||||
XlaTensorFormat(data_format_, tensor_in_shape.dims() - 2));
|
||||
auto status_or_shape = pooling.builder()->GetShape(pooling);
|
||||
OP_REQUIRES_OK(ctx, status_or_shape.status());
|
||||
OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(status_or_shape.ValueOrDie(),
|
||||
&expected_out_shape));
|
||||
OP_REQUIRES(ctx, expected_out_shape == out_backprop_shape,
|
||||
errors::Unimplemented("The output dimensions do not match the "
|
||||
"other input values."));
|
||||
|
||||
xla::PrimitiveType element_type;
|
||||
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
|
||||
xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
|
||||
|
@ -733,6 +733,7 @@ Status RearrangeFunctionArgumentPass::Run(
|
||||
{"XlaLaunch", "function"},
|
||||
};
|
||||
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
|
||||
bool fld_modified = false;
|
||||
for (Node* n : graph->nodes()) {
|
||||
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
|
||||
if (it == kNodeTypeToFunctionAttrMapping->end()) {
|
||||
@ -753,8 +754,14 @@ Status RearrangeFunctionArgumentPass::Run(
|
||||
n->ClearAttr(func_attr);
|
||||
func.set_name(new_func_name);
|
||||
n->AddAttr(func_attr, func);
|
||||
|
||||
fld_modified = true;
|
||||
}
|
||||
}
|
||||
if (fld_modified) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(4)) {
|
||||
DumpGraphToFile("rearrange_function_argument_after", *graph,
|
||||
|
@ -773,4 +773,17 @@ Status PropagateConstIntoFunctionalNodes(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
|
||||
FunctionLibraryDefinition* fld) {
|
||||
GraphDef graph_def;
|
||||
g.ToGraphDef(&graph_def);
|
||||
FunctionLibraryDefinition reachable_functions =
|
||||
fld->ReachableDefinitions(graph_def);
|
||||
for (const string& func_name : fld->ListFunctionNames()) {
|
||||
if (!reachable_functions.Find(func_name)) {
|
||||
TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -197,6 +198,10 @@ Status PropagateConstIntoFunctionalNodes(
|
||||
Graph* g, const FunctionLibraryDefinition* lookup_fld,
|
||||
FunctionLibraryDefinition* fld);
|
||||
|
||||
// Prunes unreachable FunctionDefs from FunctionLibraryDefinition.
|
||||
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
|
||||
FunctionLibraryDefinition* fld);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
|
||||
|
@ -58,18 +58,13 @@ class XlaCompilationAllocator : public Allocator {
|
||||
|
||||
// Make sure that even tensors with 0 elements have allocated
|
||||
// buffers, so they get ids to track.
|
||||
bool ShouldAllocateEmptyTensors() const override { return true; }
|
||||
|
||||
private:
|
||||
// Don't run any constructors or destructors for complex objects,
|
||||
// since there is no backing store for the tensor to run them
|
||||
// on. strings are the only complex objects currently stored in
|
||||
// Tensors. If others are added, this set of overrides must be
|
||||
// extended to include them.
|
||||
void RunStringCtor(string* p, size_t n) override {}
|
||||
void RunStringDtor(string* p, size_t n) override {}
|
||||
void RunResourceCtor(ResourceHandle* p, size_t n) override {}
|
||||
void RunResourceDtor(ResourceHandle* p, size_t n) override {}
|
||||
//
|
||||
// NOTE: It is the caller's responsibility to track whether an allocated
|
||||
// object is a buffer or an opaque handle. In particular, when this allocator
|
||||
// is used, the caller must not run any constructors or destructors for
|
||||
// complex objects, since there is no backing store for the tensor in which to
|
||||
// place their outputs.
|
||||
bool AllocatesOpaqueHandle() const override { return true; }
|
||||
};
|
||||
|
||||
XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
|
||||
|
@ -339,7 +339,7 @@ class XlaCompiler {
|
||||
// here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
|
||||
// allocate most or all available memory on the device, leaving none for the
|
||||
// compiler to access, unless it can use TensorFlow's allocator.
|
||||
xla::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
se::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
};
|
||||
|
||||
explicit XlaCompiler(Options options);
|
||||
|
@ -116,10 +116,9 @@ class XlaOpRegistry {
|
||||
// If we should cluster operations returning DT_VARIANT.
|
||||
bool cluster_variant_ops = false;
|
||||
|
||||
// If we should cluster the "Svd" op. The XLA implemenation of this op has
|
||||
// some performance (b/128001705) and possibly correctness (b/127344411)
|
||||
// issues so we avoid auto-clustering it for non XLA_* devices.
|
||||
bool cluster_svd_op = false;
|
||||
// Whether ops known to be slow or to have correctness issues should be
|
||||
// auto-clustered.
|
||||
bool cluster_slow_and_inaccurate_ops = false;
|
||||
};
|
||||
|
||||
// Registers an XLA backend. `compilation_device_name` is the name of the
|
||||
|
@ -96,7 +96,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
@ -117,7 +117,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:dump",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto",
|
||||
@ -126,6 +125,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:source_map_util",
|
||||
"//tensorflow/compiler/xla/service:stream_pool",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm//:support",
|
||||
@ -165,11 +165,11 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
"//tensorflow/compiler/xla/service:compile_only_service",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:local_service",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
|
@ -31,7 +31,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/service/compile_only_service.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/local_service.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -39,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
@ -22,12 +22,12 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator(
|
||||
DeviceMemoryAllocator* allocator) {
|
||||
se::DeviceMemoryAllocator* allocator) {
|
||||
device_allocator_ = allocator;
|
||||
return *this;
|
||||
}
|
||||
|
||||
DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const {
|
||||
se::DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const {
|
||||
return device_allocator_;
|
||||
}
|
||||
|
||||
|
@ -18,11 +18,11 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -57,11 +57,11 @@ class ExecutableBuildOptions {
|
||||
// want to run various algorithms on the device and pick the fastest one -- it
|
||||
// might allocate buffers for use by these algorithms using this allocator.
|
||||
//
|
||||
// This does not need to be the same as the DeviceMemoryAllocator passed when
|
||||
// running the executable.
|
||||
// This does not need to be the same as the se::DeviceMemoryAllocator passed
|
||||
// when running the executable.
|
||||
ExecutableBuildOptions& set_device_allocator(
|
||||
DeviceMemoryAllocator* allocator);
|
||||
DeviceMemoryAllocator* device_allocator() const;
|
||||
se::DeviceMemoryAllocator* allocator);
|
||||
se::DeviceMemoryAllocator* device_allocator() const;
|
||||
|
||||
// Returns a string representation of the build options, suitable for
|
||||
// debugging.
|
||||
@ -77,7 +77,7 @@ class ExecutableBuildOptions {
|
||||
Shape result_layout_;
|
||||
bool result_layout_set_ = false;
|
||||
absl::optional<DebugOptions> debug_options_;
|
||||
DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||
int num_replicas_ = 1;
|
||||
};
|
||||
|
||||
|
@ -528,7 +528,9 @@ XlaOp Asin(XlaOp x) {
|
||||
|
||||
XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
|
||||
|
||||
XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); }
|
||||
XlaOp Tan(XlaOp x) {
|
||||
return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); });
|
||||
}
|
||||
|
||||
// Hyperbolic trigonometric functions.
|
||||
|
||||
@ -574,9 +576,9 @@ XlaOp Acosh(XlaOp x) {
|
||||
// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
|
||||
// as 2*x and return log(2) + log(x).
|
||||
//
|
||||
// If x is negative, the above would give us some trouble, because we'd need to
|
||||
// approximate x + sqrt(sqrt(x^2 + 1) - abs(x). But we're saved
|
||||
// by the fact that asinh(-x) = -asinh(x).
|
||||
// If x is negative, the above would give us some trouble; we can't approximate
|
||||
// the result as x + abs(x) = 0! But we're saved by the fact that asinh(-x) =
|
||||
// -asinh(x).
|
||||
XlaOp Asinh(XlaOp x) {
|
||||
XlaBuilder* b = x.builder();
|
||||
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
|
||||
@ -636,9 +638,39 @@ XlaOp Atanh(XlaOp x) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); }
|
||||
// Cosh(x) = (e^x + e^-x) / 2
|
||||
// = e^(x + log(1/2)) + e^(-x + log(1/2)).
|
||||
//
|
||||
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
|
||||
// inf.
|
||||
//
|
||||
// This incorrectly overflows to inf for two f32 input values, namely
|
||||
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
|
||||
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
|
||||
// we deem this acceptable.
|
||||
XlaOp Cosh(XlaOp x) {
|
||||
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
|
||||
auto log_one_half = Log(ScalarLike(x, 0.5));
|
||||
return Exp(x + log_one_half) + Exp(-x + log_one_half);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); }
|
||||
// Sinh(x) = (e^x - e^-x) / 2
|
||||
// = e^(x + log(1/2)) - e^(-x + log(1/2)).
|
||||
//
|
||||
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
|
||||
// inf.
|
||||
//
|
||||
// This incorrectly overflows to +/-inf for two f32 input values, namely
|
||||
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
|
||||
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
|
||||
// we deem this acceptable.
|
||||
XlaOp Sinh(XlaOp x) {
|
||||
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
|
||||
auto log_one_half = Log(ScalarLike(x, 0.5));
|
||||
return Exp(x + log_one_half) - Exp(-x + log_one_half);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
|
@ -279,7 +279,7 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
|
||||
|
||||
StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
|
||||
const LiteralSlice& literal, int device_ordinal,
|
||||
DeviceMemoryAllocator* allocator) {
|
||||
se::DeviceMemoryAllocator* allocator) {
|
||||
if (allocator == nullptr) {
|
||||
allocator = backend().memory_allocator();
|
||||
}
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/local_service.h"
|
||||
@ -32,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -137,7 +137,7 @@ class LocalClient : public Client {
|
||||
// device is used.
|
||||
StatusOr<ScopedShapedBuffer> LiteralToShapedBuffer(
|
||||
const LiteralSlice& literal, int device_ordinal,
|
||||
DeviceMemoryAllocator* allocator = nullptr);
|
||||
se::DeviceMemoryAllocator* allocator = nullptr);
|
||||
|
||||
// Transfer the BorrowingLiteral to the device with the given ordinal.
|
||||
StatusOr<TransferToServerResponse> TransferToLocalServer(
|
||||
|
@ -26,12 +26,13 @@ ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal(
|
||||
int ExecutableRunOptions::device_ordinal() const { return device_ordinal_; }
|
||||
|
||||
ExecutableRunOptions& ExecutableRunOptions::set_allocator(
|
||||
DeviceMemoryAllocator* allocator) {
|
||||
stream_executor::DeviceMemoryAllocator* allocator) {
|
||||
allocator_ = allocator;
|
||||
return *this;
|
||||
}
|
||||
|
||||
DeviceMemoryAllocator* ExecutableRunOptions::allocator() const {
|
||||
stream_executor::DeviceMemoryAllocator* ExecutableRunOptions::allocator()
|
||||
const {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
namespace stream_executor {
|
||||
class Stream;
|
||||
class Platform;
|
||||
class DeviceMemoryAllocator;
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace Eigen {
|
||||
@ -31,7 +32,6 @@ struct ThreadPoolDevice;
|
||||
|
||||
namespace xla {
|
||||
|
||||
class DeviceMemoryAllocator;
|
||||
class DeviceAssignment;
|
||||
class ExecutionProfile;
|
||||
|
||||
@ -39,8 +39,9 @@ class ExecutionProfile;
|
||||
class ExecutableRunOptions {
|
||||
public:
|
||||
// Specifies the allocator to use during execution.
|
||||
ExecutableRunOptions& set_allocator(DeviceMemoryAllocator* allocator);
|
||||
DeviceMemoryAllocator* allocator() const;
|
||||
ExecutableRunOptions& set_allocator(
|
||||
stream_executor::DeviceMemoryAllocator* allocator);
|
||||
stream_executor::DeviceMemoryAllocator* allocator() const;
|
||||
|
||||
// If set, this is the device to run the computation on. Valid device_ordinal
|
||||
// values are: 0 to # of devices - 1. These values are identical to the device
|
||||
@ -87,7 +88,7 @@ class ExecutableRunOptions {
|
||||
int rng_seed() const;
|
||||
|
||||
private:
|
||||
DeviceMemoryAllocator* allocator_ = nullptr;
|
||||
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
|
||||
int device_ordinal_ = -1;
|
||||
const DeviceAssignment* device_assignment_ = nullptr;
|
||||
stream_executor::Stream* stream_ = nullptr;
|
||||
|
@ -29,6 +29,8 @@ upper_tabs:
|
||||
path: /xla/tiled_layout
|
||||
- title: Using AOT compilation
|
||||
path: /xla/tfcompile
|
||||
- title: Writing custom calls
|
||||
path: /xla/custom_call
|
||||
- heading: Tutorials
|
||||
- title: XLA compile API
|
||||
path: /xla/tutorials/xla_compile
|
||||
|
329
tensorflow/compiler/xla/g3doc/custom_call.md
Normal file
329
tensorflow/compiler/xla/g3doc/custom_call.md
Normal file
@ -0,0 +1,329 @@
|
||||
# XLA Custom Calls
|
||||
|
||||
This document describes how to write and use XLA "custom calls". Custom calls
|
||||
let you invoke code written in a programming language like C++ or CUDA from an
|
||||
XLA program.
|
||||
|
||||
Warning: Custom calls are a low-level power-user feature. It is easy to break
|
||||
your program in difficult-to-debug (and even difficult-to-notice) ways using
|
||||
custom-calls. You shouldn't use custom calls unless you're prepared to debug XLA
|
||||
yourself when something goes wrong, and you should expect relatively less
|
||||
assistance from XLA developers if you run into trouble.
|
||||
|
||||
Warning: The custom-call API/ABI is not currently stable. We don't intend to
|
||||
change it capriciously, but it may change. Some possible future changes are
|
||||
described below.
|
||||
|
||||
## Custom-call on CPU
|
||||
|
||||
You can create an HLO instruction which represents a custom-call via XLA's
|
||||
client API. This is not exposed via TensorFlow as of writing.
|
||||
|
||||
For example, the following code uses a custom-call to compute
|
||||
`A[i] = B[i % 128] + C[i]` on the CPU. (Of course you could -- and should! -- do
|
||||
this with regular HLO.)
|
||||
|
||||
```c++
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
|
||||
void do_it() {
|
||||
xla::XlaBuilder b("do_it");
|
||||
xla::XlaOp param0 =
|
||||
xla::Parameter(0, xla::ShapeUtil::CreateShape(F32, {128}), "p0");
|
||||
xla::XlaOp param1 =
|
||||
xla::Parameter(1, xla::ShapeUtil::CreateShape(F32, {2048}), "p1");
|
||||
xla::XlaOp custom_call =
|
||||
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
|
||||
/*output_shape=*/ShapeUtil::CreateShape(F32, {2048}));
|
||||
}
|
||||
|
||||
void do_custom_call(void* out, const void** in) {
|
||||
float* out_buf = reinterpret_cast<float*>(out);
|
||||
const float* in0 = reinterpret_cast<const float*>(in[0]);
|
||||
const float* in1 = reinterpret_cast<const float*>(in[1]);
|
||||
for (int i = 0; i < 2048; ++i) {
|
||||
out_buf[i] = in0[i % 128] + in1[i];
|
||||
}
|
||||
}
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");
|
||||
```
|
||||
|
||||
Notice that the function `do_custom_call` needs to know the dimensions of the
|
||||
buffers it operates over. In this example we hardcode the sizes 128 and 2048. If
|
||||
you don't want to do this, you can pass the dimensions in as parameters to the
|
||||
call.
|
||||
|
||||
## Custom-call on GPU
|
||||
|
||||
The GPU custom call framework is somewhat different than that on the CPU. Here
|
||||
is a CUDA example that does the same `A[i] = B[i % 128] + C[i]` computation as
|
||||
the CPU code above.
|
||||
|
||||
```c++
|
||||
void do_it() { /* same implementation as above */ }
|
||||
|
||||
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
|
||||
size_t idx = threadIdx.x * blockSize.x + gridIdx.x;
|
||||
out[idx] = in0[idx % 128] + in1[idx];
|
||||
}
|
||||
|
||||
void do_custom_call(CUstream stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
const float* in0 = reinterpret_cast<const float*>(buffers[0]);
|
||||
const float* in1 = reinterpret_cast<const float*>(buffers[1]);
|
||||
float* out = reinterpret_cast<float*>(buffers[2]);
|
||||
|
||||
const int64 block_dim = 64;
|
||||
const int64 grid_dim = 2048 / block_dim;
|
||||
custom_call_kernel<<<grid_dim, block_dim,
|
||||
/*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
|
||||
}
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");
|
||||
```
|
||||
|
||||
Notice first that the GPU custom call function *is still a function executed on
|
||||
the CPU*. Our `do_custom_call` CPU function is responsible for enqueueing work
|
||||
on the GPU. Here it launches a CUDA kernel, but it could also do something else,
|
||||
like call cublas.
|
||||
|
||||
`buffers` is an array of pointers which lives on the host, and each element it
|
||||
contains points to device (i.e. GPU) memory. The parameters come first, followed
|
||||
by the output value. This is notably different from the CPU calling convention,
|
||||
which has two params, `ins` and `out`. The main reason we diverge is to make it
|
||||
possible to handle tuple-shaped inputs/outputs efficiently; see the section
|
||||
below.
|
||||
|
||||
As in the CPU example, we've hardcoded the input and output buffer sizes into
|
||||
our custom call. However unlike in the CPU case, passing the buffer sizes in as
|
||||
operands to the custom call would not work well. Usually we need the buffer
|
||||
sizes available to us on the CPU; e.g. when launching a kernel, we need to know
|
||||
the block/grid dimensions to use. But if we were to pass the buffer sizes as
|
||||
operands to our custom call, their values would live in GPU memory. We'd then
|
||||
have to do an expensive synchronous device-to-host memcpy at the start of our
|
||||
operation just to read the sizes.
|
||||
|
||||
To let you work around this, we provide the `opaque` parameter. You can set this
|
||||
to an arbitrary string of bytes when you create the custom call:
|
||||
|
||||
```c++
|
||||
std::string opaque = "...";
|
||||
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
|
||||
/*output_shape=*/ShapeUtil::CreateShape(F32, {2048}),
|
||||
opaque);
|
||||
```
|
||||
|
||||
Since `xla::Shape` has a protocol buffer representation, you could store this
|
||||
serialized proto inside of `opaque` and deserialize it within your GPU
|
||||
custom-call. Note however that although `xla::ShapeProto` does not change
|
||||
frequently, it *does* change. Check the git log to see how it has changed in the
|
||||
past.
|
||||
|
||||
## Passing tuples to custom-calls
|
||||
|
||||
Consider the following custom-call.
|
||||
|
||||
```c++
|
||||
using xla::ShapeUtil;
|
||||
Shape p0_shape = ShapeUtil::MakeTuple({
|
||||
ShapeUtil::MakeShape(F32, {32}),
|
||||
ShapeUtil::MakeTuple({
|
||||
ShapeUtil::MakeTuple(F32, {64}),
|
||||
ShapeUtil::MakeTuple(F32, {128}),
|
||||
}),
|
||||
ShapeUtil::MakeShape(F32, {256}),
|
||||
});
|
||||
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
|
||||
|
||||
Shape out_shape = ShapeUtil::MakeTuple({
|
||||
ShapeUtil::MakeShape(F32, {512}),
|
||||
ShapeUtil::MakeShape(F32, {1024}),
|
||||
});
|
||||
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);
|
||||
```
|
||||
|
||||
On both CPU and GPU, a tuple is represented in memory as an array of pointers.
|
||||
In C++-pseudocode, parameter 0 above is laid out as follows.
|
||||
|
||||
```c++
|
||||
// In-memory layout of parameter 0 from custom-call above. True on both CPU
|
||||
// and GPU.
|
||||
float* subbuf0 = new float[32];
|
||||
float* subbuf1 = new float[64];
|
||||
float* subbuf2 = new float[128]
|
||||
float* subbuf3 = new float[256];
|
||||
|
||||
void* subtuple = new void*[2];
|
||||
(*subtuple)[0] = subbuf1;
|
||||
(*subtuple)[1] = subbuf2;
|
||||
|
||||
void* p0 = new void*[3];
|
||||
(*p0)[0] = subbuf0;
|
||||
(*p0)[1] = subtuple;
|
||||
(*p0)[2] = subbuf3;
|
||||
```
|
||||
|
||||
Although the in-memory representation of tuples is the same in CPU and GPU, they
|
||||
are handled differently in the CPU and GPU custom-call calling conventions.
|
||||
|
||||
### Tuple outputs as temp buffers
|
||||
|
||||
Tuple inputs to custom-calls are a convenience, but they aren't strictly
|
||||
necessary. If we didn't support tuple inputs to custom calls, you could always
|
||||
unpack the tuples using get-tuple-element before passing them to the custom
|
||||
call.
|
||||
|
||||
On the other hand, tuple *outputs* do let you do things you couldn't otherwise.
|
||||
|
||||
The obvious reason to have tuple outputs is, that's how a custom call (or any
|
||||
other XLA op) returns multiple independent arrays.
|
||||
|
||||
But less obviously, a tuple output is also a way to give your custom call temp
|
||||
memory. Yes, an *output* can represent a temp buffer. Consider, an output buffer
|
||||
has the property that the op can write to it, and it can read from it after it's
|
||||
been written to. That's exactly what you want from a temp buffer.
|
||||
|
||||
In the example above, suppose we wanted to use the `F32[1024]` as a temp buffer.
|
||||
Then we'd write the HLO just as above, and we'd simply never read tuple index 1
|
||||
of the custom call's output.
|
||||
|
||||
### Tuples in CPU custom-calls
|
||||
|
||||
In CPU code, we have a function `do_custom_call(const void** ins, void* out)`.
|
||||
`ins` is an array with just one element, which points to `param0`. The
|
||||
subbuffers of `param0` are accessible by dereferencing that pointer, and the
|
||||
subbuffers of `output_tuple` are accessible by dereferencing `out`.
|
||||
|
||||
### Tuples in GPU custom-calls
|
||||
|
||||
In GPU code, we have a function `do_custom_call(..., void** buffers, ...)`. In
|
||||
this case `buffers` is a host array of *nine* device pointers, one for each
|
||||
nested buffer. To generate the flat list, we iterate over the parameters and
|
||||
output, and then do preorder traversal of their shapes. Concretely:
|
||||
|
||||
```c++
|
||||
// Layout of `buffers` parameter to GPU custom call function for custom-call
|
||||
// above.
|
||||
buffers[0] == param0
|
||||
buffers[1] == subbuf0 or null
|
||||
buffers[2] == subtuple or null
|
||||
buffers[3] == subbuf1 or null
|
||||
buffers[4] == subbuf2 or null
|
||||
buffers[5] == subbuf3 or null
|
||||
buffers[6] == output_tuple
|
||||
buffers[7] == output_subbuf0
|
||||
buffers[8] == output_subbuf1
|
||||
```
|
||||
|
||||
The `or null` part is significant. A sub-buffer of an input tuple will be
|
||||
non-null in the `buffers` list if XLA is able to statically analyze the program
|
||||
and figure out the address of the sub-buffer. This is usually the case, but may
|
||||
not be in programs with control flow and/or `select` ops over tuples.
|
||||
|
||||
A correct custom-call implementation that accepts a tuple as input must always
|
||||
handle null input sub-buffers, by dereferencing the root tuple.
|
||||
|
||||
The rule is reversed for output buffers. The output sub-buffers will always be
|
||||
populated, but it's up to the custom call to populate the root tuple at the end.
|
||||
|
||||
See the following code. Note that we leave out CUDA error handling for clarity,
|
||||
but you'll be thankful if you do it, because otherwise it can be hard to tell
|
||||
when a stream encounters an error.
|
||||
|
||||
```c++
|
||||
void do_custom_call(CUstream stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
bool needs_sync = false;
|
||||
const float* subbuf0 = reinterpret_cast<const float*>(buffers[1]);
|
||||
if (subbuf0 == nullptr) {
|
||||
needs_sync = true;
|
||||
cudaMemcpyAsync(&subbuf0, buffers[0], sizeof(void*),
|
||||
cudaMemcpyDeviceToHost, stream);
|
||||
}
|
||||
const void** subtuple = reinterpret_cast<const void**>(buffers[2]);
|
||||
if (subtuple == nullptr) {
|
||||
needs_sync = true;
|
||||
cudaMemcpyAsync(&subtuple, buffers[2], ...);
|
||||
}
|
||||
|
||||
// ... similarly for other params ...
|
||||
|
||||
// Wait for copies enqueued above to complete.
|
||||
if (needs_sync) {
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
needs_sync = false;
|
||||
|
||||
// Now that we have `subtuple`, we can get subbuf1 and subbuf2.
|
||||
float* subbuf1 = buffers[3];
|
||||
if (subbuf1 == nullptr) {
|
||||
needs_sync = true;
|
||||
cudaMemcpyAsync(&subbuf1, subtuple, ...);
|
||||
}
|
||||
float* subbuf2 = buffers[4];
|
||||
if (subbuf2 == nullptr) {
|
||||
needs_sync = true;
|
||||
cudaMemcpyAsync(&subbuf2, subtuple + 1, ...);
|
||||
}
|
||||
|
||||
// Wait for copies enqueued above to complete.
|
||||
if (needs_sync) {
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
|
||||
// ... actually run the kernel ...
|
||||
|
||||
// Fill the output tuple.
|
||||
void* outputs[2] = {buffers[7], buffers[8]};
|
||||
cudaMemcpyAsync(buffers[6], outputs, sizeof(outputs), cudaMemcpyHostToDevice,
|
||||
stream);
|
||||
|
||||
// Necessary to force the cudaMemcpyAsync above to complete before `outputs`
|
||||
// goes out of scope. A sync is only necessary in the tuple output case, and
|
||||
// see below for a way to avoid this.
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
```
|
||||
|
||||
The `cudaStreamSynchronize` at the end of the function is unfortunate, as it's
|
||||
not required in the non-tuple-output case, and it can be expensive. One way to
|
||||
get around this would be to make `outputs` into a global variable and ensure
|
||||
that the previous cudaMemcpyAsync completed before overwriting the global and
|
||||
enqueueing another one. This is sketched below.
|
||||
|
||||
```
|
||||
void do_custom_call(CUstream stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
|
||||
// ... Beginning of function is the same as above ...
|
||||
|
||||
// ... actually run the kernel ...
|
||||
|
||||
static std::atomic<bool> first_time{true};
|
||||
static CUevent event;
|
||||
static void* outputs[2];
|
||||
if (first_time.fetch_and(false)) {
|
||||
// First time running this function. Initialize `event`.
|
||||
cuEventCreate(&event, CU_EVENT_DISABLE_TIMING);
|
||||
} else {
|
||||
// Not first time running this function. Wait for previous event to
|
||||
// complete before touching `outputs`.
|
||||
cuEventSynchronize(event);
|
||||
}
|
||||
|
||||
// Fill the output tuple.
|
||||
outputs[0] = buffers[7];
|
||||
outputs[1] = buffers[8];
|
||||
cudaMemcpyAsync(buffers[6], outputs, sizeof(outputs), cudaMemcpyHostToDevice,
|
||||
stream);
|
||||
|
||||
// Unblock `event` after the memcpy completes.
|
||||
cuEventRecord(event, stream);
|
||||
}
|
||||
```
|
||||
|
||||
This simple implementation would limit parallelism if you want to run this op on
|
||||
multiple GPUs concurrently (or on one GPU with multiple streams); in that case
|
||||
you might need multiple events and globals. We have seen one implementation of
|
||||
this algorithm which keeps a pool of globals and events and periodically polls
|
||||
them (perhaps on each call to the op) to garbage collect.
|
@ -67,8 +67,8 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@pybind11",
|
||||
@ -109,9 +109,9 @@ cc_library(
|
||||
hdrs = ["shared_device_buffer.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
@ -131,11 +131,50 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "local_client",
|
||||
srcs = ["local_client.cc"],
|
||||
hdrs = ["local_client.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
"-Wno-c++98-c++11-compat",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":shared_device_buffer",
|
||||
":types",
|
||||
":worker_thread",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_pybind_extension(
|
||||
name = "xla_extension",
|
||||
srcs = [
|
||||
"local_client.cc",
|
||||
"local_client.h",
|
||||
"xla.cc",
|
||||
],
|
||||
copts = [
|
||||
@ -146,22 +185,19 @@ tf_pybind_extension(
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "xla_extension",
|
||||
deps = [
|
||||
":local_client",
|
||||
":shared_device_buffer",
|
||||
":types",
|
||||
":worker_thread",
|
||||
":xrt",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@pybind11",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -178,7 +214,7 @@ tf_pybind_extension(
|
||||
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||
"//tensorflow/compiler/xla/client/lib:svd",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:name_uniquer",
|
||||
@ -186,9 +222,7 @@ tf_pybind_extension(
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
# Do NOT remove this dependency. The XLA Python extension must not
|
||||
# depend on any part of TensorFlow at runtime, **including**
|
||||
# libtensorflow_framework.so. The XLA module is deployed self-contained
|
||||
|
@ -78,7 +78,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -101,8 +101,8 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
||||
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
|
||||
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
|
||||
}
|
||||
cpu::CustomCallTargetRegistry::Global()->Register(
|
||||
std::string(fn_name.begin(), fn_name.end()), static_cast<void*>(capsule));
|
||||
CustomCallTargetRegistry::Global()->Register(
|
||||
fn_name, static_cast<void*>(capsule), "Host");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -147,7 +147,12 @@ Device::Device(se::StreamExecutor* executor, bool use_multiple_streams,
|
||||
"py_xla_execute");
|
||||
}
|
||||
|
||||
Device::~Device() { compute_stream_->parent()->SynchronizeAllActivity(); }
|
||||
Device::~Device() {
|
||||
bool ok = compute_stream_->parent()->SynchronizeAllActivity();
|
||||
if (!ok) {
|
||||
LOG(ERROR) << "SynchronizeAllActivity failed when destroying Device.";
|
||||
}
|
||||
}
|
||||
|
||||
void Device::ThenExecuteOnWorkerThread(se::Stream* stream,
|
||||
std::function<void()> callback) const {
|
||||
@ -155,7 +160,7 @@ void Device::ThenExecuteOnWorkerThread(se::Stream* stream,
|
||||
[this, callback]() { worker_thread_->Schedule(std::move(callback)); });
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyLocalClient>> PyLocalClient::Get(
|
||||
StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
|
||||
const std::string& platform_name, const std::string& xla_platform_name,
|
||||
bool asynchronous) {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||
@ -168,7 +173,7 @@ StatusOr<std::unique_ptr<PyLocalClient>> PyLocalClient::Get(
|
||||
options.set_platform(platform);
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
return absl::make_unique<PyLocalClient>(platform_name, client, asynchronous);
|
||||
return std::make_shared<PyLocalClient>(platform_name, client, asynchronous);
|
||||
}
|
||||
|
||||
PyLocalClient::PyLocalClient(std::string platform_name, LocalClient* client,
|
||||
@ -210,9 +215,9 @@ StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
|
||||
}
|
||||
|
||||
static StatusOr<PyLocalBuffer> TransferHostToDeviceAsync(
|
||||
const PythonBufferTree& tree, int device_ordinal, PyLocalClient* client,
|
||||
const Device& device) {
|
||||
DeviceMemoryAllocator* allocator =
|
||||
const PythonBufferTree& tree, int device_ordinal,
|
||||
std::shared_ptr<PyLocalClient> client, const Device& device) {
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
client->client()->backend().memory_allocator();
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
@ -255,13 +260,13 @@ static StatusOr<PyLocalBuffer> TransferHostToDeviceAsync(
|
||||
device.ThenReleaseOnWorkerThread(device.host_to_device_stream(),
|
||||
device_buffer);
|
||||
}
|
||||
return PyLocalBuffer(shape, std::move(device_buffer), client);
|
||||
return PyLocalBuffer(shape, std::move(device_buffer), std::move(client));
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(const py::object& argument,
|
||||
PyLocalClient* client,
|
||||
int device_ordinal) {
|
||||
StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(
|
||||
const py::object& argument, std::shared_ptr<PyLocalClient> client,
|
||||
int device_ordinal) {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython");
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
|
||||
|
||||
@ -277,13 +282,13 @@ StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(const py::object& argument,
|
||||
<< " device ordinal: " << device_ordinal;
|
||||
|
||||
const Device& device = client->device(device_ordinal);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
PyLocalBuffer buffer,
|
||||
TransferHostToDeviceAsync(tree, device_ordinal, client, device));
|
||||
TF_ASSIGN_OR_RETURN(PyLocalBuffer buffer,
|
||||
TransferHostToDeviceAsync(tree, device_ordinal,
|
||||
std::move(client), device));
|
||||
|
||||
device.ThenRelease(device.host_to_device_stream(), std::move(py_buffer_ref));
|
||||
if (!device.asynchronous()) {
|
||||
device.host_to_device_stream()->BlockHostUntilDone();
|
||||
TF_RETURN_IF_ERROR(device.host_to_device_stream()->BlockHostUntilDone());
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
@ -291,7 +296,7 @@ StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(const py::object& argument,
|
||||
/*static */ StatusOr<std::vector<PyLocalBuffer>>
|
||||
PyLocalBuffer::FromPythonValues(
|
||||
const std::vector<std::pair<py::object, int>>& arguments,
|
||||
PyLocalClient* client) {
|
||||
std::shared_ptr<PyLocalClient> client) {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPythonValues");
|
||||
int num_arguments = static_cast<int>(arguments.size());
|
||||
std::vector<PyLocalBuffer> outputs(num_arguments);
|
||||
@ -344,7 +349,7 @@ PyLocalBuffer::FromPythonValues(
|
||||
device.ThenRelease(device.host_to_device_stream(),
|
||||
std::move(transfers[i].py_buffer_ref));
|
||||
if (!device.asynchronous()) {
|
||||
device.host_to_device_stream()->BlockHostUntilDone();
|
||||
TF_RETURN_IF_ERROR(device.host_to_device_stream()->BlockHostUntilDone());
|
||||
}
|
||||
}
|
||||
|
||||
@ -355,8 +360,8 @@ PyLocalBuffer::FromPythonValues(
|
||||
}
|
||||
|
||||
/* static */ StatusOr<PyLocalBuffer> PyLocalBuffer::MakeTuple(
|
||||
const std::vector<PyLocalBuffer> buffers, PyLocalClient* client,
|
||||
int device_ordinal) {
|
||||
const std::vector<PyLocalBuffer> buffers,
|
||||
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
|
||||
std::vector<xla::Shape> host_shapes;
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> device_buffers;
|
||||
host_shapes.reserve(buffers.size());
|
||||
@ -367,7 +372,7 @@ PyLocalBuffer::FromPythonValues(
|
||||
host_shapes.push_back(buffer.on_host_shape());
|
||||
device_buffers.push_back(buffer.device_buffer());
|
||||
}
|
||||
DeviceMemoryAllocator* allocator =
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
client->client()->backend().memory_allocator();
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
@ -382,7 +387,7 @@ PyLocalBuffer::FromPythonValues(
|
||||
device_buffers, transfer_manager, allocator,
|
||||
device_ordinal, definition_event));
|
||||
PyLocalBuffer buffer(ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer,
|
||||
client);
|
||||
std::move(client));
|
||||
|
||||
// TODO(phawkins): extend TransferManager so we do not need to form a full
|
||||
// ShapedBuffer just to write the root tuple index table.
|
||||
@ -393,8 +398,8 @@ PyLocalBuffer::FromPythonValues(
|
||||
// Wait for the compute stream so that memory allocations are synchronized.
|
||||
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
|
||||
}
|
||||
transfer_manager->WriteRootTupleIndexTable(device.host_to_device_stream(),
|
||||
shaped_buffer);
|
||||
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
|
||||
device.host_to_device_stream(), shaped_buffer));
|
||||
if (definition_event) {
|
||||
definition_event->RecordOnStream(device.host_to_device_stream());
|
||||
}
|
||||
@ -404,7 +409,7 @@ PyLocalBuffer::FromPythonValues(
|
||||
std::move(tuple_buffer));
|
||||
}
|
||||
if (!device.asynchronous()) {
|
||||
device.host_to_device_stream()->BlockHostUntilDone();
|
||||
TF_RETURN_IF_ERROR(device.host_to_device_stream()->BlockHostUntilDone());
|
||||
}
|
||||
|
||||
return buffer;
|
||||
@ -412,10 +417,10 @@ PyLocalBuffer::FromPythonValues(
|
||||
|
||||
PyLocalBuffer::PyLocalBuffer(
|
||||
Shape on_host_shape, std::shared_ptr<PySharedDeviceBuffer> device_buffer,
|
||||
PyLocalClient* client)
|
||||
: on_host_shape_(std::move(on_host_shape)),
|
||||
device_buffer_(std::move(device_buffer)),
|
||||
client_(client) {}
|
||||
std::shared_ptr<PyLocalClient> client)
|
||||
: client_(std::move(client)),
|
||||
on_host_shape_(std::move(on_host_shape)),
|
||||
device_buffer_(std::move(device_buffer)) {}
|
||||
|
||||
StatusOr<py::object> PyLocalBuffer::ToPython() const {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython");
|
||||
@ -462,10 +467,10 @@ StatusOr<std::vector<PyLocalBuffer>> PyLocalBuffer::DestructureTuple() {
|
||||
|
||||
PyLocalExecutable::PyLocalExecutable(
|
||||
std::shared_ptr<LocalExecutable> executable,
|
||||
DeviceAssignment device_assignment, PyLocalClient* client)
|
||||
: executable_(std::move(executable)),
|
||||
device_assignment_(std::move(device_assignment)),
|
||||
client_(client) {}
|
||||
DeviceAssignment device_assignment, std::shared_ptr<PyLocalClient> client)
|
||||
: client_(std::move(client)),
|
||||
executable_(std::move(executable)),
|
||||
device_assignment_(std::move(device_assignment)) {}
|
||||
|
||||
std::vector<int> PyLocalExecutable::DeviceOrdinals() const {
|
||||
int num_replicas = device_assignment_.replica_count();
|
||||
@ -543,7 +548,7 @@ StatusOr<PyLocalBuffer> PyLocalExecutable::ExecuteHelper(
|
||||
device.ThenReleaseOnWorkerThread(device.compute_stream(), executable_);
|
||||
}
|
||||
if (!device.asynchronous()) {
|
||||
device.compute_stream()->BlockHostUntilDone();
|
||||
TF_RETURN_IF_ERROR(device.compute_stream()->BlockHostUntilDone());
|
||||
}
|
||||
return PyLocalBuffer(on_host_shape, std::move(out_buffer), client_);
|
||||
}
|
||||
@ -652,7 +657,7 @@ StatusOr<std::vector<PyLocalBuffer>> PyLocalExecutable::ExecutePerReplica(
|
||||
PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||
std::vector<Shape> argument_layouts,
|
||||
const ExecutableBuildOptions* build_options,
|
||||
PyLocalClient* client) {
|
||||
std::shared_ptr<PyLocalClient> client) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile");
|
||||
std::vector<const Shape*> argument_layout_pointers;
|
||||
argument_layout_pointers.reserve(argument_layouts.size());
|
||||
@ -705,7 +710,7 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||
|
||||
return absl::make_unique<PyLocalExecutable>(
|
||||
std::shared_ptr<LocalExecutable>(std::move(local_executable)),
|
||||
std::move(device_assignment), client);
|
||||
std::move(device_assignment), std::move(client));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -169,12 +169,13 @@ class PyLocalClient {
|
||||
public:
|
||||
// Initializes a local XLA client for `platform_name`. Returns an error if no
|
||||
// such platform exists, or if the platform has no visible devices.
|
||||
static StatusOr<std::unique_ptr<PyLocalClient>> Get(
|
||||
const std::string& platform_name, const std::string& xla_platform_id,
|
||||
static StatusOr<std::shared_ptr<PyLocalClient>> Get(
|
||||
const std::string& platform_name, const std::string& xla_platform_name,
|
||||
bool asynchronous);
|
||||
|
||||
explicit PyLocalClient(std::string platform_name, LocalClient* client,
|
||||
bool asynchronous);
|
||||
virtual ~PyLocalClient() = default;
|
||||
|
||||
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
|
||||
StatusOr<pybind11::object> TransferFromOutfeed(const Shape& shape,
|
||||
@ -192,7 +193,7 @@ class PyLocalClient {
|
||||
|
||||
PythonRefManager& py_ref_manager() { return py_ref_manager_; }
|
||||
|
||||
private:
|
||||
protected:
|
||||
std::string platform_name_;
|
||||
LocalClient* client_;
|
||||
std::vector<std::unique_ptr<Device>> devices_;
|
||||
@ -205,29 +206,30 @@ class PyLocalClient {
|
||||
// Holds a reference from Python to one or more device buffers.
|
||||
class PyLocalBuffer {
|
||||
public:
|
||||
static StatusOr<PyLocalBuffer> FromPython(const pybind11::object& argument,
|
||||
PyLocalClient* client,
|
||||
int device_ordinal);
|
||||
static StatusOr<PyLocalBuffer> FromPython(
|
||||
const pybind11::object& argument, std::shared_ptr<PyLocalClient> client,
|
||||
int device_ordinal);
|
||||
|
||||
// Converts multiple (python object, device ordinal) pairs into
|
||||
// PyLocalBuffers in parallel.
|
||||
static StatusOr<std::vector<PyLocalBuffer>> FromPythonValues(
|
||||
const std::vector<std::pair<pybind11::object, int>>& argument,
|
||||
PyLocalClient* client);
|
||||
std::shared_ptr<PyLocalClient> client);
|
||||
|
||||
static StatusOr<PyLocalBuffer> MakeTuple(
|
||||
const std::vector<PyLocalBuffer> buffers, PyLocalClient* client,
|
||||
int device_ordinal);
|
||||
const std::vector<PyLocalBuffer> buffers,
|
||||
std::shared_ptr<PyLocalClient> client, int device_ordinal);
|
||||
|
||||
PyLocalBuffer() = default;
|
||||
PyLocalBuffer(Shape on_host_shape,
|
||||
std::shared_ptr<PySharedDeviceBuffer> device_buffer,
|
||||
PyLocalClient* client);
|
||||
std::shared_ptr<PyLocalClient> client);
|
||||
StatusOr<pybind11::object> ToPython() const;
|
||||
const Shape& on_host_shape() const { return on_host_shape_; }
|
||||
const std::shared_ptr<PySharedDeviceBuffer>& device_buffer() const {
|
||||
return device_buffer_;
|
||||
}
|
||||
int device_ordinal() const { return device_buffer_->device_ordinal(); }
|
||||
|
||||
void Delete() {
|
||||
device_buffer_ = nullptr;
|
||||
@ -242,9 +244,9 @@ class PyLocalBuffer {
|
||||
StatusOr<std::vector<PyLocalBuffer>> DestructureTuple();
|
||||
|
||||
private:
|
||||
std::shared_ptr<PyLocalClient> client_ = nullptr;
|
||||
Shape on_host_shape_;
|
||||
std::shared_ptr<PySharedDeviceBuffer> device_buffer_;
|
||||
PyLocalClient* client_ = nullptr;
|
||||
};
|
||||
|
||||
// Represents a compiled computation that can be executed given handles to
|
||||
@ -254,10 +256,12 @@ class PyLocalExecutable {
|
||||
// Compiles a computation to an executable.
|
||||
static StatusOr<std::unique_ptr<PyLocalExecutable>> Compile(
|
||||
const XlaComputation& computation, std::vector<Shape> argument_layouts,
|
||||
const ExecutableBuildOptions* build_options, PyLocalClient* client);
|
||||
const ExecutableBuildOptions* build_options,
|
||||
std::shared_ptr<PyLocalClient> client);
|
||||
|
||||
PyLocalExecutable(std::shared_ptr<LocalExecutable> executable,
|
||||
DeviceAssignment device_assignment, PyLocalClient* client);
|
||||
DeviceAssignment device_assignment,
|
||||
std::shared_ptr<PyLocalClient> client);
|
||||
|
||||
int num_replicas() const {
|
||||
return executable_->build_options().num_replicas();
|
||||
@ -285,9 +289,9 @@ class PyLocalExecutable {
|
||||
StatusOr<PyLocalBuffer> ExecuteHelper(
|
||||
absl::Span<PyLocalBuffer* const> argument_handles, int replica);
|
||||
|
||||
std::shared_ptr<PyLocalClient> const client_;
|
||||
std::shared_ptr<LocalExecutable> executable_;
|
||||
const DeviceAssignment device_assignment_;
|
||||
PyLocalClient* const client_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -47,14 +47,14 @@ void BufferDefinitionEvent::WaitForEventOnStream(se::Stream* stream) {
|
||||
static std::shared_ptr<PySharedDeviceBuffer>
|
||||
BufferFromScopedShapedBufferIterator(
|
||||
const Shape& on_device_shape, int device_ordinal,
|
||||
DeviceMemoryAllocator* allocator,
|
||||
se::DeviceMemoryAllocator* allocator,
|
||||
ShapeTree<se::DeviceMemoryBase>::iterator* iterator,
|
||||
const ShapeTree<se::DeviceMemoryBase>::iterator& end,
|
||||
const std::shared_ptr<BufferDefinitionEvent>& definition_event) {
|
||||
CHECK(*iterator != end);
|
||||
|
||||
OwningDeviceMemory device_memory((*iterator)->second, device_ordinal,
|
||||
allocator);
|
||||
se::OwningDeviceMemory device_memory((*iterator)->second, device_ordinal,
|
||||
allocator);
|
||||
(*iterator)->second = se::DeviceMemoryBase();
|
||||
++*iterator;
|
||||
|
||||
@ -90,7 +90,7 @@ PySharedDeviceBuffer::FromScopedShapedBuffer(
|
||||
/* static */ StatusOr<std::shared_ptr<PySharedDeviceBuffer>>
|
||||
PySharedDeviceBuffer::MakeTuple(
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
|
||||
TransferManager* transfer_manager, DeviceMemoryAllocator* allocator,
|
||||
TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event) {
|
||||
std::vector<Shape> child_shapes;
|
||||
@ -102,7 +102,7 @@ PySharedDeviceBuffer::MakeTuple(
|
||||
|
||||
Shape shape = ShapeUtil::MakeTupleShape(child_shapes);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
OwningDeviceMemory device_memory,
|
||||
se::OwningDeviceMemory device_memory,
|
||||
allocator->Allocate(device_ordinal,
|
||||
transfer_manager->GetByteSizeRequirement(shape)));
|
||||
return std::make_shared<PySharedDeviceBuffer>(
|
||||
@ -113,10 +113,10 @@ PySharedDeviceBuffer::MakeTuple(
|
||||
/* static */ StatusOr<std::shared_ptr<PySharedDeviceBuffer>>
|
||||
PySharedDeviceBuffer::MakeArray(
|
||||
Shape on_device_shape, TransferManager* transfer_manager,
|
||||
DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
OwningDeviceMemory device_memory,
|
||||
se::OwningDeviceMemory device_memory,
|
||||
allocator->Allocate(
|
||||
device_ordinal,
|
||||
transfer_manager->GetByteSizeRequirement(on_device_shape)));
|
||||
@ -153,7 +153,7 @@ ShapedBuffer PySharedDeviceBuffer::AsShapedBuffer(
|
||||
}
|
||||
|
||||
PySharedDeviceBuffer::PySharedDeviceBuffer(
|
||||
Shape on_device_shape, OwningDeviceMemory device_memory,
|
||||
Shape on_device_shape, se::OwningDeviceMemory device_memory,
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event)
|
||||
: on_device_shape_(std::move(on_device_shape)),
|
||||
|
@ -17,11 +17,11 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/owning_device_memory.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -93,14 +93,14 @@ class PySharedDeviceBuffer {
|
||||
// Makes a tuple buffer. Does not initialize the tuple table.
|
||||
static StatusOr<std::shared_ptr<PySharedDeviceBuffer>> MakeTuple(
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
|
||||
TransferManager* transfer_manager, DeviceMemoryAllocator* allocator,
|
||||
TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
|
||||
// Makes an uninitialized array buffer.
|
||||
static StatusOr<std::shared_ptr<PySharedDeviceBuffer>> MakeArray(
|
||||
Shape on_device_shape, TransferManager* transfer_manager,
|
||||
DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
|
||||
// Builds a ShapedBuffer view onto the buffers of 'tree'. Since
|
||||
@ -113,7 +113,7 @@ class PySharedDeviceBuffer {
|
||||
const std::vector<std::shared_ptr<PySharedDeviceBuffer>>& children() const {
|
||||
return children_;
|
||||
}
|
||||
const OwningDeviceMemory& device_memory() const { return device_memory_; }
|
||||
const se::OwningDeviceMemory& device_memory() const { return device_memory_; }
|
||||
int device_ordinal() const { return device_memory_.device_ordinal(); }
|
||||
const std::shared_ptr<BufferDefinitionEvent> definition_event() const {
|
||||
return definition_event_;
|
||||
@ -121,7 +121,7 @@ class PySharedDeviceBuffer {
|
||||
|
||||
PySharedDeviceBuffer() = default;
|
||||
PySharedDeviceBuffer(
|
||||
Shape on_device_shape, OwningDeviceMemory device_memory,
|
||||
Shape on_device_shape, se::OwningDeviceMemory device_memory,
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
|
||||
@ -130,7 +130,7 @@ class PySharedDeviceBuffer {
|
||||
// one-to-one with the tree of device buffers, so to avoid representational
|
||||
// awkwardness we maintain on-host shapes separately.
|
||||
Shape on_device_shape_;
|
||||
OwningDeviceMemory device_memory_;
|
||||
se::OwningDeviceMemory device_memory_;
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children_;
|
||||
|
||||
// An event that is triggered when the content of one or more buffers is
|
||||
|
@ -16,15 +16,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/stream_executor/owning_device_memory.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
xla::StatusOr<PrimitiveType> NumpyTypeToPrimitiveType(
|
||||
const py::dtype& np_type) {
|
||||
xla::StatusOr<PrimitiveType> DtypeToPrimitiveType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, PrimitiveType>({
|
||||
{{'b', 1}, PRED},
|
||||
@ -50,6 +49,42 @@ xla::StatusOr<PrimitiveType> NumpyTypeToPrimitiveType(
|
||||
return it->second;
|
||||
}
|
||||
|
||||
xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
|
||||
switch (type) {
|
||||
case PRED:
|
||||
return py::dtype::of<bool>();
|
||||
case S8:
|
||||
return py::dtype::of<int8>();
|
||||
case S16:
|
||||
return py::dtype::of<int16>();
|
||||
case S32:
|
||||
return py::dtype::of<int32>();
|
||||
case S64:
|
||||
return py::dtype::of<int64>();
|
||||
case U8:
|
||||
return py::dtype::of<uint8>();
|
||||
case U16:
|
||||
return py::dtype::of<uint16>();
|
||||
case U32:
|
||||
return py::dtype::of<uint32>();
|
||||
case U64:
|
||||
return py::dtype::of<uint64>();
|
||||
case F16:
|
||||
return py::dtype("e");
|
||||
case F32:
|
||||
return py::dtype::of<float>();
|
||||
case F64:
|
||||
return py::dtype::of<double>();
|
||||
case C64:
|
||||
return py::dtype::of<std::complex<float>>();
|
||||
case C128:
|
||||
return py::dtype::of<std::complex<double>>();
|
||||
default:
|
||||
return Unimplemented("Unimplemented primitive type %s",
|
||||
PrimitiveType_Name(type));
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a numpy-style format descriptor string for `type`.
|
||||
StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type) {
|
||||
switch (type) {
|
||||
@ -159,4 +194,20 @@ StatusOr<PythonBufferTree> GetPythonBufferTree(const py::object& argument) {
|
||||
return tree;
|
||||
}
|
||||
|
||||
py::tuple IntSpanToTuple(absl::Span<int64 const> xs) {
|
||||
py::tuple out(xs.size());
|
||||
for (int i = 0; i < xs.size(); ++i) {
|
||||
out[i] = py::int_(xs[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<int64> IntSequenceToVector(const py::object& sequence) {
|
||||
std::vector<int64> output;
|
||||
for (auto item : sequence) {
|
||||
output.push_back(item.cast<int64>());
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -32,29 +32,48 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Converts a pybind11-style NumPy dtype to a PrimitiveType.
|
||||
StatusOr<PrimitiveType> NumpyTypeToPrimitiveType(
|
||||
const pybind11::dtype& np_type);
|
||||
// Helper that converts a failing StatusOr to an exception.
|
||||
// For use only inside pybind11 code.
|
||||
template <typename T>
|
||||
T ValueOrThrow(StatusOr<T> v) {
|
||||
if (!v.ok()) {
|
||||
throw std::runtime_error(v.status().ToString());
|
||||
}
|
||||
return v.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Converts a NumPy dtype to a PrimitiveType.
|
||||
StatusOr<PrimitiveType> DtypeToPrimitiveType(const pybind11::dtype& np_type);
|
||||
|
||||
// Converts a PrimitiveType to a Numpy dtype.
|
||||
StatusOr<pybind11::dtype> PrimitiveTypeToDtype(PrimitiveType type);
|
||||
|
||||
// Converts a literal to (possibly-nested tuples of) NumPy arrays.
|
||||
// The literal's leaf arrays are not copied; instead the NumPy arrays share
|
||||
// buffers with the literals. Takes ownership of `literal` and keeps the
|
||||
// necessary pieces alive using Python reference counting.
|
||||
// Requires the GIL.
|
||||
StatusOr<pybind11::object> LiteralToPython(
|
||||
std::unique_ptr<xla::Literal> literal);
|
||||
StatusOr<pybind11::object> LiteralToPython(std::unique_ptr<Literal> literal);
|
||||
|
||||
// Converts a Python object into an XLA shape and a vector of leaf buffers.
|
||||
// The leaf buffers correspond to a depth-first, left-to-right traversal of
|
||||
// the Python value.
|
||||
// Requires the GIL.
|
||||
struct PythonBufferTree {
|
||||
absl::InlinedVector<xla::BorrowingLiteral, 1> leaves;
|
||||
xla::Shape shape;
|
||||
absl::InlinedVector<BorrowingLiteral, 1> leaves;
|
||||
Shape shape;
|
||||
};
|
||||
StatusOr<PythonBufferTree> GetPythonBufferTree(
|
||||
const pybind11::object& argument);
|
||||
|
||||
// Converts a sequence of int64s to a Python tuple of ints.
|
||||
// Pybind11 by default converts a std::vector<int64> to a Python list; for
|
||||
// shapes we frequently want a tuple instead.
|
||||
pybind11::tuple IntSpanToTuple(absl::Span<int64 const> xs);
|
||||
|
||||
// Converts a Python sequence of integers to a std::vector<int64>
|
||||
std::vector<int64> IntSequenceToVector(const pybind11::object& sequence);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
// This namespace is a documented pybind11 extension point.
|
||||
@ -161,7 +180,7 @@ struct type_caster<xla::BorrowingLiteral> {
|
||||
for (int i = 0; i < array.ndim(); ++i) {
|
||||
dims[i] = array.shape(i);
|
||||
}
|
||||
auto type = xla::NumpyTypeToPrimitiveType(array.dtype());
|
||||
auto type = xla::DtypeToPrimitiveType(array.dtype());
|
||||
if (!type.ok()) {
|
||||
throw std::runtime_error(type.status().ToString());
|
||||
}
|
||||
|
@ -16,10 +16,11 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/hash/hash.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
@ -33,7 +34,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/python/xrt.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
@ -129,34 +129,95 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.value("TOKEN", TOKEN);
|
||||
|
||||
// Shapes
|
||||
py::class_<Shape>(m, "Shape")
|
||||
py::class_<Shape> shape_class(m, "Shape");
|
||||
shape_class
|
||||
.def_static(
|
||||
"Tuple",
|
||||
"tuple_shape",
|
||||
[](std::vector<Shape> shapes) -> Shape {
|
||||
return ShapeUtil::MakeTupleShape(shapes);
|
||||
},
|
||||
"Makes a tuple shape.")
|
||||
"Constructs a tuple shape.")
|
||||
.def_static(
|
||||
"Array",
|
||||
[](PrimitiveType type, std::vector<int64> dims,
|
||||
absl::optional<std::vector<int64>> layout) -> Shape {
|
||||
if (layout) {
|
||||
return ShapeUtil::MakeShapeWithLayout(type, dims, *layout);
|
||||
"array_shape",
|
||||
[](PrimitiveType type, py::object dims_seq,
|
||||
absl::optional<py::object> layout_seq) -> Shape {
|
||||
std::vector<int64> dims = IntSequenceToVector(dims_seq);
|
||||
if (layout_seq) {
|
||||
std::vector<int64> layout = IntSequenceToVector(*layout_seq);
|
||||
return ShapeUtil::MakeShapeWithLayout(type, dims, layout);
|
||||
} else {
|
||||
Shape shape = ShapeUtil::MakeShape(type, dims);
|
||||
shape.clear_layout();
|
||||
return shape;
|
||||
}
|
||||
},
|
||||
"Makes an array shape.", py::arg("type"), py::arg("dims"),
|
||||
"Constructs an array shape.", py::arg("type"), py::arg("dims"),
|
||||
py::arg("layout") = absl::nullopt)
|
||||
.def_static(
|
||||
"array_shape",
|
||||
[](py::dtype dtype, py::object dims_seq,
|
||||
absl::optional<py::object> layout_seq) -> Shape {
|
||||
PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype));
|
||||
std::vector<int64> dims = IntSequenceToVector(dims_seq);
|
||||
if (layout_seq) {
|
||||
std::vector<int64> layout = IntSequenceToVector(*layout_seq);
|
||||
return ShapeUtil::MakeShapeWithLayout(type, dims, layout);
|
||||
} else {
|
||||
Shape shape = ShapeUtil::MakeShape(type, dims);
|
||||
shape.clear_layout();
|
||||
return shape;
|
||||
}
|
||||
},
|
||||
"Constructs an array shape.", py::arg("type"), py::arg("dims"),
|
||||
py::arg("layout") = absl::nullopt)
|
||||
.def("dimensions",
|
||||
static_cast<const std::vector<int64>& (Shape::*)() const>(
|
||||
&Shape::dimensions))
|
||||
.def("element_type", &Shape::element_type)
|
||||
[](const Shape& shape) -> py::tuple {
|
||||
return IntSpanToTuple(shape.dimensions());
|
||||
})
|
||||
.def("xla_element_type", &Shape::element_type)
|
||||
.def("element_type",
|
||||
[](const Shape& shape) {
|
||||
return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
|
||||
})
|
||||
.def("numpy_dtype",
|
||||
[](const Shape& shape) {
|
||||
if (shape.IsTuple()) {
|
||||
return py::dtype("O");
|
||||
}
|
||||
return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
|
||||
})
|
||||
.def("is_tuple", &Shape::IsTuple)
|
||||
.def("is_array", &Shape::IsArray)
|
||||
.def("rank", &Shape::rank)
|
||||
.def("to_serialized_proto",
|
||||
[](const Shape& shape) {
|
||||
ShapeProto proto = shape.ToProto();
|
||||
return py::bytes(proto.SerializeAsString());
|
||||
})
|
||||
.def("tuple_shapes",
|
||||
static_cast<const std::vector<Shape>& (Shape::*)() const>(
|
||||
&Shape::tuple_shapes))
|
||||
[](const Shape& shape) {
|
||||
return std::vector<Shape>(shape.tuple_shapes());
|
||||
})
|
||||
.def(
|
||||
"with_major_to_minor_layout_if_absent",
|
||||
[](const Shape& shape) {
|
||||
Shape out = shape;
|
||||
ShapeUtil::ForEachMutableSubshape(
|
||||
&out, [](Shape* subshape, const ShapeIndex&) {
|
||||
if (!subshape->has_layout()) {
|
||||
LayoutUtil::SetToDefaultLayout(subshape);
|
||||
}
|
||||
});
|
||||
return out;
|
||||
},
|
||||
"Returns a copy of a shape with missing layouts set to "
|
||||
"major-to-minor.")
|
||||
.def("__eq__", [](const Shape& shape,
|
||||
const Shape& other) { return shape == other; })
|
||||
.def("__ne__", [](const Shape& shape,
|
||||
const Shape& other) { return shape != other; })
|
||||
.def("__hash__",
|
||||
[](const Shape& shape) { return absl::Hash<Shape>()(shape); })
|
||||
.def("__repr__", [](const Shape& shape) {
|
||||
return shape.ToString(/*print_layouts=*/true);
|
||||
});
|
||||
@ -171,10 +232,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
*program_shape.mutable_result() = result;
|
||||
return program_shape;
|
||||
}))
|
||||
.def("Parameters",
|
||||
.def("parameter_shapes",
|
||||
static_cast<const std::vector<Shape>& (ProgramShape::*)() const>(
|
||||
&ProgramShape::parameters))
|
||||
.def("Result", &ProgramShape::result)
|
||||
.def("result_shape", &ProgramShape::result)
|
||||
.def("__repr__", &ProgramShape::ToString);
|
||||
|
||||
// Literals
|
||||
@ -211,22 +272,25 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
// CPU custom-call targets.
|
||||
m.def("RegisterCpuCustomCallTarget", &RegisterCpuCustomCallTarget);
|
||||
|
||||
// The LocalClient object allows dynamic attributes to allow external backends
|
||||
// (e.g., TPU) to stash private data in the client.
|
||||
py::class_<PyLocalClient>(m, "LocalClient", py::dynamic_attr())
|
||||
.def_static("Get", &PyLocalClient::Get)
|
||||
py::class_<PyLocalClient, std::shared_ptr<PyLocalClient>>(m, "LocalClient")
|
||||
.def_static("Get", &PyLocalClient::Get, py::arg("platform"),
|
||||
py::arg("xla_platform_id"), py::arg("asynchronous"))
|
||||
.def("DeviceCount", &PyLocalClient::device_count)
|
||||
.def("TransferToInfeed", &PyLocalClient::TransferToInfeed)
|
||||
.def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed);
|
||||
|
||||
py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
|
||||
.def_static("FromPython", &PyLocalBuffer::FromPython)
|
||||
.def_static("FromPythonValues", &PyLocalBuffer::FromPythonValues)
|
||||
.def_static("MakeTuple", &PyLocalBuffer::MakeTuple)
|
||||
.def("Delete", &PyLocalBuffer::Delete)
|
||||
.def("DestructureTuple", &PyLocalBuffer::DestructureTuple)
|
||||
.def("ToPython", &PyLocalBuffer::ToPython)
|
||||
.def("shape", &PyLocalBuffer::on_host_shape);
|
||||
.def_static("from_python", &PyLocalBuffer::FromPython)
|
||||
.def_static("from_python_values", &PyLocalBuffer::FromPythonValues)
|
||||
.def_static("make_tuple", &PyLocalBuffer::MakeTuple)
|
||||
.def("delete", &PyLocalBuffer::Delete)
|
||||
.def("destructure", &PyLocalBuffer::DestructureTuple)
|
||||
.def("to_py", &PyLocalBuffer::ToPython)
|
||||
.def("shape", &PyLocalBuffer::on_host_shape)
|
||||
.def("device", &PyLocalBuffer::device_ordinal)
|
||||
.def("is_deleted", [](const PyLocalBuffer& buffer) {
|
||||
return buffer.device_buffer() == nullptr;
|
||||
});
|
||||
|
||||
py::class_<PyLocalExecutable>(m, "LocalExecutable")
|
||||
.def_static("Compile", &PyLocalExecutable::Compile,
|
||||
@ -301,7 +365,12 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
// XlaBuilder.
|
||||
py::module ops = m.def_submodule("ops", "XLA operations");
|
||||
|
||||
ops.def("AllReduce",
|
||||
static_cast<XlaOp (*)(
|
||||
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
|
||||
const absl::optional<ChannelHandle>&)>(&CrossReplicaSum));
|
||||
ops.def("AllToAll", &AllToAll);
|
||||
ops.def("CollectivePermute", &CollectivePermute);
|
||||
ops.def("CrossReplicaSum",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
|
||||
&CrossReplicaSum));
|
||||
|
@ -28,7 +28,6 @@ import os
|
||||
import numpy as np
|
||||
|
||||
import six
|
||||
from six.moves import xrange
|
||||
|
||||
# Note this module does *not* depend on any Python protocol buffers. The XLA
|
||||
# Python bindings are currently packaged both as part of jaxlib and as part
|
||||
@ -71,18 +70,10 @@ class Backend(object):
|
||||
for pyval, device in pyvals_and_devices
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_buffer(self, c_buffer):
|
||||
"""Deletes buffer `c_buffer`."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def make_tuple(self, c_buffers, device_ordinal):
|
||||
"""Makes a tuple from a sequence of backend buffer objects."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def destructure_tuple(self, c_buffer):
|
||||
"""Destructures a tuple buffer into a sequence of buffers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def compile(self, computation, compile_options):
|
||||
"""Compiles a computation. Returns an executable."""
|
||||
@ -103,47 +94,38 @@ class Backend(object):
|
||||
class LocalBackend(Backend):
|
||||
"""XLA backend implemented using the in-process xla::LocalClient API."""
|
||||
|
||||
def __init__(self, platform=None, xla_platform_id=None, asynchronous=False):
|
||||
def __init__(self, platform, client):
|
||||
"""Creates a new LocalBackend.
|
||||
|
||||
Args:
|
||||
platform: A string; the user-visible platform name, e.g. 'gpu'.
|
||||
xla_platform_id: A string; XLA's name for the platform, e.g., 'CUDA'.
|
||||
asynchronous: A boolean; should we enable asynchronous execution?
|
||||
(Experimental.)
|
||||
client: An _xla.PyLocalClient object.
|
||||
"""
|
||||
super(LocalBackend, self).__init__(platform)
|
||||
self.client = _xla.LocalClient.Get(platform, xla_platform_id, asynchronous)
|
||||
self.client = client
|
||||
|
||||
def device_count(self):
|
||||
return self.client.DeviceCount()
|
||||
|
||||
def buffer_from_pyval(self, pyval, device=0):
|
||||
return _xla.PyLocalBuffer.FromPython(pyval, self.client, device)
|
||||
return _xla.PyLocalBuffer.from_python(pyval, self.client, device)
|
||||
|
||||
def buffers_from_pyvals(self, pyvals_and_devices):
|
||||
return _xla.PyLocalBuffer.FromPythonValues(pyvals_and_devices, self.client)
|
||||
|
||||
def delete_buffer(self, c_buffer):
|
||||
c_buffer.Delete()
|
||||
return _xla.PyLocalBuffer.from_python_values(pyvals_and_devices,
|
||||
self.client)
|
||||
|
||||
def make_tuple(self, c_buffers, device_ordinal):
|
||||
return _xla.PyLocalBuffer.MakeTuple(c_buffers, self.client, device_ordinal)
|
||||
|
||||
def destructure_tuple(self, c_buffer):
|
||||
return c_buffer.DestructureTuple()
|
||||
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device_ordinal)
|
||||
|
||||
def compile(self, c_computation, compile_options):
|
||||
options = _xla.ExecutableBuildOptions()
|
||||
options.num_replicas = compile_options.num_replicas
|
||||
if compile_options.argument_layouts:
|
||||
argument_layouts = [
|
||||
s.as_xla_shape() for s in compile_options.argument_layouts
|
||||
]
|
||||
argument_layouts = compile_options.argument_layouts
|
||||
else:
|
||||
argument_layouts = c_computation.GetProgramShape().Parameters()
|
||||
argument_layouts = c_computation.GetProgramShape().parameter_shapes()
|
||||
if compile_options.result_layout:
|
||||
options.result_layout = compile_options.result_layout.as_xla_shape()
|
||||
options.result_layout = compile_options.result_layout
|
||||
options.debug_options.xla_cpu_fast_math_honor_infs = True
|
||||
options.debug_options.xla_cpu_fast_math_honor_nans = True
|
||||
return _xla.LocalExecutable.Compile(c_computation, argument_layouts,
|
||||
@ -159,10 +141,22 @@ class LocalBackend(Backend):
|
||||
return executable.ExecutePerReplica(per_replica_args)
|
||||
|
||||
|
||||
def _cpu_backend_factory():
|
||||
client = _xla.LocalClient.Get(
|
||||
platform='cpu', xla_platform_id='Host', asynchronous=True)
|
||||
return LocalBackend(platform='cpu', client=client)
|
||||
|
||||
|
||||
def _gpu_backend_factory():
|
||||
client = _xla.LocalClient.Get(
|
||||
platform='gpu', xla_platform_id='CUDA', asynchronous=False)
|
||||
return LocalBackend(platform='gpu', client=client)
|
||||
|
||||
|
||||
# Backend factories, keyed by user-visible name, in increasing priority order.
|
||||
_local_backend_factories = collections.OrderedDict([
|
||||
('cpu', lambda: LocalBackend(platform='cpu', xla_platform_id='Host')),
|
||||
('gpu', lambda: LocalBackend(platform='gpu', xla_platform_id='CUDA')),
|
||||
('cpu', _cpu_backend_factory),
|
||||
('gpu', _gpu_backend_factory),
|
||||
])
|
||||
|
||||
|
||||
@ -291,95 +285,12 @@ def dtype_to_etype(dtype):
|
||||
return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
|
||||
|
||||
|
||||
class Buffer(object):
|
||||
"""Represents a handle to data owned by XLA.
|
||||
|
||||
The referent is ready for use in executing a local, compiled
|
||||
Computation. On XLA platforms involving a device (e.g. GPU), this
|
||||
means the referent is in device memory.
|
||||
"""
|
||||
|
||||
def __init__(self, c_buffer, backend, device):
|
||||
self.c_buffer = c_buffer
|
||||
self._backend = backend
|
||||
self._device = device
|
||||
|
||||
@staticmethod
|
||||
def from_pyval(pyval, device=0, backend=None):
|
||||
"""Copies the `pyval` to a freshly allocated on-device buffer."""
|
||||
backend = backend or get_local_backend()
|
||||
pyval = require_numpy_array_layout(pyval)
|
||||
cbuf = backend.buffer_from_pyval(pyval, device)
|
||||
return Buffer(cbuf, backend, device)
|
||||
|
||||
@staticmethod
|
||||
def from_pyvals(pyvals_and_devices, backend=None):
|
||||
"""Copies multiple Python values to freshly allocated on-device buffers.
|
||||
|
||||
Arguments:
|
||||
pyvals_and_devices: a list of `(pyval, device)` pairs, where `pyval` is
|
||||
a Python value to copy (e.g., a NumPy array), and `device` is an integer
|
||||
device ordinal.
|
||||
backend: a Backend object, or `None` to use the default local backend.
|
||||
Returns:
|
||||
A list of `Buffer` objects corresponding to `pyvals_and_devices`.
|
||||
"""
|
||||
backend = backend or get_local_backend()
|
||||
pyvals_and_devices = [(require_numpy_array_layout(pyval), device)
|
||||
for pyval, device in pyvals_and_devices]
|
||||
cbufs = backend.buffers_from_pyvals(pyvals_and_devices)
|
||||
return [
|
||||
Buffer(cbuf, backend, device)
|
||||
for cbuf, (_, device) in zip(cbufs, pyvals_and_devices)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def make_tuple(buffers, backend=None, device=0):
|
||||
backend = backend or get_local_backend()
|
||||
buf = backend.make_tuple([b.c_buffer for b in buffers],
|
||||
device_ordinal=device)
|
||||
return Buffer(buf, backend, device)
|
||||
|
||||
def to_py(self):
|
||||
return self.c_buffer.ToPython()
|
||||
|
||||
def shape(self):
|
||||
return _wrap_shape(self.c_buffer.shape())
|
||||
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
def delete(self):
|
||||
if self.c_buffer is not None:
|
||||
self._backend.delete_buffer(self.c_buffer)
|
||||
self.c_buffer = None
|
||||
|
||||
def destructure(self):
|
||||
"""Assuming a tuple buffer, unpack it into constituent tuple elements."""
|
||||
assert self.c_buffer is not None
|
||||
result = self._backend.destructure_tuple(self.c_buffer)
|
||||
return tuple(
|
||||
Buffer(sub_buffer, device=self._device, backend=self._backend)
|
||||
for sub_buffer in result)
|
||||
|
||||
def is_deleted(self):
|
||||
return self.c_buffer is None
|
||||
|
||||
|
||||
# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops
|
||||
# compatibility with Jaxlib versions older than 0.1.13.
|
||||
LocalBuffer = Buffer
|
||||
|
||||
|
||||
class Format(enum.IntEnum):
|
||||
"""Python copy of the Format protocol buffer enum."""
|
||||
INVALID_FORMAT = 0
|
||||
DENSE = 1
|
||||
SPARSE = 2
|
||||
|
||||
Shape = _xla.Shape
|
||||
Shape.__doc__ = """
|
||||
A Shape is an object defined in C++ that duck types like the following class:
|
||||
|
||||
class Shape(object):
|
||||
"""Represents an XLA shape.
|
||||
'''Represents an XLA shape.
|
||||
|
||||
A shape is either an array shape, having rank-many integer
|
||||
dimensions and an element type (represented by a Numpy dtype), or it
|
||||
@ -388,188 +299,120 @@ class Shape(object):
|
||||
type shape =
|
||||
TupleShape of shape list
|
||||
| ArrayShape of { dimensions: int list; element_type: dtype }
|
||||
'''
|
||||
|
||||
Callers are expected to instantiate this class only via the static
|
||||
constructors: tuple_shape, array_shape, and from_pyval.
|
||||
@staticmethod
|
||||
def tuple_shape(tuple_shapes) -> Shape:
|
||||
"Construct a tuple shape."
|
||||
|
||||
@staticmethod
|
||||
def array_shape(element_type, dimensions, minor_to_major=None) -> Shape:
|
||||
|
||||
@staticmethod
|
||||
def from_pyval(pyval) -> Shape:
|
||||
"Returns a Shape that describes a tuple-tree of Numpy arrays."
|
||||
|
||||
def __eq__(self, other: Shape) -> bool:
|
||||
def __ne__(self, other: Shape) -> bool:
|
||||
def __hash__(self):
|
||||
def __repr__(self):
|
||||
def is_tuple(self) -> bool:
|
||||
def is_array(self) -> bool:
|
||||
def tuple_shapes(self) -> [Shape]:
|
||||
def numpy_dtype(self) -> np.dtype:
|
||||
"Like element_type(), but returns dtype('O') for a tuple shape."
|
||||
def xla_element_type(self) -> PrimitiveType:
|
||||
def element_type(self) -> np.dtype:
|
||||
def dimensions(self) -> (int, int, ...):
|
||||
def rank(self) -> int:
|
||||
def minor_to_major(self) -> [int]:
|
||||
def with_major_to_minor_layout_if_absent(self) -> Shape:
|
||||
"Returns a copy with missing layouts set to major-to-minor."
|
||||
|
||||
def to_serialized_proto(self) -> bytes:
|
||||
"Returns 'shape' as a serialized proto."
|
||||
"""
|
||||
|
||||
ProgramShape = _xla.ProgramShape
|
||||
ProgramShape.__doc__ = """
|
||||
A ProgramShape is a C++ object that duck types like the following class.
|
||||
|
||||
class ProgramShape(object):
|
||||
def __init__(self, parameter_shapes, result_shape):
|
||||
def parameter_shapes(self) -> [Shape]:
|
||||
def result_shape(self) -> Shape:
|
||||
def __repr__(self):
|
||||
"""
|
||||
|
||||
|
||||
class Buffer(object):
|
||||
"""Represents a handle to data owned by XLA.
|
||||
|
||||
The referent is ready for use in executing a local, compiled
|
||||
Computation. On XLA platforms involving a device (e.g. GPU), this
|
||||
means the referent is in device memory.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def tuple_shape(tuple_shapes):
|
||||
"""Construct a tuple shape."""
|
||||
if (not isinstance(tuple_shapes, (tuple, list)) or
|
||||
not all(isinstance(t, Shape) for t in tuple_shapes)):
|
||||
raise TypeError('tuple_shapes must be a tuple of Shapes')
|
||||
return Shape(tuple_shapes, tuple)
|
||||
def from_pyval(pyval, device=0, backend=None):
|
||||
"""Copies the `pyval` to a freshly allocated on-device buffer."""
|
||||
backend = backend or get_local_backend()
|
||||
pyval = require_numpy_array_layout(pyval)
|
||||
return backend.buffer_from_pyval(pyval, device)
|
||||
|
||||
@staticmethod
|
||||
def array_shape(element_type, dimensions, minor_to_major=None):
|
||||
"""Construct an array shape."""
|
||||
if (not isinstance(dimensions, tuple) or
|
||||
not all(isinstance(i, int) for i in dimensions)):
|
||||
dimensions = tuple(int(i) for i in dimensions)
|
||||
return Shape(
|
||||
dimensions, np.dtype(element_type), minor_to_major=minor_to_major)
|
||||
def from_pyvals(pyvals_and_devices, backend=None):
|
||||
"""Copies multiple Python values to freshly allocated on-device buffers.
|
||||
|
||||
@staticmethod
|
||||
def from_pyval(pyval):
|
||||
"""Returns a Shape that describes a tuple-tree of Numpy arrays."""
|
||||
|
||||
def convert(pyval):
|
||||
if isinstance(pyval, tuple):
|
||||
return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
|
||||
else:
|
||||
pyval = require_numpy_array_layout(pyval)
|
||||
return Shape.array_shape(pyval.dtype, np.shape(pyval))
|
||||
|
||||
return convert(pyval)
|
||||
|
||||
def __init__(self, dimensions, dtype, minor_to_major=None):
|
||||
assert isinstance(dimensions, tuple)
|
||||
self._dimensions = dimensions
|
||||
self._dtype = dtype
|
||||
self._is_tuple = dtype == tuple
|
||||
self._minor_to_major = minor_to_major
|
||||
self._check_minor_to_major()
|
||||
|
||||
def __eq__(self, other):
|
||||
# pylint: disable=protected-access
|
||||
return (self._dtype == other._dtype and
|
||||
self._dimensions == other._dimensions and
|
||||
self._minor_to_major == other._minor_to_major)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self._dtype, self._dimensions, self._minor_to_major))
|
||||
|
||||
def __repr__(self):
|
||||
return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, '
|
||||
'_is_tuple={!r}, _minor_to_major={!r})').format(
|
||||
self._dtype, self._dimensions, self._is_tuple,
|
||||
self._minor_to_major)
|
||||
|
||||
def is_tuple(self):
|
||||
return self._is_tuple
|
||||
|
||||
def is_array(self):
|
||||
return not self._is_tuple
|
||||
|
||||
def tuple_shapes(self):
|
||||
if not self.is_tuple():
|
||||
raise ValueError('not a tuple shape')
|
||||
return self._dimensions
|
||||
|
||||
def numpy_dtype(self):
|
||||
"""Like element_type(), but returns dtype('O') in case of a tuple shape."""
|
||||
if self.is_tuple():
|
||||
return np.dtype(np.object)
|
||||
else:
|
||||
return self.element_type()
|
||||
|
||||
def xla_element_type(self):
|
||||
return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.numpy_dtype())]
|
||||
|
||||
def element_type(self):
|
||||
if not self.is_array():
|
||||
raise ValueError('not an array shape')
|
||||
return self._dtype
|
||||
|
||||
def dimensions(self):
|
||||
if not self.is_array():
|
||||
raise ValueError('not an array shape')
|
||||
return self._dimensions
|
||||
|
||||
def rank(self):
|
||||
return len(self.dimensions())
|
||||
|
||||
def minor_to_major(self):
|
||||
return self._minor_to_major
|
||||
|
||||
def map_leaves(self, f):
|
||||
"""Map f over each leaf-level array subshape.
|
||||
|
||||
Args:
|
||||
f: The function to apply. Whenever f returns None, the identity is applied
|
||||
instead.
|
||||
Arguments:
|
||||
pyvals_and_devices: a list of `(pyval, device)` pairs, where `pyval` is a
|
||||
Python value to copy (e.g., a NumPy array), and `device` is an integer
|
||||
device ordinal.
|
||||
backend: a Backend object, or `None` to use the default local backend.
|
||||
|
||||
Returns:
|
||||
A new Shape with the mapped leaves.
|
||||
A list of `Buffer` objects corresponding to `pyvals_and_devices`.
|
||||
"""
|
||||
if self.is_tuple():
|
||||
children = tuple(child.map_leaves(f) for child in self.tuple_shapes())
|
||||
return Shape.tuple_shape(children)
|
||||
backend = backend or get_local_backend()
|
||||
pyvals_and_devices = [(require_numpy_array_layout(pyval), device)
|
||||
for pyval, device in pyvals_and_devices]
|
||||
return backend.buffers_from_pyvals(pyvals_and_devices)
|
||||
|
||||
@staticmethod
|
||||
def make_tuple(buffers, backend=None, device=0):
|
||||
backend = backend or get_local_backend()
|
||||
return backend.make_tuple(buffers, device_ordinal=device)
|
||||
|
||||
# Buffer is not an instantiable type and exists only for its static methods.
|
||||
# The underlying buffer objects are C++ object with the following
|
||||
# API:
|
||||
# def to_py(self):
|
||||
# def shape(self) -> Shape:
|
||||
# def device(self) -> int:
|
||||
# def delete(self):
|
||||
# def destructure(self) -> [Buffer]
|
||||
# def is_deleted(self) -> bool:
|
||||
#
|
||||
# TODO(phawkins): remove Buffer and its static methods completely, have
|
||||
# clients call methods on Backend to create buffers.
|
||||
|
||||
|
||||
# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops
|
||||
# compatibility with Jaxlib versions older than 0.1.13.
|
||||
LocalBuffer = Buffer
|
||||
|
||||
|
||||
def shape_from_pyval(pyval):
|
||||
"""Returns a Shape that describes a tuple-tree of Numpy arrays."""
|
||||
|
||||
def convert(pyval):
|
||||
if isinstance(pyval, tuple):
|
||||
return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
|
||||
else:
|
||||
mapped = f(self)
|
||||
return self if mapped is None else mapped
|
||||
pyval = require_numpy_array_layout(pyval)
|
||||
return Shape.array_shape(pyval.dtype, np.shape(pyval))
|
||||
|
||||
def _check_minor_to_major(self):
|
||||
mtm = self._minor_to_major
|
||||
if self.is_tuple():
|
||||
assert mtm is None, self
|
||||
if mtm is not None:
|
||||
assert self.rank() == len(mtm), self
|
||||
assert sorted(mtm) == list(range(len(mtm))), self
|
||||
|
||||
def update_minor_to_major(self, minor_to_major):
|
||||
if not self.is_array():
|
||||
raise ValueError('not an array shape')
|
||||
if not isinstance(minor_to_major, tuple):
|
||||
raise TypeError('minor_to_major must be a tuple')
|
||||
updated = Shape.array_shape(self.element_type(), self.dimensions(),
|
||||
minor_to_major)
|
||||
updated._check_minor_to_major() # pylint: disable=protected-access
|
||||
return updated
|
||||
|
||||
def with_major_to_minor_layout_if_absent(self):
|
||||
"""Returns a copy of a shape with missing layouts set to major-to-minor."""
|
||||
|
||||
def f(a):
|
||||
if a.minor_to_major():
|
||||
return None
|
||||
return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1)))
|
||||
|
||||
return self.map_leaves(f)
|
||||
|
||||
def serialize(self, proto):
|
||||
"""Serializes 'shape' into proto."""
|
||||
if self.is_tuple():
|
||||
proto.element_type = int(PrimitiveType.TUPLE)
|
||||
for shape in self.tuple_shapes():
|
||||
shape.serialize(proto.tuple_shapes.add())
|
||||
else:
|
||||
proto.element_type = int(self.xla_element_type())
|
||||
proto.dimensions.extend(self.dimensions())
|
||||
proto.is_dynamic_dimension.extend([False for _ in self.dimensions()])
|
||||
if self.minor_to_major():
|
||||
proto.layout.format = Format.DENSE
|
||||
proto.layout.minor_to_major.extend(self.minor_to_major())
|
||||
|
||||
def as_xla_shape(self):
|
||||
if self.is_tuple():
|
||||
return _xla.Shape.Tuple([x.as_xla_shape() for x in self.tuple_shapes()])
|
||||
|
||||
return _xla.Shape.Array(self.xla_element_type(), self.dimensions(),
|
||||
self.minor_to_major())
|
||||
|
||||
|
||||
ProgramShape = collections.namedtuple('ProgramShape',
|
||||
('parameter_shapes', 'result_shape'))
|
||||
|
||||
|
||||
def _wrap_shape(xla_shape):
|
||||
element_type = xla_shape.element_type()
|
||||
if element_type == PrimitiveType.TUPLE:
|
||||
shapes = tuple(_wrap_shape(sub) for sub in xla_shape.tuple_shapes())
|
||||
return Shape.tuple_shape(shapes)
|
||||
else:
|
||||
dtype = XLA_ELEMENT_TYPE_TO_DTYPE[element_type]
|
||||
return Shape.array_shape(dtype, xla_shape.dimensions())
|
||||
|
||||
|
||||
def _wrap_program_shape(program_shape):
|
||||
return ProgramShape([_wrap_shape(arg) for arg in program_shape.Parameters()],
|
||||
_wrap_shape(program_shape.Result()))
|
||||
return convert(pyval)
|
||||
|
||||
|
||||
def require_numpy_array_layout(value):
|
||||
@ -612,8 +455,7 @@ def transfer_from_outfeed(shape, device_ordinal=0):
|
||||
# TODO(phawkins): support non-default backends.
|
||||
backend = get_local_backend()
|
||||
return backend.client.TransferFromOutfeed(
|
||||
shape.with_major_to_minor_layout_if_absent().as_xla_shape(),
|
||||
device_ordinal)
|
||||
shape.with_major_to_minor_layout_if_absent(), device_ordinal)
|
||||
|
||||
|
||||
class CompileOptions(object):
|
||||
@ -699,10 +541,10 @@ class Computation(object):
|
||||
return Executable(c, backend=backend)
|
||||
|
||||
def GetProgramShape(self):
|
||||
return _wrap_program_shape(self._c_computation.GetProgramShape())
|
||||
return self._c_computation.GetProgramShape()
|
||||
|
||||
def GetReturnValueShape(self):
|
||||
return _wrap_shape(self._c_computation.GetProgramShape().Result())
|
||||
return self._c_computation.GetProgramShape().result_shape()
|
||||
|
||||
|
||||
class Executable(object):
|
||||
@ -717,14 +559,12 @@ class Executable(object):
|
||||
"""Returns a list containing the device ordinals for each replica."""
|
||||
return self._device_ordinals
|
||||
|
||||
def Execute(self, arguments=(), check_for_deleted_args=True):
|
||||
def Execute(self, arguments=None, check_for_deleted_args=True):
|
||||
"""Execute on one replica with Buffer arguments and return value."""
|
||||
arguments = arguments or []
|
||||
if check_for_deleted_args and any(arg.is_deleted() for arg in arguments):
|
||||
raise ValueError('Executing with deleted local buffer argument')
|
||||
raw_args = [arg.c_buffer for arg in arguments]
|
||||
output_buffer = self._backend.execute(self._c_executable, raw_args)
|
||||
return Buffer(
|
||||
output_buffer, backend=self._backend, device=self._device_ordinals[0])
|
||||
return self._backend.execute(self._c_executable, arguments)
|
||||
|
||||
def ExecutePerReplica(self, arguments=None):
|
||||
"""Execute on many replicas with Buffer arguments and return value.
|
||||
@ -753,23 +593,8 @@ class Executable(object):
|
||||
'Executing on device {} with argument from device {}'.format(
|
||||
self._device_ordinals[replica], arg.device()))
|
||||
|
||||
# Pull out argument buffer handles
|
||||
# pylint: disable=g-complex-comprehension
|
||||
stripped_args = [
|
||||
[arg.c_buffer for arg in replica_args] for replica_args in arguments
|
||||
]
|
||||
|
||||
# Execute
|
||||
output_buffers = self._backend.execute_replicated(self._c_executable,
|
||||
stripped_args)
|
||||
|
||||
# Wrap output handles in Buffer instances
|
||||
return tuple(
|
||||
Buffer(
|
||||
output_buffer,
|
||||
backend=self._backend,
|
||||
device=self._device_ordinals[replica])
|
||||
for replica, output_buffer in enumerate(output_buffers))
|
||||
return self._backend.execute_replicated(self._c_executable, arguments)
|
||||
|
||||
def ExecuteWithPythonValues(self, arguments=()):
|
||||
"""Execute on one replica with Python values as arguments and output."""
|
||||
@ -877,7 +702,7 @@ class ComputationBuilder(object):
|
||||
return Computation(self._builder.Build(), backend=backend)
|
||||
|
||||
def GetShape(self, operand):
|
||||
return _wrap_shape(self._builder.GetShape(operand))
|
||||
return self._builder.GetShape(operand)
|
||||
|
||||
def SetOpMetadata(self, op_metadata):
|
||||
"""Set metadata for operations that are about to be enqueued."""
|
||||
@ -896,9 +721,8 @@ class ComputationBuilder(object):
|
||||
Returns:
|
||||
An XlaOp.
|
||||
"""
|
||||
return ops.Infeed(
|
||||
self._builder,
|
||||
shape.with_major_to_minor_layout_if_absent().as_xla_shape())
|
||||
return ops.Infeed(self._builder,
|
||||
shape.with_major_to_minor_layout_if_absent())
|
||||
|
||||
def Outfeed(self, operand):
|
||||
"""Enqueues an outfeed op onto the computation.
|
||||
@ -995,10 +819,9 @@ class ComputationBuilder(object):
|
||||
if parameter_num is None:
|
||||
parameter_num = next(self._parameter_numbering)
|
||||
|
||||
return ops.Parameter(
|
||||
self._builder, parameter_num,
|
||||
shape.with_major_to_minor_layout_if_absent().as_xla_shape(),
|
||||
name.encode('utf8'))
|
||||
return ops.Parameter(self._builder, parameter_num,
|
||||
shape.with_major_to_minor_layout_if_absent(),
|
||||
name.encode('utf8'))
|
||||
|
||||
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
|
||||
"""Enqueues a Parameter op onto the computation.
|
||||
@ -1013,7 +836,7 @@ class ComputationBuilder(object):
|
||||
An XlaOp.
|
||||
"""
|
||||
return self.ParameterWithShape(
|
||||
Shape.from_pyval(value), name=name, parameter_num=parameter_num)
|
||||
shape_from_pyval(value), name=name, parameter_num=parameter_num)
|
||||
|
||||
def Iota(self, dtype, size):
|
||||
"""Enqueues an iota constant onto the computation.
|
||||
@ -1040,7 +863,7 @@ class ComputationBuilder(object):
|
||||
An XlaOp representing the added broadcasted iota constant.
|
||||
"""
|
||||
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
|
||||
xla_shape = _xla.Shape.Array(element_type, shape, None)
|
||||
xla_shape = _xla.Shape.array_shape(element_type, shape, None)
|
||||
return ops.Iota(self._builder, xla_shape, dimension)
|
||||
|
||||
def Concatenate(self, operands, dimension):
|
||||
@ -1097,6 +920,24 @@ class ComputationBuilder(object):
|
||||
dimensions = tuple(range(ndim))
|
||||
return ops.Reshape(operand, dimensions, new_sizes)
|
||||
|
||||
def AllReduce(self, operand, computation, replica_groups=None):
|
||||
"""AllReduce op.
|
||||
|
||||
Args:
|
||||
operand: XlaOp representing the input array
|
||||
computation: a Computation object - binary reduction function.
|
||||
replica_groups: optional, list of lists of ints encoding a partition of
|
||||
the set {0, 1, ..., num_replicas} into equally-sized replica groups
|
||||
within which the all-to-all is performed. If not supplied or None (the
|
||||
default), all replicas belong to the same group.
|
||||
|
||||
Returns:
|
||||
An XlaOp that represents the all-reduced result.
|
||||
"""
|
||||
replica_groups_protos = _get_replica_groups_protos(replica_groups)
|
||||
return ops.AllReduce(operand, computation.computation,
|
||||
replica_groups_protos, None)
|
||||
|
||||
def AllToAll(self,
|
||||
operand,
|
||||
split_dimension,
|
||||
@ -1117,13 +958,7 @@ class ComputationBuilder(object):
|
||||
Returns:
|
||||
An XlaOp that represents the all-to-all concatenation.
|
||||
"""
|
||||
if replica_groups is None:
|
||||
replica_groups_protos = [] # special value for XLA API
|
||||
else:
|
||||
replica_groups = list(replica_groups)
|
||||
replica_groups_protos = [
|
||||
_make_replica_group_proto(group) for group in replica_groups
|
||||
]
|
||||
replica_groups_protos = _get_replica_groups_protos(replica_groups)
|
||||
if not replica_groups:
|
||||
split_count = 1
|
||||
else:
|
||||
@ -1146,13 +981,8 @@ class ComputationBuilder(object):
|
||||
Returns:
|
||||
An XlaOp that represents on each replica the sum of its group's values.
|
||||
"""
|
||||
if replica_groups is None:
|
||||
replica_groups = [] # special value for XLA API
|
||||
else:
|
||||
replica_groups = [
|
||||
_make_replica_group_proto(group) for group in replica_groups
|
||||
]
|
||||
return ops.CrossReplicaSum(operand, replica_groups)
|
||||
replica_groups_protos = _get_replica_groups_protos(replica_groups)
|
||||
return ops.CrossReplicaSum(operand, replica_groups_protos)
|
||||
|
||||
def Trans(self, operand):
|
||||
"""Specialized matrix transpose op."""
|
||||
@ -1298,10 +1128,9 @@ class ComputationBuilder(object):
|
||||
An XlaOp representing the added custom call op.
|
||||
"""
|
||||
opaque = opaque or b''
|
||||
return ops.CustomCall(
|
||||
self._builder, call_target_name, list(operands),
|
||||
shape_with_layout.as_xla_shape(),
|
||||
[s.as_xla_shape() for s in operand_shapes_with_layout], opaque)
|
||||
return ops.CustomCall(self._builder, call_target_name,
|
||||
list(operands), shape_with_layout,
|
||||
list(operand_shapes_with_layout), opaque)
|
||||
|
||||
def Map(self, operands, computation_to_apply, dimensions):
|
||||
"""Enqueues a map operation onto the computation.
|
||||
@ -1389,7 +1218,7 @@ class ComputationBuilder(object):
|
||||
dims: A 1D array-like of nonnegative integers specifying the dimensions.
|
||||
Returns: a XlaOp to the generated array of F32 values.
|
||||
"""
|
||||
shape = _xla.Shape.Array(self.GetShape(mu).xla_element_type(), dims)
|
||||
shape = _xla.Shape.array_shape(self.GetShape(mu).xla_element_type(), dims)
|
||||
return ops.RngNormal(mu, sigma, shape)
|
||||
|
||||
def RngUniform(self, a, b, dims):
|
||||
@ -1406,7 +1235,7 @@ class ComputationBuilder(object):
|
||||
Returns: a XlaOp to the generated array of values with the same numeric type
|
||||
(F32, S32, or U32) as the arguments a and b.
|
||||
"""
|
||||
shape = _xla.Shape.Array(self.GetShape(a).xla_element_type(), dims)
|
||||
shape = _xla.Shape.array_shape(self.GetShape(a).xla_element_type(), dims)
|
||||
return ops.RngUniform(a, b, shape)
|
||||
|
||||
def While(self, cond, body, init):
|
||||
@ -1659,7 +1488,6 @@ class ComputationBuilder(object):
|
||||
|
||||
FftType = _xla.FftType
|
||||
|
||||
|
||||
_UNARY_OPS = [
|
||||
'Not',
|
||||
'Clz',
|
||||
@ -1732,6 +1560,7 @@ _OTHER_OPS = [
|
||||
'Cholesky',
|
||||
'Clamp',
|
||||
'Collapse',
|
||||
'CollectivePermute',
|
||||
'ConvertElementType',
|
||||
'Dot',
|
||||
'Gather',
|
||||
@ -1885,3 +1714,14 @@ def _make_replica_group_proto(replica_group):
|
||||
replica_group_proto = ReplicaGroup()
|
||||
replica_group_proto.replica_ids.extend(replica_group)
|
||||
return replica_group_proto
|
||||
|
||||
|
||||
def _get_replica_groups_protos(replica_groups):
|
||||
if replica_groups is None:
|
||||
replica_groups_protos = [] # special value for XLA API
|
||||
else:
|
||||
replica_groups = list(replica_groups)
|
||||
replica_groups_protos = [
|
||||
_make_replica_group_proto(group) for group in replica_groups
|
||||
]
|
||||
return replica_groups_protos
|
||||
|
@ -315,10 +315,11 @@ class ComputationsWithConstantsTest(ComputationTest):
|
||||
c.CustomCall(
|
||||
b"test_subtract_f32",
|
||||
operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),
|
||||
shape_with_layout=xla_client.Shape.array_shape(np.float32, (), ()),
|
||||
shape_with_layout=xla_client.Shape.array_shape(
|
||||
np.dtype(np.float32), (), ()),
|
||||
operand_shapes_with_layout=(
|
||||
xla_client.Shape.array_shape(np.float32, (), ()),
|
||||
xla_client.Shape.array_shape(np.float32, (), ()),
|
||||
xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
|
||||
xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
|
||||
))
|
||||
self._ExecuteAndCompareClose(c, expected=0.75)
|
||||
|
||||
@ -1745,7 +1746,7 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
def testInfeedS32Values(self):
|
||||
to_infeed = NumpyArrayS32([1, 2, 3, 4])
|
||||
c = self._NewComputation()
|
||||
c.Infeed(xla_client.Shape.from_pyval(to_infeed[0]))
|
||||
c.Infeed(xla_client.shape_from_pyval(to_infeed[0]))
|
||||
compiled_c = c.Build().Compile()
|
||||
for item in to_infeed:
|
||||
xla_client.transfer_to_infeed(item)
|
||||
@ -1757,7 +1758,7 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
def testInfeedThenOutfeedS32(self):
|
||||
to_round_trip = NumpyArrayS32([1, 2, 3, 4])
|
||||
c = self._NewComputation()
|
||||
x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0]))
|
||||
x = c.Infeed(xla_client.shape_from_pyval(to_round_trip[0]))
|
||||
c.Outfeed(x)
|
||||
|
||||
compiled_c = c.Build().Compile()
|
||||
@ -1767,7 +1768,7 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
execution.start()
|
||||
xla_client.transfer_to_infeed(want)
|
||||
got = xla_client.transfer_from_outfeed(
|
||||
xla_client.Shape.from_pyval(to_round_trip[0]))
|
||||
xla_client.shape_from_pyval(to_round_trip[0]))
|
||||
execution.join()
|
||||
self.assertEqual(want, got)
|
||||
|
||||
@ -1803,7 +1804,9 @@ class ErrorTest(ComputationTest):
|
||||
c.ClearOpMetadata()
|
||||
|
||||
options = xla_client.CompileOptions()
|
||||
options.argument_layouts = [xla_client.Shape.array_shape(np.float32, [])]
|
||||
options.argument_layouts = [
|
||||
xla_client.Shape.array_shape(np.dtype(np.float32), [])
|
||||
]
|
||||
|
||||
def TestFun():
|
||||
return c.Build().Compile(compile_options=options)
|
||||
|
@ -84,9 +84,9 @@ void AddXrtSubmodule(py::module* module) {
|
||||
.def_property_readonly("tf_device_ids", &XrtContext::tf_device_ids);
|
||||
|
||||
py::class_<XrtBuffer, std::shared_ptr<XrtBuffer>>(m, "XrtBuffer")
|
||||
.def_static("FromLiteral", &XrtBuffer::FromLiteral)
|
||||
.def_static("MakeTuple", &XrtBuffer::MakeTuple)
|
||||
.def("ToPython",
|
||||
.def_static("from_literal", &XrtBuffer::FromLiteral)
|
||||
.def_static("make_tuple", &XrtBuffer::MakeTuple)
|
||||
.def("to_py",
|
||||
[](std::shared_ptr<XrtBuffer> buffer) -> xla::StatusOr<py::object> {
|
||||
auto literal = absl::make_unique<xla::Literal>();
|
||||
{
|
||||
@ -95,8 +95,10 @@ void AddXrtSubmodule(py::module* module) {
|
||||
}
|
||||
return xla::LiteralToPython(std::move(literal));
|
||||
})
|
||||
.def("Delete", &XrtBuffer::Delete)
|
||||
.def("DestructureTuple", &XrtBuffer::DestructureTuple);
|
||||
.def("delete", &XrtBuffer::Delete)
|
||||
.def("destructure", &XrtBuffer::DestructureTuple)
|
||||
.def("is_deleted",
|
||||
[](const XrtBuffer& buffer) { return !buffer.handle().valid(); });
|
||||
|
||||
py::class_<XrtExecutable, std::shared_ptr<XrtExecutable>>(m, "XrtExecutable")
|
||||
.def_static("Compile",
|
||||
|
@ -31,13 +31,6 @@ from tensorflow.compiler.xla.python import xla_extension as _xla
|
||||
# pylint: enable=g-direct-tensorflow-import
|
||||
|
||||
|
||||
def _make_xla_shape(shape):
|
||||
if shape.is_tuple():
|
||||
return _xla.Shape.Tuple([_make_xla_shape(s) for s in shape.tuple_shapes()])
|
||||
return _xla.Shape.Array(shape.xla_element_type(), shape.dimensions(),
|
||||
shape.minor_to_major())
|
||||
|
||||
|
||||
def get_tf_context(target, worker):
|
||||
"""Returns a TensorFlow RPC client object.
|
||||
|
||||
@ -60,7 +53,8 @@ class XrtBackend(xla_client.Backend):
|
||||
tf_device_type: the type of TensorFlow device to use for XRT (e.g. `"TPU"`).
|
||||
"""
|
||||
|
||||
def __init__(self, tf_context, tf_device_type):
|
||||
def __init__(self, tf_context, tf_device_type, platform="tpu"):
|
||||
super(XrtBackend, self).__init__(platform)
|
||||
self.tf_device_type = tf_device_type
|
||||
|
||||
self.context = _xla.xrt.XrtContext.Create(tf_context, tf_device_type)
|
||||
@ -69,30 +63,23 @@ class XrtBackend(xla_client.Backend):
|
||||
return self.context.DeviceCount()
|
||||
|
||||
def buffer_from_pyval(self, pyval, device=0):
|
||||
return _xla.xrt.XrtBuffer.FromLiteral(self.context, device, pyval)
|
||||
|
||||
def delete_buffer(self, c_buffer):
|
||||
c_buffer.Delete()
|
||||
|
||||
def destructure_tuple(self, c_buffer):
|
||||
return c_buffer.DestructureTuple()
|
||||
return _xla.xrt.XrtBuffer.from_literal(self.context, device, pyval)
|
||||
|
||||
def make_tuple(self, buffers, device_ordinal):
|
||||
return _xla.xrt.XrtBuffer.MakeTuple(self.context, buffers)
|
||||
return _xla.xrt.XrtBuffer.make_tuple(self.context, buffers)
|
||||
|
||||
def compile(self, computation, compile_options):
|
||||
# pylint: disable=protected-access
|
||||
program_shape = xla_client._wrap_program_shape(
|
||||
computation.GetProgramShape())
|
||||
program_shape = computation.GetProgramShape()
|
||||
# pylint: enable=protected-access
|
||||
proto = computation.GetSerializedProto()
|
||||
# TODO(phawkins): use the layouts in compile_options.
|
||||
arg_shapes = [
|
||||
_make_xla_shape(shape.with_major_to_minor_layout_if_absent())
|
||||
for shape in program_shape.parameter_shapes
|
||||
shape.with_major_to_minor_layout_if_absent()
|
||||
for shape in program_shape.parameter_shapes()
|
||||
]
|
||||
result_shape = _make_xla_shape(
|
||||
program_shape.result_shape.with_major_to_minor_layout_if_absent())
|
||||
result_shape = (
|
||||
program_shape.result_shape().with_major_to_minor_layout_if_absent())
|
||||
device_assignment = _xla.xrt.AssignDevices(compile_options.num_replicas, 1)
|
||||
return _xla.xrt.XrtExecutable.Compile(self.context, proto, arg_shapes,
|
||||
result_shape, device_assignment)
|
||||
|
@ -48,7 +48,7 @@ class XrtBackendTest(test.TestCase):
|
||||
b = np.arange(10)
|
||||
|
||||
c = BuildAddAndScaleComputation(
|
||||
xla_client.Shape.from_pyval(a), xla_client.Shape.from_pyval(b))
|
||||
xla_client.shape_from_pyval(a), xla_client.shape_from_pyval(b))
|
||||
|
||||
executable = c.Compile(backend=backend)
|
||||
output = executable.ExecuteWithPythonValues((a, b))
|
||||
|
@ -18,6 +18,9 @@ package_group(
|
||||
includes = [
|
||||
"//tensorflow/compiler/xla:friends",
|
||||
],
|
||||
packages = [
|
||||
"//learning/brain/experimental/tf_runtime/...",
|
||||
],
|
||||
)
|
||||
|
||||
xla_proto_library(
|
||||
@ -434,10 +437,10 @@ tf_cc_test(
|
||||
srcs = ["pattern_matcher_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_parser",
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -505,8 +508,8 @@ cc_library(
|
||||
hdrs = ["hlo_matchers.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
@ -549,13 +552,13 @@ tf_cc_test(
|
||||
srcs = ["hlo_sharding_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
@ -583,6 +586,7 @@ tf_cc_test(
|
||||
srcs = ["call_graph_test.cc"],
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -590,7 +594,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
@ -653,6 +656,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":call_graph",
|
||||
":flatten_call_graph",
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -660,7 +664,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
@ -691,7 +694,6 @@ cc_library(
|
||||
deps = [
|
||||
":compiler",
|
||||
":computation_placer",
|
||||
":device_memory_allocator",
|
||||
":platform_util",
|
||||
":stream_pool",
|
||||
":transfer_manager",
|
||||
@ -701,6 +703,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -721,7 +724,6 @@ cc_library(
|
||||
":compiler",
|
||||
":computation_layout",
|
||||
":computation_placer",
|
||||
":device_memory_allocator",
|
||||
":dump",
|
||||
":dynamic_dimension_inference",
|
||||
":executable",
|
||||
@ -751,6 +753,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:ptr_util",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -767,7 +770,6 @@ cc_library(
|
||||
":backend",
|
||||
":compiler",
|
||||
":computation_layout",
|
||||
":device_memory_allocator",
|
||||
":executable",
|
||||
":hlo",
|
||||
":hlo_execution_profile",
|
||||
@ -787,6 +789,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -855,7 +858,6 @@ cc_library(
|
||||
srcs = ["shaped_buffer.cc"],
|
||||
hdrs = ["shaped_buffer.h"],
|
||||
deps = [
|
||||
":device_memory_allocator",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -865,6 +867,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -878,7 +881,6 @@ tf_cc_test(
|
||||
srcs = ["shaped_buffer_test.cc"],
|
||||
deps = [
|
||||
":cpu_plugin",
|
||||
":device_memory_allocator",
|
||||
":platform_util",
|
||||
":shaped_buffer",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -888,6 +890,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:ptr_util",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
@ -901,7 +904,6 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":computation_layout",
|
||||
":device_memory_allocator",
|
||||
":dump",
|
||||
":hlo",
|
||||
":hlo_execution_profile",
|
||||
@ -922,6 +924,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -988,7 +991,6 @@ cc_library(
|
||||
hdrs = ["allocation_tracker.h"],
|
||||
deps = [
|
||||
":backend",
|
||||
":device_memory_allocator",
|
||||
":transfer_manager",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -997,6 +999,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -1156,6 +1159,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_ordering",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -1163,7 +1167,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
@ -1205,10 +1208,10 @@ tf_cc_test(
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_ordering",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
@ -1455,8 +1458,8 @@ tf_cc_test(
|
||||
srcs = ["instruction_fusion_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":instruction_fusion",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
@ -1467,11 +1470,11 @@ cc_library(
|
||||
srcs = ["multi_output_fusion.cc"],
|
||||
hdrs = ["multi_output_fusion.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":hlo_reachability",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
@ -1663,6 +1666,7 @@ cc_library(
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1788,8 +1792,8 @@ tf_cc_test(
|
||||
srcs = ["gather_expander_test.cc"],
|
||||
deps = [
|
||||
":gather_expander",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:test_macros_header",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
],
|
||||
@ -1887,9 +1891,9 @@ tf_cc_test(
|
||||
name = "while_loop_analysis_test",
|
||||
srcs = ["while_loop_analysis_test.cc"],
|
||||
deps = [
|
||||
":hlo_parser",
|
||||
":while_loop_analysis",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
@ -2294,7 +2298,7 @@ tf_cc_test(
|
||||
":cpu_plugin",
|
||||
":hlo_cost_analysis",
|
||||
":hlo_execution_profile",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
@ -2307,14 +2311,14 @@ tf_cc_test(
|
||||
srcs = ["hlo_computation_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":pattern_matcher",
|
||||
":pattern_matcher_gmock",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -2519,13 +2523,13 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_liveness_analysis",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
@ -2909,12 +2913,12 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_module_dce",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
@ -3040,12 +3044,12 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_cse",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
@ -3229,27 +3233,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "device_memory_allocator",
|
||||
srcs = [
|
||||
"device_memory_allocator.cc",
|
||||
"owning_device_memory.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"device_memory_allocator.h",
|
||||
"owning_device_memory.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "maybe_owning_device_memory",
|
||||
srcs = [
|
||||
@ -3259,7 +3242,7 @@ cc_library(
|
||||
"maybe_owning_device_memory.h",
|
||||
],
|
||||
deps = [
|
||||
":device_memory_allocator",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
@ -3302,10 +3285,10 @@ xla_test(
|
||||
"gpu",
|
||||
],
|
||||
deps = [
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
@ -3428,6 +3411,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":shape_inference",
|
||||
":transpose_folding",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -3436,7 +3420,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
@ -3679,10 +3662,10 @@ tf_cc_test(
|
||||
name = "tuple_util_test",
|
||||
srcs = ["tuple_util_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":tuple_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
@ -3708,11 +3691,11 @@ tf_cc_test(
|
||||
name = "while_util_test",
|
||||
srcs = ["while_util_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
],
|
||||
@ -3743,9 +3726,9 @@ tf_cc_test(
|
||||
srcs = ["while_loop_invariant_code_motion_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":while_loop_invariant_code_motion",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -3771,9 +3754,9 @@ tf_cc_test(
|
||||
srcs = ["while_loop_constant_sinking_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":while_loop_constant_sinking",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -3973,6 +3956,8 @@ cc_library(
|
||||
hdrs = ["ar_crs_combiner.h"],
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -3980,8 +3965,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
@ -4005,11 +3988,11 @@ cc_library(
|
||||
srcs = ["dynamic_index_splitter.cc"],
|
||||
hdrs = ["dynamic_index_splitter.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
@ -4124,6 +4107,13 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "custom_call_target_registry",
|
||||
srcs = ["custom_call_target_registry.cc"],
|
||||
hdrs = ["custom_call_target_registry.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "slice_sinker_test",
|
||||
srcs = ["slice_sinker_test.cc"],
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
@ -183,6 +184,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleBroadcast(HloInstruction* broadcast) override;
|
||||
|
||||
Status HandleCompare(HloInstruction* compare) override;
|
||||
|
||||
Status HandleConcatenate(HloInstruction* concatenate) override;
|
||||
|
||||
Status HandleConstant(HloInstruction* constant) override;
|
||||
@ -2213,6 +2216,49 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
|
||||
HloInstruction* lhs;
|
||||
HloInstruction* rhs;
|
||||
CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
|
||||
|
||||
auto replace_with_pred_broadcast = [&](bool value) {
|
||||
return ReplaceWithNewInstruction(
|
||||
compare,
|
||||
HloInstruction::CreateBroadcast(
|
||||
compare->shape(),
|
||||
computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(value))),
|
||||
{}));
|
||||
};
|
||||
if (compare->comparison_direction() == ComparisonDirection::kLt &&
|
||||
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
|
||||
return replace_with_pred_broadcast(false);
|
||||
} else if (compare->comparison_direction() == ComparisonDirection::kGt &&
|
||||
IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
|
||||
return replace_with_pred_broadcast(false);
|
||||
} else if (compare->comparison_direction() == ComparisonDirection::kGe &&
|
||||
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
|
||||
return replace_with_pred_broadcast(true);
|
||||
} else if (compare->comparison_direction() == ComparisonDirection::kLe &&
|
||||
IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
|
||||
return replace_with_pred_broadcast(true);
|
||||
}
|
||||
if (lhs == rhs &&
|
||||
primitive_util::IsIntegralType(lhs->shape().element_type())) {
|
||||
switch (compare->comparison_direction()) {
|
||||
case ComparisonDirection::kGt:
|
||||
case ComparisonDirection::kLt:
|
||||
case ComparisonDirection::kNe:
|
||||
return replace_with_pred_broadcast(false);
|
||||
case ComparisonDirection::kEq:
|
||||
case ComparisonDirection::kGe:
|
||||
case ComparisonDirection::kLe:
|
||||
return replace_with_pred_broadcast(true);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A conversion to the same element type as the operand is a nop and can be
|
||||
// removed. A conversion of a constant can be simplified by making a new
|
||||
// constant.
|
||||
|
@ -5372,21 +5372,54 @@ TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
|
||||
EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
|
||||
}
|
||||
|
||||
// This test exposes a real bug: It tries to read an out-of-bounds array index
|
||||
// from within ComposePermutations(). TODO(b/132330723): Fix this.
|
||||
TEST_F(AlgebraicSimplifierTest,
|
||||
DotContractingReorder_NoChangeInContractingDimsOrder) {
|
||||
DISABLED_DotContractingReorder_NoChangeInContractingDimsOrder) {
|
||||
// No optimization opportunity here because the transpose does not reorder the
|
||||
// contracting dims.
|
||||
const char* kModuleStr = R"(
|
||||
param = f32[2,5,1,3] parameter(0)
|
||||
transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
|
||||
reshape = f32[5,6] reshape(transpose)
|
||||
constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
|
||||
ROOT dot = f32[5,4] dot(reshape, constant),
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={0}}
|
||||
HloModule m
|
||||
test {
|
||||
param = f32[2,5,1,3] parameter(0)
|
||||
transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
|
||||
reshape = f32[5,6] reshape(transpose)
|
||||
constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
|
||||
ROOT dot = f32[5,4] dot(reshape, constant),
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareIota) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
zero = s32[] constant(0)
|
||||
iota = s32[128] iota(), iota_dimension=0
|
||||
broad = s32[128] broadcast(zero), dimensions={}
|
||||
ROOT compare = pred[128] compare(iota, broad), direction=LT
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Broadcast(m::ConstantScalar(false))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareSame) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
param = s32[123] parameter(0)
|
||||
ROOT compare = pred[123] compare(param, param), direction=GE
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Broadcast(m::ConstantScalar(true))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -20,13 +20,13 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -221,8 +221,8 @@ void AllocationTracker::AddAllocationOrIncrementRefCount(
|
||||
auto it = allocation_map.find(device_memory.opaque());
|
||||
if (it == allocation_map.end()) {
|
||||
allocation_map[device_memory.opaque()] = {
|
||||
OwningDeviceMemory(device_memory, device_ordinal,
|
||||
backend_->memory_allocator()),
|
||||
se::OwningDeviceMemory(device_memory, device_ordinal,
|
||||
backend_->memory_allocator()),
|
||||
/*ref_count=*/1};
|
||||
} else {
|
||||
it->second.ref_count++;
|
||||
|
@ -77,7 +77,7 @@ class AllocationTracker {
|
||||
// Data structure encapsulating single memory allocation on the device.
|
||||
struct Allocation {
|
||||
// The pointer to this allocation.
|
||||
OwningDeviceMemory device_memory;
|
||||
se::OwningDeviceMemory device_memory;
|
||||
|
||||
// This is the number of times this memory allocation is referred to by
|
||||
// registered data handles.
|
||||
|
@ -107,44 +107,90 @@ absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
absl::optional<HloInstruction*> ArCrsCombiner::ConditionalFromBodyParameter(
|
||||
HloInstruction* instruction) {
|
||||
CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
|
||||
HloComputation* computation = instruction->parent();
|
||||
auto caller_instructions = call_graph_->GetComputationCallers(computation);
|
||||
if (caller_instructions.size() == 1) {
|
||||
auto caller_instruction = caller_instructions[0];
|
||||
if (caller_instruction->opcode() == HloOpcode::kConditional) {
|
||||
return caller_instruction;
|
||||
}
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
std::vector<HloInstruction*> ArCrsCombiner::GetAllTuples(
|
||||
HloInstruction* instruction) {
|
||||
if (instruction->opcode() == HloOpcode::kTuple) {
|
||||
return {instruction};
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kDomain) {
|
||||
return GetAllTuples(instruction->operands()[0]);
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||
auto maybe_while = WhileFromBodyParameter(instruction);
|
||||
if (!maybe_while) {
|
||||
return {};
|
||||
}
|
||||
auto while_instr = *maybe_while;
|
||||
auto init_tuples = GetAllTuples(while_instr->while_init());
|
||||
auto body_tuples =
|
||||
GetAllTuples(while_instr->while_body()->root_instruction());
|
||||
if (init_tuples.empty() || body_tuples.empty()) {
|
||||
return {};
|
||||
}
|
||||
init_tuples.insert(init_tuples.end(), body_tuples.begin(),
|
||||
body_tuples.end());
|
||||
return init_tuples;
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kGetTupleElement) {
|
||||
std::vector<HloInstruction*> result_tuples;
|
||||
for (auto tuple : GetAllTuples(instruction->operands()[0])) {
|
||||
auto tmp_tuples =
|
||||
GetAllTuples(tuple->mutable_operand(instruction->tuple_index()));
|
||||
if (tmp_tuples.empty()) {
|
||||
return {};
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kTuple:
|
||||
return {instruction};
|
||||
case HloOpcode::kDomain:
|
||||
return GetAllTuples(instruction->operands()[0]);
|
||||
case HloOpcode::kParameter: {
|
||||
auto maybe_while = WhileFromBodyParameter(instruction);
|
||||
if (maybe_while) {
|
||||
auto while_instr = *maybe_while;
|
||||
auto init_tuples = GetAllTuples(while_instr->while_init());
|
||||
auto body_tuples =
|
||||
GetAllTuples(while_instr->while_body()->root_instruction());
|
||||
if (init_tuples.empty() || body_tuples.empty()) {
|
||||
return {};
|
||||
}
|
||||
init_tuples.insert(init_tuples.end(), body_tuples.begin(),
|
||||
body_tuples.end());
|
||||
return init_tuples;
|
||||
}
|
||||
result_tuples.insert(result_tuples.end(), tmp_tuples.begin(),
|
||||
tmp_tuples.end());
|
||||
auto maybe_conditional = ConditionalFromBodyParameter(instruction);
|
||||
if (maybe_conditional) {
|
||||
auto cond_instr = *maybe_conditional;
|
||||
std::vector<HloInstruction*> tuples;
|
||||
for (int64 i = 0; i < cond_instr->branch_computations().size(); ++i) {
|
||||
if (cond_instr->branch_computation(i)->parameter_instruction(0) ==
|
||||
instruction) {
|
||||
// If the same computation is used for more than one branch of the
|
||||
// conditional, we collect the arguments that flow to the
|
||||
// computation from all branches.
|
||||
auto branch_tuples =
|
||||
GetAllTuples(cond_instr->mutable_operand(i + 1));
|
||||
if (branch_tuples.empty()) {
|
||||
return {};
|
||||
}
|
||||
tuples.insert(tuples.end(), branch_tuples.begin(),
|
||||
branch_tuples.end());
|
||||
}
|
||||
}
|
||||
return tuples;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
return result_tuples;
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
std::vector<HloInstruction*> result_tuples;
|
||||
for (auto tuple : GetAllTuples(instruction->operands()[0])) {
|
||||
auto tmp_tuples =
|
||||
GetAllTuples(tuple->mutable_operand(instruction->tuple_index()));
|
||||
if (tmp_tuples.empty()) {
|
||||
return {};
|
||||
}
|
||||
result_tuples.insert(result_tuples.end(), tmp_tuples.begin(),
|
||||
tmp_tuples.end());
|
||||
}
|
||||
return result_tuples;
|
||||
}
|
||||
case HloOpcode::kConditional: {
|
||||
std::vector<HloInstruction*> result_tuples;
|
||||
for (HloComputation* body : instruction->branch_computations()) {
|
||||
if (body->root_instruction()->opcode() != HloOpcode::kTuple) {
|
||||
return {};
|
||||
}
|
||||
result_tuples.push_back(body->root_instruction());
|
||||
}
|
||||
return result_tuples;
|
||||
}
|
||||
default:
|
||||
return {};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
bool ArCrsCombiner::TupleElementsComputeSameValue(
|
||||
|
@ -119,6 +119,12 @@ class ArCrsCombiner : public HloModulePass {
|
||||
absl::optional<HloInstruction*> WhileFromBodyParameter(
|
||||
HloInstruction* instruction);
|
||||
|
||||
// If the passed instruction is a parameter in one of the branch computations,
|
||||
// and the branch body is only called by a single instruction, return the
|
||||
// conditional instruction.
|
||||
absl::optional<HloInstruction*> ConditionalFromBodyParameter(
|
||||
HloInstruction* instruction);
|
||||
|
||||
// Returns a vector of tuple instructions.
|
||||
// If all instructions that flow to "instruction" are tuples, return them.
|
||||
// Otherwise, return an empty vector.
|
||||
|
@ -1173,5 +1173,47 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, SameValueTestConditional) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
branch_true {
|
||||
pt = (f32[2,4], f32[2,4]) parameter(0)
|
||||
gte.0 = f32[2,4] get-tuple-element(pt), index=0
|
||||
gte.1 = f32[2,4] get-tuple-element(pt), index=1
|
||||
ROOT tuple.t = (f32[2,4], f32[2,4]) tuple(gte.1, gte.0)
|
||||
}
|
||||
|
||||
branch_false {
|
||||
pf = (f32[2,4], f32[2,4]) parameter(0)
|
||||
gte.0 = f32[2,4] get-tuple-element(pf), index=0
|
||||
gte.1 = f32[2,4] get-tuple-element(pf), index=1
|
||||
add = f32[2,4] add(gte.1, gte.1)
|
||||
ROOT tuple.f = (f32[2,4], f32[2,4]) tuple(gte.0, add)
|
||||
}
|
||||
|
||||
ENTRY Parameters1.v4 {
|
||||
constant = pred[] constant(true)
|
||||
p = f32[2,4] parameter(0)
|
||||
tuple = (f32[2,4], f32[2,4]) tuple(p, p)
|
||||
ROOT conditional = (f32[2,4], f32[2,4]) conditional(constant, tuple, tuple), true_computation=branch_true, false_computation=branch_false
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto cond = module->entry_computation()->root_instruction();
|
||||
|
||||
auto branch_true = cond->branch_computation(0)->root_instruction();
|
||||
auto t0 = branch_true->mutable_operand(0);
|
||||
auto t1 = branch_true->mutable_operand(1);
|
||||
EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(t0, t1));
|
||||
|
||||
auto branch_false = cond->branch_computation(1)->root_instruction();
|
||||
auto f0 = branch_false->mutable_operand(0);
|
||||
auto f1 = branch_false->mutable_operand(1);
|
||||
EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(f0, f1));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -134,7 +134,7 @@ Backend::Backend(se::Platform* platform, Compiler* compiler,
|
||||
}
|
||||
}
|
||||
// Create a memory allocator for the valid stream executors.
|
||||
memory_allocator_ = absl::make_unique<StreamExecutorMemoryAllocator>(
|
||||
memory_allocator_ = absl::make_unique<se::StreamExecutorMemoryAllocator>(
|
||||
platform, stream_executors);
|
||||
CHECK(!stream_executors_.empty())
|
||||
<< "Service found no devices for backend " << platform_->Name() << '.';
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -35,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace Eigen {
|
||||
struct ThreadPoolDevice;
|
||||
@ -88,7 +88,7 @@ class Backend {
|
||||
// Accessors for the various objects.
|
||||
se::Platform* platform() const { return platform_; }
|
||||
Compiler* compiler() const { return compiler_; }
|
||||
DeviceMemoryAllocator* memory_allocator() const {
|
||||
se::DeviceMemoryAllocator* memory_allocator() const {
|
||||
return memory_allocator_.get();
|
||||
}
|
||||
TransferManager* transfer_manager() const { return transfer_manager_; }
|
||||
@ -179,7 +179,7 @@ class Backend {
|
||||
stream_pools_ GUARDED_BY(mu_);
|
||||
|
||||
// The default memory allocator to use.
|
||||
std::unique_ptr<StreamExecutorMemoryAllocator> memory_allocator_;
|
||||
std::unique_ptr<se::StreamExecutorMemoryAllocator> memory_allocator_;
|
||||
|
||||
// For the CPU backend, an Eigen threadpool device for use by Eigen code.
|
||||
struct IntraOpThreadPool;
|
||||
|
@ -75,8 +75,10 @@ class AotCompilationOptions {
|
||||
|
||||
// Optional allocator that may be used for allocating temp space on the device
|
||||
// during compilation.
|
||||
DeviceMemoryAllocator* device_allocator() const { return device_allocator_; }
|
||||
void set_device_allocator(DeviceMemoryAllocator* device_allocator) {
|
||||
se::DeviceMemoryAllocator* device_allocator() const {
|
||||
return device_allocator_;
|
||||
}
|
||||
void set_device_allocator(se::DeviceMemoryAllocator* device_allocator) {
|
||||
device_allocator_ = device_allocator;
|
||||
}
|
||||
|
||||
@ -98,7 +100,7 @@ class AotCompilationOptions {
|
||||
AotCompilationOptions();
|
||||
|
||||
private:
|
||||
DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||
DebugOptions debug_options_;
|
||||
absl::optional<DeviceAssignment> static_device_assignment_;
|
||||
};
|
||||
@ -147,14 +149,14 @@ class Compiler {
|
||||
// allocated should be deallocated before this function returns.
|
||||
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||
DeviceMemoryAllocator* device_allocator) = 0;
|
||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
||||
|
||||
// Optimizes a HLO module group, a set of module which runs concurrently on
|
||||
// multiple devices potentially communicating data between the modules.
|
||||
virtual Status RunHloPassesOnModuleGroup(
|
||||
HloModuleGroup* module_group,
|
||||
absl::Span<se::StreamExecutor* const> executors,
|
||||
DeviceMemoryAllocator* device_allocator) = 0;
|
||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
||||
|
||||
// Compiles the HLO module for execution on a device given by the executor,
|
||||
// and returns an executable object or an error status. No HLO passes are
|
||||
@ -168,7 +170,7 @@ class Compiler {
|
||||
// device_allocator is optional; see RunHloPasses.
|
||||
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||
DeviceMemoryAllocator* device_allocator) = 0;
|
||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
||||
|
||||
// Compiles a set of HLO modules that can run in parallel, potentially
|
||||
// communicating data between the modules.
|
||||
@ -176,7 +178,7 @@ class Compiler {
|
||||
RunBackendOnModuleGroup(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||
DeviceMemoryAllocator* device_allocator) = 0;
|
||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
||||
|
||||
// Compiles a set of HLO modules that can run in parallel, potentially
|
||||
// communicating data between the modules, and returns a corresponding
|
||||
@ -189,7 +191,7 @@ class Compiler {
|
||||
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||
DeviceMemoryAllocator* device_allocator) = 0;
|
||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
||||
|
||||
// Returns the backend configurations that the backend will consider for the
|
||||
// given HLO. Returns no configurations if the backend does not support
|
||||
|
@ -182,7 +182,6 @@ cc_library(
|
||||
deps = [
|
||||
":compiler_functor",
|
||||
":cpu_runtime",
|
||||
":custom_call_target_registry",
|
||||
":disassembler",
|
||||
":orc_jit_memory_mapper",
|
||||
":runtime_fp16",
|
||||
@ -203,6 +202,7 @@ cc_library(
|
||||
"@llvm//:orc_jit",
|
||||
"@llvm//:support",
|
||||
"@llvm//:target", # fixdeps: keep
|
||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
@ -245,7 +245,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:computation_layout",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_execution_profile",
|
||||
@ -255,6 +254,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor/host:host_stream",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -946,17 +946,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "custom_call_target_registry",
|
||||
srcs = [
|
||||
"custom_call_target_registry.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"custom_call_target_registry.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "orc_jit_memory_mapper",
|
||||
srcs = ["orc_jit_memory_mapper.cc"],
|
||||
|
@ -537,7 +537,7 @@ Status CreateHloProfilingArtifacts(
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
|
||||
DeviceMemoryAllocator* /*device_allocator*/) {
|
||||
se::DeviceMemoryAllocator* /*device_allocator*/) {
|
||||
std::unique_ptr<llvm::TargetMachine> jit_target_machine =
|
||||
SimpleOrcJIT::InferTargetMachineForJIT(
|
||||
CompilerTargetOptions(module->config()),
|
||||
@ -597,7 +597,7 @@ struct OrcJITPostCompilationHook {
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
DeviceMemoryAllocator* /*device_allocator*/) {
|
||||
se::DeviceMemoryAllocator* /*device_allocator*/) {
|
||||
VLOG(1) << "Compiling: " << module->name();
|
||||
XLA_SCOPED_LOGGING_TIMER(
|
||||
absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
|
||||
|
@ -133,11 +133,11 @@ class CpuCompiler : public LLVMCompiler {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
DeviceMemoryAllocator* device_allocator) override;
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
DeviceMemoryAllocator* device_allocator) override;
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
|
@ -73,13 +73,13 @@ CpuExecutable::CpuExecutable(
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
|
||||
std::vector<OwningDeviceMemory>>>
|
||||
std::vector<se::OwningDeviceMemory>>>
|
||||
CpuExecutable::CreateBufferTable(
|
||||
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||
absl::Span<const ShapedBuffer* const> arguments) {
|
||||
std::vector<se::DeviceMemoryBase> unowning_buffers(
|
||||
assignment_->Allocations().size());
|
||||
std::vector<OwningDeviceMemory> owning_buffers(
|
||||
std::vector<se::OwningDeviceMemory> owning_buffers(
|
||||
assignment_->Allocations().size());
|
||||
VLOG(3) << "Allocating " << assignment_->Allocations().size()
|
||||
<< " allocations for module " << module().name();
|
||||
@ -207,7 +207,7 @@ Status CpuExecutable::ExecuteComputeFunction(
|
||||
|
||||
StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<OwningDeviceMemory> buffers) {
|
||||
absl::Span<se::OwningDeviceMemory> buffers) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
ScopedShapedBuffer result_buffer(
|
||||
/*on_host_shape=*/result_shape(),
|
||||
@ -216,7 +216,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
|
||||
const HloInputOutputAliasConfig& input_output_alias =
|
||||
module().input_output_alias_config();
|
||||
|
||||
// Move OwningDeviceMemory values which contain the array(s) of the result
|
||||
// Move se::OwningDeviceMemory values which contain the array(s) of the result
|
||||
// into the respective location in ScopedShapedBuffer which is returned to the
|
||||
// caller.
|
||||
TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus(
|
||||
@ -235,7 +235,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
|
||||
const BufferAllocation::Slice slice,
|
||||
this->assignment_->GetUniqueSlice(src, buffer_source->index()));
|
||||
const BufferAllocation::Index buffer_index = slice.index();
|
||||
OwningDeviceMemory& buffer = buffers[buffer_index];
|
||||
se::OwningDeviceMemory& buffer = buffers[buffer_index];
|
||||
if (!slice.allocation()->is_entry_computation_parameter()) {
|
||||
// If the buffer coming out of the result is from a parameter, the
|
||||
// owning buffer will be null, and that means the caller aliased some
|
||||
@ -297,8 +297,8 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
|
||||
auto* host_stream = dynamic_cast<se::host::HostStream*>(
|
||||
run_options->stream()->implementation());
|
||||
se::Stream* stream = run_options->stream();
|
||||
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||
std::vector<OwningDeviceMemory> owning_buffers;
|
||||
se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||
std::vector<se::OwningDeviceMemory> owning_buffers;
|
||||
std::vector<se::DeviceMemoryBase> unowning_buffers;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::tie(unowning_buffers, owning_buffers),
|
||||
@ -326,7 +326,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
|
||||
CpuExecutable* executable;
|
||||
ServiceExecutableRunOptions run_options;
|
||||
std::vector<se::DeviceMemoryBase> unowning_buffers;
|
||||
std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
|
||||
std::shared_ptr<std::vector<se::OwningDeviceMemory>> buffers;
|
||||
HloExecutionProfile* hlo_execution_profile;
|
||||
|
||||
void operator()() {
|
||||
@ -338,7 +338,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
|
||||
};
|
||||
host_stream->EnqueueTask(
|
||||
AsyncRunTask{this, *run_options, std::move(unowning_buffers),
|
||||
std::make_shared<std::vector<OwningDeviceMemory>>(
|
||||
std::make_shared<std::vector<se::OwningDeviceMemory>>(
|
||||
std::move(owning_buffers)),
|
||||
hlo_execution_profile});
|
||||
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -37,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
@ -111,8 +111,9 @@ class CpuExecutable : public Executable {
|
||||
// storage and the live-out buffer into which the computation writes it
|
||||
// result.
|
||||
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
|
||||
std::vector<OwningDeviceMemory>>>
|
||||
CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||
std::vector<se::OwningDeviceMemory>>>
|
||||
CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
|
||||
int device_ordinal,
|
||||
absl::Span<const ShapedBuffer* const> arguments);
|
||||
|
||||
// Calls the generated function performing the computation with the given
|
||||
@ -126,7 +127,7 @@ class CpuExecutable : public Executable {
|
||||
// The addresses are set according to buffer assignment.
|
||||
StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<OwningDeviceMemory> buffers);
|
||||
absl::Span<se::OwningDeviceMemory> buffers);
|
||||
|
||||
// Returns the points-to set of the root instruction of the entry
|
||||
// computation. Uses points-to analysis from buffer assignment.
|
||||
|
@ -1,74 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
|
||||
|
||||
// This file is depended on by kernels that have to build for mobile devices.
|
||||
// For this reason, we avoid relying on TensorFlow and instead only use the
|
||||
// standard C++ library.
|
||||
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
// The CPU JIT compiler uses this registry to resolve symbolic CustomCall
|
||||
// targets; so when using the CPU JIT, CustomCall targets need to be registered
|
||||
// here with the symbol name used in the CustomCall.
|
||||
//
|
||||
// The XLA AOT compiler links using a standard offline linker; so when compiling
|
||||
// in AOT mode, you *also* need to make sure the name of the callee (presumably
|
||||
// implemented in C++) matches up with the symbolic name used in the CustomCall.
|
||||
//
|
||||
// We maintain the registry in both the JIT and the AOT cases for simplicity,
|
||||
// but we only use it when running in JIT mode.
|
||||
class CustomCallTargetRegistry {
|
||||
public:
|
||||
static CustomCallTargetRegistry* Global();
|
||||
|
||||
void Register(const std::string& symbol, void* address);
|
||||
void* Lookup(const std::string& symbol) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, void*> registered_symbols_;
|
||||
mutable std::mutex mu_;
|
||||
};
|
||||
|
||||
class RegisterCustomCallTarget {
|
||||
public:
|
||||
explicit RegisterCustomCallTarget(const std::string& name, void* address) {
|
||||
CustomCallTargetRegistry::Global()->Register(name, address);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
|
||||
|
||||
#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, counter) \
|
||||
static ::xla::cpu::RegisterCustomCallTarget REGISTER_CUSTOM_CALL_CONCAT( \
|
||||
custom_call_target_register, counter)(symbol, \
|
||||
reinterpret_cast<void*>(address))
|
||||
|
||||
#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
|
||||
REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, __COUNTER__)
|
||||
|
||||
#define REGISTER_CUSTOM_CALL_TARGET(function) \
|
||||
REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
|
@ -119,13 +119,9 @@ llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input,
|
||||
int32 vector_width) {
|
||||
VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32");
|
||||
|
||||
// This implements the same polynomial approximation as implemented in Eigen3.
|
||||
|
||||
// This implements the same polynomial approximation as implemented in Cephes.
|
||||
const llvm::APFloat half = GetIeeeF32(0.5);
|
||||
const llvm::APFloat one = GetIeeeF32(1.0);
|
||||
|
||||
const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950);
|
||||
const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949);
|
||||
const llvm::APFloat one = GetIeeeF32(1);
|
||||
|
||||
const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341);
|
||||
const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375);
|
||||
@ -138,39 +134,79 @@ llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input,
|
||||
const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1);
|
||||
const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1);
|
||||
|
||||
llvm::Value* input_clamped =
|
||||
vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi);
|
||||
llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half));
|
||||
llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx);
|
||||
llvm::Value* z = vsl.Mul(cephes_exp_C2, fx);
|
||||
llvm::Value* x = vsl.Sub(input_clamped, tmp);
|
||||
x = vsl.Sub(x, z);
|
||||
z = vsl.Mul(x, x);
|
||||
// To compute e^input, we re-express it as
|
||||
//
|
||||
// e^input = e^(a + b)
|
||||
// = e^(a + n log(2))
|
||||
// = e^a * 2^n.
|
||||
//
|
||||
// We choose n = floor(a * log(2) + 0.5), restricting the value of `a` to
|
||||
// (-0.5, 0.5). We then use a polynomial to compute e^a.
|
||||
|
||||
llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1);
|
||||
y = vsl.MulAdd(y, x, cephes_exp_p2);
|
||||
y = vsl.MulAdd(y, x, cephes_exp_p3);
|
||||
y = vsl.MulAdd(y, x, cephes_exp_p4);
|
||||
y = vsl.MulAdd(y, x, cephes_exp_p5);
|
||||
y = vsl.MulAdd(y, z, x);
|
||||
y = vsl.Add(one, y);
|
||||
// Restrict input to a small range, including some values that evaluate to
|
||||
// +/- inf. Our computations below aren't particularly sensitive to the exact
|
||||
// choices here, so we choose values a bit larger/smaller than
|
||||
//
|
||||
// log(F32_MAX) = 88.723...
|
||||
// log(F32_EPSILON) = -103.279....
|
||||
//
|
||||
input = vsl.Clamp(input, GetIeeeF32(-104), GetIeeeF32(88.8));
|
||||
|
||||
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
|
||||
// time so drop down to IRBuilder for this bit.
|
||||
llvm::Value* vector_constant_0x7f =
|
||||
b->CreateVectorSplat(vector_width, b->getInt32(0x7f));
|
||||
llvm::Value* vector_constant_23 =
|
||||
b->CreateVectorSplat(vector_width, b->getInt32(23));
|
||||
llvm::Type* i32_vector_type =
|
||||
llvm::VectorType::get(b->getInt32Ty(), vector_width);
|
||||
// fx is clamped so we don't have to worry about it being out of range for
|
||||
// i32.
|
||||
llvm::Value* emm0 = b->CreateFPToSI(fx, i32_vector_type);
|
||||
emm0 = b->CreateAdd(emm0, vector_constant_0x7f);
|
||||
emm0 = b->CreateShl(emm0, vector_constant_23);
|
||||
llvm::Value* emm0_f32 = b->CreateBitCast(emm0, vsl.vector_type());
|
||||
llvm::Value* x = input;
|
||||
llvm::Value* n = vsl.Floor(vsl.MulAdd(input, cephes_LOG2EF, half));
|
||||
|
||||
return vsl.Max(vsl.Mul(y, emm0_f32), input);
|
||||
// When we eventually do the multiplication in e^a * 2^n, we need to handle
|
||||
// the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
|
||||
// (so e^a * 2^n != inf). There's a similar problem for n < -126, the
|
||||
// smallest fp32 exponent.
|
||||
//
|
||||
// A straightforward solution would be to detect n out of range and split it
|
||||
// up, doing
|
||||
//
|
||||
// e^a * 2^n = e^a * 2^(n1 + n2)
|
||||
// = (2^n1 * e^a) * 2^n2.
|
||||
//
|
||||
// But it turns out this approach is quite slow. It's not clear why; our
|
||||
// hypothesis is that the integer operations on the exponent `n` have nonlocal
|
||||
// effects on the pipeline.
|
||||
//
|
||||
// The approach we use instead is to clamp n to [-126, 127] so 2^n doesn't
|
||||
// over/underflow. This causes `a` to be outside the range (-0.5, 0.5), which
|
||||
// means that our polynomial for e^a will give a less-accurate result. In
|
||||
// practice this seems to work well enough; it passes our exhaustive tests,
|
||||
// breaking only one result, and by one ulp (we return exp(88.7228394) =
|
||||
// max-float but we should return inf).
|
||||
n = vsl.Clamp(n, GetIeeeF32(-126), GetIeeeF32(127));
|
||||
|
||||
// Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
|
||||
x = vsl.Sub(x, vsl.Mul(cephes_exp_C1, n));
|
||||
x = vsl.Sub(x, vsl.Mul(cephes_exp_C2, n));
|
||||
llvm::Value* z = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1);
|
||||
z = vsl.MulAdd(z, x, cephes_exp_p2);
|
||||
z = vsl.MulAdd(z, x, cephes_exp_p3);
|
||||
z = vsl.MulAdd(z, x, cephes_exp_p4);
|
||||
z = vsl.MulAdd(z, x, cephes_exp_p5);
|
||||
z = vsl.MulAdd(z, vsl.Mul(x, x), x);
|
||||
z = vsl.Add(one, z);
|
||||
|
||||
// Convert n to an i32. This is safe because we clamped it above.
|
||||
llvm::Value* n_i32 =
|
||||
b->CreateFPToSI(n, llvm::VectorType::get(b->getInt32Ty(), vector_width));
|
||||
|
||||
// Create 2^n as an fp32. This works because -126 <= n <= 127 means that n is
|
||||
// within the bounds for an fp32 exponent.
|
||||
auto splat_i32 = [&](int32 v) {
|
||||
return b->CreateVectorSplat(vector_width, b->getInt32(v));
|
||||
};
|
||||
const int32 kF32SignificandBits = 23;
|
||||
llvm::Value* exp_bias = splat_i32(0x7f);
|
||||
llvm::Value* pow2 =
|
||||
b->CreateBitCast(b->CreateShl(b->CreateAdd(n_i32, exp_bias),
|
||||
splat_i32(kF32SignificandBits)),
|
||||
vsl.vector_type());
|
||||
|
||||
// Return z * 2^n.
|
||||
return vsl.Mul(z, pow2);
|
||||
}
|
||||
|
||||
llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input,
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <utility>
|
||||
@ -28,7 +29,6 @@ limitations under the License.
|
||||
#include "llvm/Support/CodeGen.h"
|
||||
#include "llvm/Support/Host.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
|
||||
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
@ -146,16 +147,18 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
|
||||
// On Mac OS X, 'name' may have a leading underscore prefix, even though the
|
||||
// registered name may not.
|
||||
std::string stripped_name(name.begin() + 1, name.end());
|
||||
func_addr = CustomCallTargetRegistry::Global()->Lookup(stripped_name);
|
||||
func_addr =
|
||||
xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
|
||||
} else {
|
||||
func_addr = CustomCallTargetRegistry::Global()->Lookup(name);
|
||||
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host");
|
||||
}
|
||||
|
||||
if (func_addr == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Unable to resolve runtime symbol: `" << name
|
||||
<< "'. Hint: if the symbol a custom call target, make sure you've "
|
||||
"registered it with the JIT using REGISTER_CUSTOM_CALL_TARGET.";
|
||||
"registered it with the JIT using "
|
||||
"XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.";
|
||||
return nullptr;
|
||||
}
|
||||
llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),
|
||||
@ -209,14 +212,15 @@ llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) {
|
||||
namespace {
|
||||
// Register some known symbols with the CustomCallTargetRegistry.
|
||||
bool RegisterKnownJITSymbols() {
|
||||
CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global();
|
||||
xla::CustomCallTargetRegistry* registry =
|
||||
xla::CustomCallTargetRegistry::Global();
|
||||
|
||||
#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
|
||||
do { \
|
||||
auto* function_address = \
|
||||
reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
|
||||
registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
|
||||
function_address); \
|
||||
function_address, "Host"); \
|
||||
CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
|
||||
"__xla_cpu_runtime_" #base_name); \
|
||||
} while (false)
|
||||
@ -247,8 +251,10 @@ bool RegisterKnownJITSymbols() {
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
|
||||
|
||||
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
|
||||
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
|
||||
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee),
|
||||
"Host");
|
||||
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
|
||||
"Host");
|
||||
|
||||
#undef REGISTER_CPU_RUNTIME_SYMBOL
|
||||
|
||||
@ -256,11 +262,12 @@ bool RegisterKnownJITSymbols() {
|
||||
// Unfortunately the double versions are overloaded on some systems, e.g.
|
||||
// Mac so we need an explicit cast. This requires passing the function signature
|
||||
// for that case.
|
||||
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
|
||||
do { \
|
||||
registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \
|
||||
registry->Register( \
|
||||
#name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \
|
||||
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
|
||||
do { \
|
||||
registry->Register(#name "f", reinterpret_cast<void*>(name##f), "Host"); \
|
||||
registry->Register(#name, \
|
||||
reinterpret_cast<void*>(static_cast<double_sig>(name)), \
|
||||
"Host"); \
|
||||
} while (false)
|
||||
|
||||
REGISTER_LIBM_SYMBOL(acos, double (*)(double));
|
||||
@ -318,8 +325,9 @@ bool RegisterKnownJITSymbols() {
|
||||
#ifdef __APPLE__
|
||||
REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
|
||||
registry->Register("__sincosf_stret",
|
||||
reinterpret_cast<void*>(__sincosf_stret));
|
||||
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret));
|
||||
reinterpret_cast<void*>(__sincosf_stret), "Host");
|
||||
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret),
|
||||
"Host");
|
||||
#else
|
||||
REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
|
||||
#endif
|
||||
@ -332,19 +340,19 @@ bool RegisterKnownJITSymbols() {
|
||||
|
||||
#undef REGISTER_LIBM_SYMBOL
|
||||
|
||||
registry->Register("memcpy", reinterpret_cast<void*>(memcpy));
|
||||
registry->Register("memmove", reinterpret_cast<void*>(memmove));
|
||||
registry->Register("memset", reinterpret_cast<void*>(memset));
|
||||
registry->Register("memcpy", reinterpret_cast<void*>(memcpy), "Host");
|
||||
registry->Register("memmove", reinterpret_cast<void*>(memmove), "Host");
|
||||
registry->Register("memset", reinterpret_cast<void*>(memset), "Host");
|
||||
|
||||
#ifdef __APPLE__
|
||||
registry->Register("__bzero", reinterpret_cast<void*>(bzero));
|
||||
registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host");
|
||||
registry->Register("memset_pattern16",
|
||||
reinterpret_cast<void*>(memset_pattern16));
|
||||
reinterpret_cast<void*>(memset_pattern16), "Host");
|
||||
#endif
|
||||
|
||||
#ifdef MEMORY_SANITIZER
|
||||
registry->Register("__msan_unpoison",
|
||||
reinterpret_cast<void*>(__msan_unpoison));
|
||||
reinterpret_cast<void*>(__msan_unpoison), "Host");
|
||||
#endif
|
||||
|
||||
return true;
|
||||
|
@ -107,13 +107,19 @@ llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
|
||||
llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
|
||||
const llvm::APFloat& low,
|
||||
const llvm::APFloat& high) {
|
||||
CHECK(!low.isNaN());
|
||||
CHECK(!high.isNaN());
|
||||
CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
|
||||
|
||||
AssertCorrectTypes({a});
|
||||
llvm::Type* type = a->getType();
|
||||
CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
|
||||
CHECK(scalar_type_->isFloatingPointTy());
|
||||
return llvm_ir::EmitFloatMin(
|
||||
llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), b_),
|
||||
GetConstantFloat(type, high), b_);
|
||||
|
||||
llvm::Value* low_value = GetConstantFloat(type, low);
|
||||
llvm::Value* high_value = GetConstantFloat(type, high);
|
||||
a = b_->CreateSelect(b_->CreateFCmpUGE(a, low_value), a, low_value);
|
||||
a = b_->CreateSelect(b_->CreateFCmpULE(a, high_value), a, high_value);
|
||||
return a;
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
|
||||
|
@ -100,8 +100,10 @@ class VectorSupportLibrary {
|
||||
|
||||
llvm::Value* Floor(llvm::Value* a);
|
||||
|
||||
// Precondition: Neither `low` nor `high` is nan.
|
||||
llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
|
||||
const llvm::APFloat& high);
|
||||
|
||||
llvm::Value* SplatFloat(const llvm::APFloat& d) {
|
||||
return GetConstantFloat(vector_type(), d);
|
||||
}
|
||||
|
@ -13,10 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
CustomCallTargetRegistry* CustomCallTargetRegistry::Global() {
|
||||
static auto* registry = new CustomCallTargetRegistry;
|
||||
@ -24,16 +23,17 @@ CustomCallTargetRegistry* CustomCallTargetRegistry::Global() {
|
||||
}
|
||||
|
||||
void CustomCallTargetRegistry::Register(const std::string& symbol,
|
||||
void* address) {
|
||||
void* address,
|
||||
const std::string& platform) {
|
||||
std::lock_guard<std::mutex> lock(mu_);
|
||||
registered_symbols_[symbol] = address;
|
||||
registered_symbols_[std::make_pair(symbol, platform)] = address;
|
||||
}
|
||||
|
||||
void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const {
|
||||
void* CustomCallTargetRegistry::Lookup(const std::string& symbol,
|
||||
const std::string& platform) const {
|
||||
std::lock_guard<std::mutex> lock(mu_);
|
||||
auto it = registered_symbols_.find(symbol);
|
||||
auto it = registered_symbols_.find(std::make_pair(symbol, platform));
|
||||
return it == registered_symbols_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user