Performance improvements to speed up invocation of XLA code, by making

canonicalization and signature generation faster

Added benchmark for XlaCompilationCache::BuildSignature to measure time
taken to build a signature for the cache.

Base is this CL with just the changes to add the benchmark in
xla_compilation_cache_test.cc, New is this whole CL.

Run on desktop machine (40 X 2793 MHz CPUs); 2019-09-17T08:30:04.125894664-07:00
CPU: Intel Ivybridge with HyperThreading (20 cores) dL1:32KB dL2:256KB dL3:25MB
Benchmark                                      Base (ns)    New (ns) Improvement
----------------------------------------------------------------------------
BM_BuildSignature/0                                  226          87    +61.5%
BM_BuildSignature/1                                  337         171    +49.3%
BM_BuildSignature/2                                  504         259    +48.6%
BM_BuildSignature/5                                 1008         592    +41.3%
BM_BuildSignature/10                                1751        1238    +29.3%

RELNOTES: n/a
PiperOrigin-RevId: 276289188
Change-Id: Ia47343203f6ac587a921a92f86c2428dd04db2a7
This commit is contained in:
Jeffrey A. Dean 2019-10-23 09:24:14 -07:00 committed by TensorFlower Gardener
parent c87a16e17a
commit ec030f72f3
8 changed files with 152 additions and 33 deletions

View File

@ -313,6 +313,7 @@ cc_library(
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",

View File

@ -124,6 +124,7 @@ XlaCompilationCache::BuildSignature(
absl::Span<const XlaCompiler::Argument> args) {
Signature signature;
signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
for (const XlaCompiler::Argument& arg : args) {
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
@ -131,7 +132,8 @@ XlaCompilationCache::BuildSignature(
break;
case XlaCompiler::Argument::kParameter:
case XlaCompiler::Argument::kResource:
signature.arg_shapes.emplace_back(arg.type, arg.DimensionSizes());
signature.arg_shapes.emplace_back(arg.type,
arg.DimensionSizesAsInlinedVector());
break;
default:
return errors::InvalidArgument(

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -97,11 +98,12 @@ class XlaCompilationCache : public ResourceBase {
// List of Tensor types & shapes for compile-time constant arguments to the
// compilation, ordered by argument number.
std::vector<std::pair<DataType, std::vector<int64>>> arg_shapes;
absl::InlinedVector<std::pair<DataType, absl::InlinedVector<int64, 4>>, 4>
arg_shapes;
// List of Tensor values for compile-time constant arguments to the
// compilation, ordered by argument number. Tensors must be in host memory.
std::vector<Tensor> arg_values;
absl::InlinedVector<Tensor, 4> arg_values;
bool operator==(const Signature& other) const;

View File

@ -14,8 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
@ -50,5 +52,29 @@ TEST(XlaCompilationCacheTest, SignatureEquality) {
}
}
static void BM_BuildSignature(int iters, int n_args) {
NameAttrList fn;
fn.set_name("afunction");
for (int i = 0; i < n_args; i++) {
(*fn.mutable_attr())[absl::StrCat("T", i)].set_type(DT_FLOAT);
}
std::vector<XlaCompiler::Argument> args(n_args);
for (int i = 0; i < n_args; i++) {
args[i].kind = (((i % 3) == 0) ? XlaCompiler::Argument::kConstant
: XlaCompiler::Argument::kParameter);
args[i].type = DT_INT32;
args[i].shape = TensorShape({4, 0});
args[i].constant_value = Tensor(DT_INT32, {4, 0});
}
while (--iters > 0) {
xla::StatusOr<XlaCompilationCache::Signature> s =
XlaCompilationCache::BuildSignature(fn, args);
CHECK(s.ok());
XlaCompilationCache::Signature sig = std::move(s.ValueOrDie());
}
}
BENCHMARK(BM_BuildSignature)->Arg(0)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
} // namespace
} // namespace tensorflow

View File

@ -494,6 +494,16 @@ std::vector<int64> XlaCompiler::Argument::DimensionSizes() const {
}
}
absl::InlinedVector<int64, 4>
XlaCompiler::Argument::DimensionSizesAsInlinedVector() const {
if (absl::holds_alternative<TensorShape>(shape)) {
return absl::get<TensorShape>(shape).dim_sizes();
} else {
auto v = absl::get<xla::Shape>(shape).dimensions();
return absl::InlinedVector<int64, 4>(v.begin(), v.end());
}
}
string XlaCompiler::Argument::ShapeHumanString() const {
if (absl::holds_alternative<TensorShape>(shape)) {
return absl::get<TensorShape>(shape).DebugString();

View File

@ -179,6 +179,7 @@ class XlaCompiler {
// Returns the dimension sizes for either TensorShape or xla::Shape.
std::vector<int64> DimensionSizes() const;
absl::InlinedVector<int64, 4> DimensionSizesAsInlinedVector() const;
// Returns the human-readable string for either TensorShape or xla::Shape.
string ShapeHumanString() const;

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include <ctype.h>
#include <map>
#include <unordered_map>
#include <utility>
@ -509,7 +511,7 @@ string Print(const AttrValue& attr_value) {
return attr_value.func().name();
}
std::vector<string> entries;
for (auto p : attr_value.func().attr()) {
for (const auto& p : attr_value.func().attr()) {
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
}
std::sort(entries.begin(), entries.end());
@ -825,7 +827,7 @@ namespace {
// and adds an unset attr to the map.
std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
std::map<string, AttrValue> set_attrs;
for (auto pair : fdef.attr()) {
for (const auto& pair : fdef.attr()) {
if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
set_attrs[pair.first] = pair.second;
}
@ -841,7 +843,7 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
if (f1_attrs.size() != f2_attrs.size()) return false;
for (auto iter1 : f1_attrs) {
for (const auto& iter1 : f1_attrs) {
auto iter2 = f2_attrs.find(iter1.first);
if (iter2 == f2_attrs.end()) return false;
if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
@ -910,55 +912,132 @@ string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
}
}
namespace {
class AttrKeyAndValue {
public:
enum ValueRepresentationOp {
kRaw,
kCEscape,
};
AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value,
ValueRepresentationOp value_op = kRaw)
: key_name_(key_name),
key_suffix_(key_suffix),
value_op_(value_op),
value_(std::move(value)) {}
bool operator<(const AttrKeyAndValue& b) const {
if (key_name_ != b.key_name_) {
return key_name_ < b.key_name_;
} else if (key_suffix_ != b.key_suffix_) {
return key_suffix_ < b.key_suffix_;
} else {
return value_ < b.value_;
}
}
void AppendTo(bool first, string* s) const {
absl::string_view v;
bool add_escaped = false;
if ((value_op_ == kCEscape) && NeedsEscaping(value_)) {
// Use CEscape call below
add_escaped = true;
} else {
// Add raw value contents directly
v = value_;
}
if (key_suffix_ >= 0) {
strings::StrAppend(s, first ? "" : ",", key_name_, key_suffix_, "=", v);
} else {
strings::StrAppend(s, first ? "" : ",", key_name_, "=", v);
}
if (add_escaped) {
strings::StrAppend(s, absl::CEscape(value_));
}
}
private:
static bool NeedsEscaping(const string& s) {
for (auto c : s) {
if (!isalnum(c) && (c != ' ')) {
return true;
}
}
return false;
}
absl::string_view key_name_;
int key_suffix_; // -1 if missing
ValueRepresentationOp value_op_;
string value_;
};
} // namespace
string Canonicalize(const string& funcname, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options) {
std::vector<string> entries;
entries.reserve(attrs.size() + static_cast<int>(options.target.empty()) +
absl::InlinedVector<AttrKeyAndValue, 8> entries;
entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) +
options.input_devices.size());
for (auto p : attrs) {
for (const auto& p : attrs) {
if (p.first != kExecutorAttr) {
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
entries.push_back(AttrKeyAndValue(p.first, -1, Print(p.second)));
}
}
if (!options.target.empty()) {
entries.push_back(
strings::StrCat("_target", "=", absl::CEscape(options.target)));
entries.push_back(AttrKeyAndValue("_target", -1, options.target,
AttrKeyAndValue::kCEscape));
}
for (int i = 0; i < options.input_devices.size(); ++i) {
entries.push_back(strings::StrCat("_input_dev", i, "=",
absl::CEscape(options.input_devices[i])));
entries.push_back(AttrKeyAndValue("_input_dev", i, options.input_devices[i],
AttrKeyAndValue::kCEscape));
}
for (int i = 0; i < options.output_devices.size(); ++i) {
entries.push_back(strings::StrCat(
"_output_dev", i, "=", absl::CEscape(options.output_devices[i])));
entries.push_back(AttrKeyAndValue("_output_dev", i,
options.output_devices[i],
AttrKeyAndValue::kCEscape));
}
for (const auto& iter : options.input_resource_dtypes_and_shapes) {
entries.push_back(strings::StrCat("_input_resource_dtype", iter.first, "=",
entries.push_back(AttrKeyAndValue("_input_resource_dtype", iter.first,
DataTypeString(iter.second.dtype)));
entries.push_back(
strings::StrCat("_input_resource_shape", iter.first, "=",
absl::CEscape(iter.second.shape.DebugString())));
entries.push_back(AttrKeyAndValue("_input_resource_shape", iter.first,
iter.second.shape.DebugString(),
AttrKeyAndValue::kCEscape));
}
if (options.lib_def) {
entries.push_back(strings::StrCat(
"_lib_def", "=", reinterpret_cast<uintptr_t>(options.lib_def)));
entries.push_back(AttrKeyAndValue(
"_lib_def", -1,
absl::StrCat("", reinterpret_cast<uintptr_t>(options.lib_def))));
}
if (!options.state_handle.empty()) {
entries.push_back(
strings::StrCat("_state_handle", "=", options.state_handle));
AttrKeyAndValue("_state_handle", -1, options.state_handle));
}
string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
if (!executor_type.empty()) {
entries.push_back(strings::StrCat(kExecutorAttr, "=", executor_type));
entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type));
}
string config_proto_serialized;
options.config_proto.SerializeToString(&config_proto_serialized);
if (!config_proto_serialized.empty()) {
entries.push_back(strings::StrCat("_config_proto", "=",
absl::CEscape(config_proto_serialized)));
if (options.config_proto.ByteSize() > 0) {
string config_proto_serialized;
options.config_proto.SerializeToString(&config_proto_serialized);
entries.push_back(AttrKeyAndValue("_config_proto", -1,
config_proto_serialized,
AttrKeyAndValue::kCEscape));
}
std::sort(entries.begin(), entries.end());
return strings::StrCat(funcname, "[", absl::StrJoin(entries, ","), "]");
string result = strings::StrCat(funcname, "[");
bool first = true;
for (const auto& entry : entries) {
entry.AppendTo(first, &result);
first = false;
}
result += "]";
return result;
}
string Canonicalize(const string& funcname, AttrSlice attrs) {
static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions =
new FunctionLibraryRuntime::InstantiateOptions;
return Canonicalize(funcname, attrs, *kEmptyOptions);
}
FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,

View File

@ -805,9 +805,7 @@ class FunctionLibraryRuntime {
// address spaces.
string Canonicalize(const string& funcname, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options);
inline string Canonicalize(const string& funcname, AttrSlice attrs) {
return Canonicalize(funcname, attrs, {});
}
string Canonicalize(const string& funcname, AttrSlice attrs);
const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;