Performance improvements to speed up invocation of XLA code, by making
canonicalization and signature generation faster Added benchmark for XlaCompilationCache::BuildSignature to measure time taken to build a signature for the cache. Base is this CL with just the changes to add the benchmark in xla_compilation_cache_test.cc, New is this whole CL. Run on desktop machine (40 X 2793 MHz CPUs); 2019-09-17T08:30:04.125894664-07:00 CPU: Intel Ivybridge with HyperThreading (20 cores) dL1:32KB dL2:256KB dL3:25MB Benchmark Base (ns) New (ns) Improvement ---------------------------------------------------------------------------- BM_BuildSignature/0 226 87 +61.5% BM_BuildSignature/1 337 171 +49.3% BM_BuildSignature/2 504 259 +48.6% BM_BuildSignature/5 1008 592 +41.3% BM_BuildSignature/10 1751 1238 +29.3% RELNOTES: n/a PiperOrigin-RevId: 276289188 Change-Id: Ia47343203f6ac587a921a92f86c2428dd04db2a7
This commit is contained in:
parent
c87a16e17a
commit
ec030f72f3
@ -313,6 +313,7 @@ cc_library(
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -124,6 +124,7 @@ XlaCompilationCache::BuildSignature(
|
||||
absl::Span<const XlaCompiler::Argument> args) {
|
||||
Signature signature;
|
||||
signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
|
||||
|
||||
for (const XlaCompiler::Argument& arg : args) {
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
@ -131,7 +132,8 @@ XlaCompilationCache::BuildSignature(
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
case XlaCompiler::Argument::kResource:
|
||||
signature.arg_shapes.emplace_back(arg.type, arg.DimensionSizes());
|
||||
signature.arg_shapes.emplace_back(arg.type,
|
||||
arg.DimensionSizesAsInlinedVector());
|
||||
break;
|
||||
default:
|
||||
return errors::InvalidArgument(
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -97,11 +98,12 @@ class XlaCompilationCache : public ResourceBase {
|
||||
|
||||
// List of Tensor types & shapes for compile-time constant arguments to the
|
||||
// compilation, ordered by argument number.
|
||||
std::vector<std::pair<DataType, std::vector<int64>>> arg_shapes;
|
||||
absl::InlinedVector<std::pair<DataType, absl::InlinedVector<int64, 4>>, 4>
|
||||
arg_shapes;
|
||||
|
||||
// List of Tensor values for compile-time constant arguments to the
|
||||
// compilation, ordered by argument number. Tensors must be in host memory.
|
||||
std::vector<Tensor> arg_values;
|
||||
absl::InlinedVector<Tensor, 4> arg_values;
|
||||
|
||||
bool operator==(const Signature& other) const;
|
||||
|
||||
|
@ -14,8 +14,10 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -50,5 +52,29 @@ TEST(XlaCompilationCacheTest, SignatureEquality) {
|
||||
}
|
||||
}
|
||||
|
||||
static void BM_BuildSignature(int iters, int n_args) {
|
||||
NameAttrList fn;
|
||||
fn.set_name("afunction");
|
||||
for (int i = 0; i < n_args; i++) {
|
||||
(*fn.mutable_attr())[absl::StrCat("T", i)].set_type(DT_FLOAT);
|
||||
}
|
||||
std::vector<XlaCompiler::Argument> args(n_args);
|
||||
for (int i = 0; i < n_args; i++) {
|
||||
args[i].kind = (((i % 3) == 0) ? XlaCompiler::Argument::kConstant
|
||||
: XlaCompiler::Argument::kParameter);
|
||||
args[i].type = DT_INT32;
|
||||
args[i].shape = TensorShape({4, 0});
|
||||
args[i].constant_value = Tensor(DT_INT32, {4, 0});
|
||||
}
|
||||
|
||||
while (--iters > 0) {
|
||||
xla::StatusOr<XlaCompilationCache::Signature> s =
|
||||
XlaCompilationCache::BuildSignature(fn, args);
|
||||
CHECK(s.ok());
|
||||
XlaCompilationCache::Signature sig = std::move(s.ValueOrDie());
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_BuildSignature)->Arg(0)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -494,6 +494,16 @@ std::vector<int64> XlaCompiler::Argument::DimensionSizes() const {
|
||||
}
|
||||
}
|
||||
|
||||
absl::InlinedVector<int64, 4>
|
||||
XlaCompiler::Argument::DimensionSizesAsInlinedVector() const {
|
||||
if (absl::holds_alternative<TensorShape>(shape)) {
|
||||
return absl::get<TensorShape>(shape).dim_sizes();
|
||||
} else {
|
||||
auto v = absl::get<xla::Shape>(shape).dimensions();
|
||||
return absl::InlinedVector<int64, 4>(v.begin(), v.end());
|
||||
}
|
||||
}
|
||||
|
||||
string XlaCompiler::Argument::ShapeHumanString() const {
|
||||
if (absl::holds_alternative<TensorShape>(shape)) {
|
||||
return absl::get<TensorShape>(shape).DebugString();
|
||||
|
@ -179,6 +179,7 @@ class XlaCompiler {
|
||||
|
||||
// Returns the dimension sizes for either TensorShape or xla::Shape.
|
||||
std::vector<int64> DimensionSizes() const;
|
||||
absl::InlinedVector<int64, 4> DimensionSizesAsInlinedVector() const;
|
||||
|
||||
// Returns the human-readable string for either TensorShape or xla::Shape.
|
||||
string ShapeHumanString() const;
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
|
||||
#include <ctype.h>
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
@ -509,7 +511,7 @@ string Print(const AttrValue& attr_value) {
|
||||
return attr_value.func().name();
|
||||
}
|
||||
std::vector<string> entries;
|
||||
for (auto p : attr_value.func().attr()) {
|
||||
for (const auto& p : attr_value.func().attr()) {
|
||||
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
|
||||
}
|
||||
std::sort(entries.begin(), entries.end());
|
||||
@ -825,7 +827,7 @@ namespace {
|
||||
// and adds an unset attr to the map.
|
||||
std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
|
||||
std::map<string, AttrValue> set_attrs;
|
||||
for (auto pair : fdef.attr()) {
|
||||
for (const auto& pair : fdef.attr()) {
|
||||
if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
|
||||
set_attrs[pair.first] = pair.second;
|
||||
}
|
||||
@ -841,7 +843,7 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
|
||||
std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
|
||||
std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
|
||||
if (f1_attrs.size() != f2_attrs.size()) return false;
|
||||
for (auto iter1 : f1_attrs) {
|
||||
for (const auto& iter1 : f1_attrs) {
|
||||
auto iter2 = f2_attrs.find(iter1.first);
|
||||
if (iter2 == f2_attrs.end()) return false;
|
||||
if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
|
||||
@ -910,55 +912,132 @@ string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
class AttrKeyAndValue {
|
||||
public:
|
||||
enum ValueRepresentationOp {
|
||||
kRaw,
|
||||
kCEscape,
|
||||
};
|
||||
AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value,
|
||||
ValueRepresentationOp value_op = kRaw)
|
||||
: key_name_(key_name),
|
||||
key_suffix_(key_suffix),
|
||||
value_op_(value_op),
|
||||
value_(std::move(value)) {}
|
||||
|
||||
bool operator<(const AttrKeyAndValue& b) const {
|
||||
if (key_name_ != b.key_name_) {
|
||||
return key_name_ < b.key_name_;
|
||||
} else if (key_suffix_ != b.key_suffix_) {
|
||||
return key_suffix_ < b.key_suffix_;
|
||||
} else {
|
||||
return value_ < b.value_;
|
||||
}
|
||||
}
|
||||
|
||||
void AppendTo(bool first, string* s) const {
|
||||
absl::string_view v;
|
||||
bool add_escaped = false;
|
||||
if ((value_op_ == kCEscape) && NeedsEscaping(value_)) {
|
||||
// Use CEscape call below
|
||||
add_escaped = true;
|
||||
} else {
|
||||
// Add raw value contents directly
|
||||
v = value_;
|
||||
}
|
||||
if (key_suffix_ >= 0) {
|
||||
strings::StrAppend(s, first ? "" : ",", key_name_, key_suffix_, "=", v);
|
||||
} else {
|
||||
strings::StrAppend(s, first ? "" : ",", key_name_, "=", v);
|
||||
}
|
||||
if (add_escaped) {
|
||||
strings::StrAppend(s, absl::CEscape(value_));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static bool NeedsEscaping(const string& s) {
|
||||
for (auto c : s) {
|
||||
if (!isalnum(c) && (c != ' ')) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
absl::string_view key_name_;
|
||||
int key_suffix_; // -1 if missing
|
||||
ValueRepresentationOp value_op_;
|
||||
string value_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
string Canonicalize(const string& funcname, AttrSlice attrs,
|
||||
const FunctionLibraryRuntime::InstantiateOptions& options) {
|
||||
std::vector<string> entries;
|
||||
entries.reserve(attrs.size() + static_cast<int>(options.target.empty()) +
|
||||
absl::InlinedVector<AttrKeyAndValue, 8> entries;
|
||||
entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) +
|
||||
options.input_devices.size());
|
||||
for (auto p : attrs) {
|
||||
for (const auto& p : attrs) {
|
||||
if (p.first != kExecutorAttr) {
|
||||
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
|
||||
entries.push_back(AttrKeyAndValue(p.first, -1, Print(p.second)));
|
||||
}
|
||||
}
|
||||
if (!options.target.empty()) {
|
||||
entries.push_back(
|
||||
strings::StrCat("_target", "=", absl::CEscape(options.target)));
|
||||
entries.push_back(AttrKeyAndValue("_target", -1, options.target,
|
||||
AttrKeyAndValue::kCEscape));
|
||||
}
|
||||
for (int i = 0; i < options.input_devices.size(); ++i) {
|
||||
entries.push_back(strings::StrCat("_input_dev", i, "=",
|
||||
absl::CEscape(options.input_devices[i])));
|
||||
entries.push_back(AttrKeyAndValue("_input_dev", i, options.input_devices[i],
|
||||
AttrKeyAndValue::kCEscape));
|
||||
}
|
||||
for (int i = 0; i < options.output_devices.size(); ++i) {
|
||||
entries.push_back(strings::StrCat(
|
||||
"_output_dev", i, "=", absl::CEscape(options.output_devices[i])));
|
||||
entries.push_back(AttrKeyAndValue("_output_dev", i,
|
||||
options.output_devices[i],
|
||||
AttrKeyAndValue::kCEscape));
|
||||
}
|
||||
for (const auto& iter : options.input_resource_dtypes_and_shapes) {
|
||||
entries.push_back(strings::StrCat("_input_resource_dtype", iter.first, "=",
|
||||
entries.push_back(AttrKeyAndValue("_input_resource_dtype", iter.first,
|
||||
DataTypeString(iter.second.dtype)));
|
||||
entries.push_back(
|
||||
strings::StrCat("_input_resource_shape", iter.first, "=",
|
||||
absl::CEscape(iter.second.shape.DebugString())));
|
||||
entries.push_back(AttrKeyAndValue("_input_resource_shape", iter.first,
|
||||
iter.second.shape.DebugString(),
|
||||
AttrKeyAndValue::kCEscape));
|
||||
}
|
||||
if (options.lib_def) {
|
||||
entries.push_back(strings::StrCat(
|
||||
"_lib_def", "=", reinterpret_cast<uintptr_t>(options.lib_def)));
|
||||
entries.push_back(AttrKeyAndValue(
|
||||
"_lib_def", -1,
|
||||
absl::StrCat("", reinterpret_cast<uintptr_t>(options.lib_def))));
|
||||
}
|
||||
if (!options.state_handle.empty()) {
|
||||
entries.push_back(
|
||||
strings::StrCat("_state_handle", "=", options.state_handle));
|
||||
AttrKeyAndValue("_state_handle", -1, options.state_handle));
|
||||
}
|
||||
string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
|
||||
if (!executor_type.empty()) {
|
||||
entries.push_back(strings::StrCat(kExecutorAttr, "=", executor_type));
|
||||
entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type));
|
||||
}
|
||||
string config_proto_serialized;
|
||||
options.config_proto.SerializeToString(&config_proto_serialized);
|
||||
if (!config_proto_serialized.empty()) {
|
||||
entries.push_back(strings::StrCat("_config_proto", "=",
|
||||
absl::CEscape(config_proto_serialized)));
|
||||
if (options.config_proto.ByteSize() > 0) {
|
||||
string config_proto_serialized;
|
||||
options.config_proto.SerializeToString(&config_proto_serialized);
|
||||
entries.push_back(AttrKeyAndValue("_config_proto", -1,
|
||||
config_proto_serialized,
|
||||
AttrKeyAndValue::kCEscape));
|
||||
}
|
||||
std::sort(entries.begin(), entries.end());
|
||||
return strings::StrCat(funcname, "[", absl::StrJoin(entries, ","), "]");
|
||||
string result = strings::StrCat(funcname, "[");
|
||||
bool first = true;
|
||||
for (const auto& entry : entries) {
|
||||
entry.AppendTo(first, &result);
|
||||
first = false;
|
||||
}
|
||||
result += "]";
|
||||
return result;
|
||||
}
|
||||
|
||||
string Canonicalize(const string& funcname, AttrSlice attrs) {
|
||||
static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions =
|
||||
new FunctionLibraryRuntime::InstantiateOptions;
|
||||
return Canonicalize(funcname, attrs, *kEmptyOptions);
|
||||
}
|
||||
|
||||
FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
|
||||
|
@ -805,9 +805,7 @@ class FunctionLibraryRuntime {
|
||||
// address spaces.
|
||||
string Canonicalize(const string& funcname, AttrSlice attrs,
|
||||
const FunctionLibraryRuntime::InstantiateOptions& options);
|
||||
inline string Canonicalize(const string& funcname, AttrSlice attrs) {
|
||||
return Canonicalize(funcname, attrs, {});
|
||||
}
|
||||
string Canonicalize(const string& funcname, AttrSlice attrs);
|
||||
|
||||
const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
|
||||
const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
|
||||
|
Loading…
x
Reference in New Issue
Block a user