Initial open-source release of XLA: Accelerated Linear Algebra.

XLA is a compiler-based linear algebra execution engine that targets CPUs, GPUs and custom accelerators.

XLA is still experimental; we are releasing it early to get the community involved.
Change: 143990941
This commit is contained in:
Peter Hawkins 2017-01-09 12:04:37 -08:00 committed by TensorFlower Gardener
parent 7ad7e4dfae
commit 1e67c90e2c
656 changed files with 138481 additions and 2 deletions

20
configure vendored
View File

@ -112,6 +112,26 @@ else
sed -i -e "s/WITH_HDFS_SUPPORT = True/WITH_HDFS_SUPPORT = False/" tensorflow/core/platform/default/build_config.bzl
fi
## Enable XLA.
while [ "$TF_ENABLE_XLA" == "" ]; do
read -p "Do you wish to build TensorFlow with the XLA just-in-time compiler (experimental)? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=1;;
[Nn]* ) echo "No XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;;
"" ) echo "No XLA support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
if [ "$TF_ENABLE_XLA" == "1" ]; then
# Update Bazel build configuration.
perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl
else
# Update Bazel build configuration.
perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl
fi
# Invoke python_config and set up symlinks to python includes
./util/python/python_config.sh --setup "$PYTHON_BIN_PATH"

View File

@ -95,6 +95,26 @@ filegroup(
"//tensorflow/c:all_files",
"//tensorflow/cc:all_files",
"//tensorflow/cc/saved_model:all_files",
"//tensorflow/compiler/aot:all_files",
"//tensorflow/compiler/aot/tests:all_files",
"//tensorflow/compiler/jit:all_files",
"//tensorflow/compiler/jit/graphcycles:all_files",
"//tensorflow/compiler/jit/legacy_flags:all_files",
"//tensorflow/compiler/tests:all_files",
"//tensorflow/compiler/tf2xla:all_files",
"//tensorflow/compiler/tf2xla/kernels:all_files",
"//tensorflow/compiler/xla:all_files",
"//tensorflow/compiler/xla/client:all_files",
"//tensorflow/compiler/xla/client/lib:all_files",
"//tensorflow/compiler/xla/legacy_flags:all_files",
"//tensorflow/compiler/xla/port:all_files",
"//tensorflow/compiler/xla/service:all_files",
"//tensorflow/compiler/xla/service/cpu:all_files",
"//tensorflow/compiler/xla/service/gpu:all_files",
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files",
"//tensorflow/compiler/xla/service/llvm_ir:all_files",
"//tensorflow/compiler/xla/tests:all_files",
"//tensorflow/compiler/xla/tools:all_files",
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/android:all_files",
"//tensorflow/contrib/bayesflow:all_files",

View File

@ -0,0 +1,218 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//visibility:private"],
)
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
# Optional runtime utilities for use by code generated by tfcompile.
cc_library(
name = "runtime",
srcs = ["runtime.cc"],
hdrs = ["runtime.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_lite",
],
)
cc_test(
name = "runtime_test",
srcs = ["runtime_test.cc"],
deps = [
":runtime",
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Don't depend on this directly; this is only used for the benchmark test
# generated by tf_library.
cc_library(
name = "tf_library_test_main",
testonly = 1,
visibility = ["//visibility:public"],
deps = ["//tensorflow/core:test_main"],
)
xla_proto_library(
name = "tfcompile_proto",
srcs = ["tfcompile.proto"],
deps = [
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tfcompile_lib",
srcs = [
"codegen.cc",
"compile.cc",
"flags.cc",
"tfcompile_util.cc",
],
hdrs = [
"codegen.h",
"compile.h",
"flags.h",
"tfcompile_util.h",
],
deps = [
":runtime", # needed by codegen to print aligned_buffer_bytes
":tfcompile_proto",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
],
)
cc_test(
name = "codegen_test",
srcs = ["codegen_test.cc"],
data = ["codegen_test_h.golden"],
deps = [
":tfcompile_lib",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_test(
name = "tfcompile_util_test",
srcs = ["tfcompile_util_test.cc"],
deps = [
":tfcompile_lib",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_binary(
name = "tfcompile",
visibility = ["//visibility:public"],
deps = [":tfcompile_main"],
)
cc_library(
name = "tfcompile_main",
srcs = ["tfcompile_main.cc"],
visibility = ["//visibility:public"],
deps = [
":tfcompile_lib",
":tfcompile_proto",
"//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags",
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
# NOTE: Most end-to-end tests are in the "tests" subdirectory, to ensure that
# tfcompile.bzl correctly handles usage from outside of the package that it is
# defined in.
# A simple test of tf_library from a text protobuf, mostly to enable the
# benchmark_test.
tf_library(
name = "test_graph_tfadd",
testonly = 1,
config = "test_graph_tfadd.config.pbtxt",
cpp_class = "AddComp",
graph = "test_graph_tfadd.pbtxt",
tags = ["manual"],
)
# Utility library for benchmark binaries, used by the *_benchmark rules that are
# added by the tfcompile bazel macro.
cc_library(
name = "benchmark",
srcs = ["benchmark.cc"],
hdrs = ["benchmark.h"],
visibility = ["//visibility:public"],
deps = [
# The purpose of the benchmark library is to support building an aot
# binary with minimal dependencies, to demonstrate small binary sizes.
#
# KEEP THE DEPENDENCIES MINIMAL.
"//tensorflow/core:framework_lite",
],
)
cc_library(
name = "benchmark_extra_android",
tags = [
"manual",
"notap",
],
visibility = ["//visibility:public"],
)
cc_test(
name = "benchmark_test",
srcs = ["benchmark_test.cc"],
tags = ["manual"],
deps = [
":benchmark",
":test_graph_tfadd",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
test_suite(
name = "all_tests",
tags = ["manual"],
tests = [
":benchmark_test",
":test_graph_tfadd_test",
"//tensorflow/compiler/aot/tests:all_tests",
],
)
exports_files([
"benchmark_main.template", # used by tf_library(...,gen_benchmark=True)
"test.cc", # used by tf_library(...,gen_test=True)
])
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,138 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// The purpose of the benchmark library is to support building an aot binary
// with minimal dependencies, to demonstrate small binary sizes.
//
// KEEP THE DEPENDENCIES MINIMAL.
#include "tensorflow/compiler/aot/benchmark.h"
#include <sys/time.h>
#include <algorithm>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace tfcompile {
namespace benchmark {
// Returns current wall time in micros.
//
// TODO(b/33546473): Refactor tensorflow::Env::NowMicros() so that we can re-use
// the implementation without pulling in all of the Env dependencies.
static double NowMicros() {
struct timeval tv;
gettimeofday(&tv, NULL);
return static_cast<uint64>(tv.tv_sec) * 1000000 + tv.tv_usec;
}
void DumpStatsToStdout(const Stats& stats) {
// Compute stats.
std::vector<int64> sorted_us(stats.per_iter_us);
std::sort(sorted_us.begin(), sorted_us.end());
const size_t count_us = sorted_us.size();
double sum_us = 0;
size_t count_us_trimmed = 0;
double sum_us_trimmed = 0;
size_t count_us_best = 0;
double sum_us_best = 0;
static constexpr float trim_ratio = 0.25;
static constexpr float best_ratio = 0.1;
const size_t count_trimmed = count_us * trim_ratio;
const size_t count_best = count_us * best_ratio;
for (size_t i = 0; i < sorted_us.size(); ++i) {
const int64 us = sorted_us[i];
sum_us += us;
if (i >= count_trimmed && i < count_us - count_trimmed) {
sum_us_trimmed += us;
++count_us_trimmed;
}
if (i < count_best) {
sum_us_best += us;
++count_us_best;
}
}
// Prepare nicely-formatted data.
const int kBufSize = 1000;
char buf[kBufSize];
snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
const string label_trimmed(buf);
snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
const string label_best(buf);
std::vector<std::pair<string, double>> groups = {
{"Best:", sorted_us.front()},
{"Worst:", sorted_us.back()},
{"Median:", sorted_us[count_us / 2]},
{"Mean:", sum_us / count_us},
{label_trimmed, sum_us_trimmed / count_us_trimmed},
{label_best, sum_us_best / count_us_best},
};
int max_label_size = 0;
double max_us = 0;
for (const auto& g : groups) {
if (g.first.size() > max_label_size) {
max_label_size = g.first.size();
}
if (g.second > max_us) {
max_us = g.second;
}
}
int max_digits = 1;
while (max_us >= 10.0) {
max_us /= 10.0;
++max_digits;
}
// Dump stats out.
printf("Benchmark ran %zu iterations over %lld us\n", count_us,
stats.total_us);
for (const auto& g : groups) {
printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
g.second);
}
}
void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) {
// If neither max_seconds or max_iters is set, stop at kDefaultMicros.
const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
? Options::kDefaultMicros
: options.max_micros;
printf("Running benchmark for %lld us\n", max_us);
const int64 start_us = NowMicros();
int64 iters = 0;
while (true) {
const int64 iter_start_us = NowMicros();
fn();
const int64 end_us = NowMicros();
// Collect stats and decide whether to stop.
stats->per_iter_us.push_back(end_us - iter_start_us);
const int64 total_us = end_us - start_us;
++iters;
if ((max_us > 0 && total_us >= max_us) ||
(options.max_iters > 0 && iters >= options.max_iters)) {
stats->total_us = total_us;
break;
}
}
}
} // namespace benchmark
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,70 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Contains benchmark functions used with the code-generated benchmarks that can
// be used to test a model on android. See also code generation rules in
// tfcompile.bzl.
//
// This is separate from the built-in micro-benchmarks, because we want to:
// 1. show a binary with minimal dependencies, to show a close-to-lower-bound
// binary size.
// 2. compile on Android.
#ifndef TENSORFLOW_COMPILER_AOT_BENCHMARK_H_
#define TENSORFLOW_COMPILER_AOT_BENCHMARK_H_
#include <functional>
#include <string>
#include <vector>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace tfcompile {
namespace benchmark {
// Options specifies options for benchmarks of functions generated by tfcompile.
struct Options {
// kDefaultMicros specifies the default time to run the benchmark, and is used
// if neither max_iters nor max_micros is set.
static const int64 kDefaultMicros = 3000000;
int64 max_iters = 0; // Maximum iterations to run, ignored if <= 0.
int64 max_micros = 0; // Maximum microseconds to run, ignored if <= 0.
};
// Stats holds statistics collected during benchmarking.
struct Stats {
std::vector<int64> per_iter_us; // Per-iteration deltas in us.
int64 total_us; // Total time in us.
Stats() : total_us(0) { per_iter_us.reserve(5000); }
};
// DumpStatsToStdout printfs to stdout stats in a multi-line human-friendly
// form.
void DumpStatsToStdout(const Stats& stats);
// BenchmarkFn is the signature of the function generated by tfcompile.
typedef std::function<void()> BenchmarkFn;
// Benchmark runs a benchmark of the function `fn`, collecting stats in `stats`.
// Use `options` to configure benchmarking options.
void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats);
} // namespace benchmark
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_BENCHMARK_H_

View File

@ -0,0 +1,51 @@
// Generated by the tf_library build rule. DO NOT EDIT!
//
// This file contains the main function and logic for benchmarking code
// generated by tfcompile. All tokens of the form `{{TFCOMPILE_*}}` must be
// rewritten to real values before this file can be compiled.
//
// TFCOMPILE_HEADER : Path to the header file generated by tfcompile.
// TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile.
//
// The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and
// generates a cc_binary rule for you.
// These macros must be defined before eigen files are included.
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
// clang-format off
#include "{{TFCOMPILE_HEADER}}" // NOLINT(whitespace/braces)
// clang-format on
#include "tensorflow/compiler/aot/benchmark.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
// Macros that expand to tokens based on the entry point name.
// clang-format off
#define CPP_CLASS {{TFCOMPILE_CPP_CLASS}} // NOLINT(whitespace/braces)
// clang-format on
namespace tensorflow {
namespace tfcompile {
int Main(int argc, char** argv) {
Eigen::ThreadPool pool(1 /* num_threads */);
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
CPP_CLASS computation;
computation.set_thread_pool(&device);
benchmark::Options options;
benchmark::Stats stats;
benchmark::Benchmark(options, [&] { computation.Run(); }, &stats);
benchmark::DumpStatsToStdout(stats);
return 0;
}
} // namespace tfcompile
} // namespace tensorflow
int main(int argc, char** argv) {
return tensorflow::tfcompile::Main(argc, argv);
}

View File

@ -0,0 +1,46 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/benchmark.h"
#include "tensorflow/compiler/aot/test_graph_tfadd.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace tfcompile {
namespace benchmark {
namespace {
// There isn't much we can verify in a stable fashion, so we just run the
// benchmark with max_iters, and ensure we end up with that many iter stats.
TEST(Benchmark, Benchmark) {
AddComp add;
Options options;
options.max_iters = 1;
Stats stats1;
Benchmark(options, [&] { add.Run(); }, &stats1);
EXPECT_EQ(stats1.per_iter_us.size(), 1);
options.max_iters = 5;
Stats stats5;
Benchmark(options, [&] { add.Run(); }, &stats5);
EXPECT_EQ(stats5.per_iter_us.size(), 5);
}
} // namespace
} // namespace benchmark
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,579 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/codegen.h"
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/compiler/aot/runtime.h"
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace tfcompile {
namespace {
// Convert an XLA type into a C++ type.
Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
switch (type) {
case xla::PRED:
*str = "bool";
break;
case xla::S8:
*str = "tensorflow::int8";
break;
case xla::S16:
*str = "tensorflow::int16";
break;
case xla::S32:
*str = "tensorflow::int32";
break;
case xla::S64:
*str = "tensorflow::int64";
break;
case xla::U8:
*str = "tensorflow::uint8";
break;
case xla::U16:
*str = "tensorflow::uint16";
break;
case xla::U32:
*str = "tensorflow::uint32";
break;
case xla::U64:
*str = "tensorflow::uint64";
break;
case xla::F32:
*str = "float";
break;
case xla::F64:
*str = "double";
break;
default:
return errors::Unimplemented("XLA type ", xla::PrimitiveType_Name(type),
" has no equivalent in C++");
}
return Status::OK();
}
// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1
// values. There are `n` entries in `sizes`.
size_t total_buffer_bytes(const intptr_t* sizes, size_t n) {
size_t total = 0;
for (size_t i = 0; i < n; ++i) {
if (sizes[i] != -1) {
total += sizes[i];
}
}
return total;
}
// Fills in arg_sizes with the byte size of each positional arg.
Status ComputeArgSizes(const CompileResult& compile_result,
std::vector<int64>* arg_sizes) {
const xla::ProgramShape& ps = compile_result.program_shape;
for (int i = 0; i < ps.parameters_size(); ++i) {
if (i == ps.parameters_size() - 1 && compile_result.has_context_arg) {
// If the compiled function needs a XlaLocalRuntimeContext* arg, it's
// always last, and must be represented as an opaque type.
const xla::PrimitiveType type = ps.parameters(i).element_type();
if (type != xla::OPAQUE) {
return errors::InvalidArgument(
"expected final context arg to be opaque, but got type: ",
xla::PrimitiveType_Name(type), ", from program shape: ",
xla::ShapeUtil::HumanString(ps));
}
arg_sizes->push_back(-1);
} else {
arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf(
ps.parameters(i), compile_result.pointer_size));
}
}
return Status::OK();
}
// Add (from,to) rewrite pairs based on the given shape. These rewrite pairs
// are used to generate methods for args and results.
Status AddRewritesForShape(int i, const xla::Shape& shape,
std::vector<std::pair<string, string>>* rewrites) {
string type;
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
std::vector<string> dim_vars;
string dim_sizes, indices;
if (xla::ShapeUtil::Rank(shape) == 0 ||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
dim_sizes = "[1]";
indices = "[0]";
} else {
for (int dim = 0; dim < shape.dimensions_size(); ++dim) {
dim_vars.push_back(strings::StrCat("size_t dim", dim));
dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]");
indices += strings::StrCat("[dim", dim, "]");
}
}
rewrites->push_back({"{{I}}", strings::StrCat(i)});
rewrites->push_back({"{{TYPE}}", type});
rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")});
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
rewrites->push_back({"{{INDICES}}", indices});
return Status::OK();
}
// Returns code rewritten by replacing all rewrite pairs, with an extra rewrite
// for the name. Note that the rewriting strategy is roughly O(N*M), where N is
// the size of the code and M is the number of rewrites. It's fine for now
// since N and M are pretty small.
//
// TODO(toddw): If this becomes a problem, we should be able to change the
// algorithm to O(N) by using a state machine, e.g. regexps or a real
// text-templating mechanism.
string RewriteWithName(const string& name, string code,
const std::vector<std::pair<string, string>>& rewrites) {
str_util::ReplaceAllPairs(&code, rewrites);
str_util::ReplaceAll(&code, "{{NAME}}", name);
return code;
}
// Generate methods for args (inputs).
Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
const CompileResult& compile_result, string* methods) {
*methods += R"(
void** args() { return args_; }
const void *const *args() const { return args_; }
)";
size_t num_args = ps.parameters_size();
if (compile_result.has_context_arg) {
// If the compiled function needs a XlaLocalRuntimeContext* arg, it's
// always last, and is set in the class constructor.
num_args--;
}
if (config.feed_size() != num_args) {
return errors::InvalidArgument("mismatch between feed_size(",
config.feed_size(), ") and num_args(",
num_args, ")");
}
for (int i = 0; i < num_args; ++i) {
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
const string code = R"(
void set_arg{{NAME}}_data(void* data) {
args_[{{I}}] = data;
}
{{TYPE}}* arg{{NAME}}_data() {
return static_cast<{{TYPE}}*>(args_[{{I}}]);
}
{{TYPE}}& arg{{NAME}}({{DIM_VARS}}) {
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
args_[{{I}}])){{INDICES}};
}
const {{TYPE}}* arg{{NAME}}_data() const {
return static_cast<const {{TYPE}}*>(args_[{{I}}]);
}
const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const {
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
args_[{{I}}])){{INDICES}};
}
)";
*methods += RewriteWithName(strings::StrCat(i), code, rewrites);
if (!config.feed(i).name().empty()) {
*methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites);
}
}
return Status::OK();
}
// Generate methods for results (outputs).
Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
string* methods) {
if (ps.result().element_type() != xla::TUPLE) {
// Non-tuple (i.e. single-result) case.
if (config.fetch_size() != 1) {
return errors::InvalidArgument(
"non-tuple result implies 1 fetch, but got ", config.fetch_size(),
" fetches");
}
*methods += R"(
void** results() { return temps_ + kResultIndex; }
const void *const *results() const { return temps_ + kResultIndex; }
)";
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(AddRewritesForShape(0, ps.result(), &rewrites));
const string code = R"(
{{TYPE}}* result{{NAME}}_data() {
return static_cast<{{TYPE}}*>(temps_[kResultIndex]);
}
{{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
temps_[kResultIndex])){{INDICES}};
}
const {{TYPE}}* result{{NAME}}_data() const {
return static_cast<const {{TYPE}}*>(temps_[kResultIndex]);
}
const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
temps_[kResultIndex])){{INDICES}};
}
)";
*methods += RewriteWithName("0", code, rewrites);
if (!config.fetch(0).name().empty()) {
*methods += RewriteWithName("_" + config.fetch(0).name(), code, rewrites);
}
return Status::OK();
}
// Tuple (i.e. multi-result) case.
if (config.fetch_size() != ps.result().tuple_shapes_size()) {
return errors::InvalidArgument("mismatch between fetch_size(",
config.feed_size(), ") and tuple_size(",
ps.result().tuple_shapes_size(), ")");
}
*methods += R"(
void** results() {
return static_cast<void**>(temps_[kResultIndex]);
}
const void *const *results() const {
return static_cast<const void *const *>(temps_[kResultIndex]);
}
)";
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(
AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
string code = R"(
{{TYPE}}* result{{NAME}}_data() {
return static_cast<{{TYPE}}*>(
static_cast<void**>(temps_[kResultIndex])[{{I}}]);
}
{{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}};
}
const {{TYPE}}* result{{NAME}}_data() const {
return static_cast<{{TYPE}}*>(
static_cast<void**>(temps_[kResultIndex])[{{I}}]);
}
const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}};
}
)";
*methods += RewriteWithName(strings::StrCat(i), code, rewrites);
if (!config.fetch(i).name().empty()) {
*methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites);
}
}
return Status::OK();
}
} // namespace
Status GenerateHeader(const HeaderOpts& opts, const Config& config,
const CompileResult& compile_result, string* header) {
TF_RETURN_IF_ERROR(ValidateConfig(config));
const int64 result_index = compile_result.aot->result_buffer_index();
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
if (result_index < 0 || result_index > temp_sizes.size()) {
return errors::InvalidArgument("result index: ", result_index,
" is outside the range of temp sizes: [0,",
temp_sizes.size(), ")");
}
// Compute sizes and generate methods.
std::vector<int64> arg_sizes;
TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes));
const xla::ProgramShape& ps = compile_result.program_shape;
string methods_arg, methods_result;
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
const std::vector<intptr_t> iarg(arg_sizes.begin(), arg_sizes.end());
const std::vector<intptr_t> itemp(temp_sizes.begin(), temp_sizes.end());
const size_t arg_bytes_aligned =
runtime::aligned_buffer_bytes(iarg.data(), iarg.size());
const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size());
const size_t temp_bytes_aligned =
runtime::aligned_buffer_bytes(itemp.data(), itemp.size());
const size_t temp_bytes_total =
total_buffer_bytes(itemp.data(), itemp.size());
// Create rewrite strings for the optional context arg.
string context_include;
string context_set_arg, context_set_thread_pool, context_member_var;
string run_result = "true";
string error_msg = "tensorflow::string()";
if (compile_result.has_context_arg) {
// NOTE: Extra spaces and newlines are used to ensure nice formatting.
context_include =
"#include "
"\"tensorflow/compiler/tf2xla/"
"xla_local_runtime_context.h\"\n";
context_set_arg = " args_[kNumArgs-1] = &context_;\n";
context_set_thread_pool = " context_.thread_pool = pool;\n";
context_member_var = " tensorflow::XlaLocalRuntimeContext context_;\n";
run_result = "!context_.error";
error_msg = "context_.error_msg";
}
// Create rewrite strings for namespace start and end.
string ns_start;
for (const string& n : opts.namespaces) {
ns_start += strings::StrCat("namespace ", n, " {\n");
}
ns_start += "\n";
string ns_end("\n");
for (int i = opts.namespaces.size() - 1; i >= 0; --i) {
const string& n = opts.namespaces[i];
ns_end += strings::StrCat("} // end namespace ", n, "\n");
}
// Use a poor-man's text templating mechanism; first populate the full header
// with placeholder tokens, and then rewrite the tokens with real values.
*header =
R"(// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT!
//
// This header was generated via ahead-of-time compilation of a TensorFlow
// graph. An object file corresponding to this header was also generated.
// This header gives access to the functionality in that object file.
//
// clang-format off
#ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
#define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
{{CONTEXT_INCLUDE}}
#include "tensorflow/compiler/aot/runtime.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace Eigen { class ThreadPoolDevice; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void {{ENTRY}}(
void* result, xla::ExecutableRunOptions* run_options,
void** args, void** temps);
{{NS_START}}
// {{CLASS}} represents a computation previously specified in a
// TensorFlow graph, now compiled into executable code. Usage example:
//
// {{CLASS}} computation;
// // ...set args using computation.argN methods
// CHECK(computation.Run());
// // ...inspect results using computation.resultN methods
//
// The Run method invokes the actual computation, with inputs read from arg
// buffers, and outputs written to result buffers. Each Run call may also use
// a set of temporary buffers for the computation.
//
// By default each instance of this class manages its own arg, result and temp
// buffers. The AllocMode constructor parameter may be used to modify the
// buffer allocation strategy.
//
// Under the default allocation strategy, this class is thread-compatible:
// o Calls to non-const methods require exclusive access to the object.
// o Concurrent calls to const methods are OK, if those calls are made while
// it is guaranteed that no thread may call a non-const method.
//
// The logical function signature is:
// {{PROGRAM_SHAPE}}
//
// Memory stats:
// arg bytes total: {{ARG_BYTES_TOTAL}}
// arg bytes aligned: {{ARG_BYTES_ALIGNED}}
// temp bytes total: {{TEMP_BYTES_TOTAL}}
// temp bytes aligned: {{TEMP_BYTES_ALIGNED}}
class {{CLASS}} {
public:
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = {{ARG_NUM}};
// Byte size of each argument buffer. There are kNumArgs entries.
static const intptr_t* ArgSizes() {
static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}};
return kArgSizes;
}
// AllocMode controls the buffer allocation mode.
enum class AllocMode {
// Allocate all buffers - args, results and temps.
ARGS_RESULTS_AND_TEMPS,
// Only allocate result and temp buffers.
// Use set_argN_data to set argument buffers before Run is called.
RESULTS_AND_TEMPS_ONLY,
};
{{CLASS}}(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) {
if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
ArgSizes(), kNumArgs, args_, false /* annotate_initialized */);
}
{{CONTEXT_SET_ARG}}
alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
TempSizes(), kNumTemps, temps_, true /* annotate_initialized */);
}
~{{CLASS}}() {
tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
}
// Sets the thread pool to use during the Run call.
{{CLASS}}& set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
run_options_.set_intra_op_thread_pool(pool);
{{CONTEXT_SET_THREAD_POOL}}
return *this;
}
// Runs the computation, with inputs read from arg buffers, and outputs
// written to result buffers. Returns true on success and false on failure.
bool Run() {
{{ENTRY}}(temps_[kResultIndex], &run_options_, args_, temps_);
return {{RUN_RESULT}};
}
// Returns the error message from the previous failed Run call.
tensorflow::string error_msg() const { return {{ERROR_MSG}}; }
// Arg methods for managing input buffers. Buffers are in row-major order.
// There is a set of methods for each positional argument, with the following
// general form:
//
// void set_argN_data(void* data)
// Sets the buffer of type T for positional argument N. May be called in
// any AllocMode. Must be called before Run to have an affect. Must be
// called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument,
// to set the argument buffers.
//
// T* argN_data()
// Returns the buffer of type T for positional argument N.
//
// T& argN(...dim indices...)
// Returns a reference to the value of type T for positional argument N,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
//
// void** args()
// Returns an array of argument buffers, where args()[N] is the buffer for
// positional argument N.
{{METHODS_ARG}}
// Result methods for managing output buffers. Buffers are in row-major order.
// Must only be called after a successful Run call. There is a set of methods
// for each positional result, with the following general form:
//
// T* resultN_data()
// Returns the buffer of type T for positional result N.
//
// T& resultN(...dim indices...)
// Returns a reference to the value of type T for positional result N,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
//
// void** results()
// Returns an array of result buffers, where results()[N] is the buffer for
// positional result N.
//
// Unlike the arg methods, there is no set_resultN_data method. The result
// buffers are managed internally, and may change after each call to Run.
{{METHODS_RESULT}}
private:
// Number of result and temporary buffers for the compiled computation.
static constexpr size_t kNumTemps = {{TEMP_NUM}};
// The 0-based index of the result in the temporary buffers.
static constexpr size_t kResultIndex = {{RESULT_INDEX}};
// Byte size of each result / temporary buffer. There are kNumTemps entries.
static const intptr_t* TempSizes() {
static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}};
return kTempSizes;
}
void* args_[kNumArgs];
void* temps_[kNumTemps];
void* alloc_args_ = nullptr;
void* alloc_temps_ = nullptr;
xla::ExecutableRunOptions run_options_;
{{CONTEXT_MEMBER_VAR}}
TF_DISALLOW_COPY_AND_ASSIGN({{CLASS}});
};
{{NS_END}}
#endif // TFCOMPILE_GENERATED_{{ENTRY}}_H_
// clang-format on
)";
// The replacement strategy is naive, but good enough for our purposes.
const std::vector<std::pair<string, string>> rewrites = {
{"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
{"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
{"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
{"{{CLASS}}", opts.class_name},
{"{{CONTEXT_INCLUDE}}\n", context_include},
{"{{CONTEXT_MEMBER_VAR}}\n", context_member_var},
{"{{CONTEXT_SET_ARG}}\n", context_set_arg},
{"{{CONTEXT_SET_THREAD_POOL}}\n", context_set_thread_pool},
{"{{ENTRY}}", compile_result.entry_point},
{"{{ERROR_MSG}}", error_msg},
{"{{METHODS_ARG}}\n", methods_arg},
{"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
{"{{RESULT_INDEX}}", strings::StrCat(result_index)},
{"{{RUN_RESULT}}", run_result},
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
{"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")},
};
str_util::ReplaceAllPairs(header, rewrites);
return Status::OK();
}
Status ParseCppClass(const string& cpp_class, string* class_name,
std::vector<string>* namespaces) {
class_name->clear();
namespaces->clear();
size_t begin = 0;
size_t end = 0;
while ((end = cpp_class.find("::", begin)) != string::npos) {
const string ns = cpp_class.substr(begin, end - begin);
TF_RETURN_IF_ERROR(ValidateCppIdent(
ns, "in namespace component of cpp_class: " + cpp_class));
namespaces->push_back(ns);
begin = end + 2; // +2 to skip the two colons
}
const string name = cpp_class.substr(begin);
TF_RETURN_IF_ERROR(
ValidateCppIdent(name, "in class name of cpp_class: " + cpp_class));
*class_name = name;
return Status::OK();
}
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,53 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_AOT_CODEGEN_H_
#define TENSORFLOW_COMPILER_AOT_CODEGEN_H_
#include <string>
#include <vector>
#include "tensorflow/compiler/aot/compile.h"
namespace tensorflow {
namespace tfcompile {
// HeaderOpts specifies options for header-file generation.
struct HeaderOpts {
// The name of the generated C++ class, wrapping the generated function.
string class_name;
// Namespaces specifies a list of C++ namespaces to add to the generated
// header. If empty, all symbols will be in the global namespace.
std::vector<string> namespaces;
};
// GenerateHeader uses the meta-information from compile_result to generate a
// C++ header giving access to the function in the generated object file. The
// header includes API usage documentation.
Status GenerateHeader(const HeaderOpts& opts, const Config& config,
const CompileResult& compile_result, string* header);
// ParseCppClass parses `cpp_class` into its `class_name` and `namespaces`
// components. The syntax is [[<optional_namespace>::],...]<class_name>. This
// mirrors the C++ syntax for referring to a class, where multiple namespaces
// may precede the class name, separated by double-colons.
Status ParseCppClass(const string& cpp_class, string* class_name,
std::vector<string>* namespaces);
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_CODEGEN_H_

View File

@ -0,0 +1,137 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/codegen.h"
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace tfcompile {
namespace {
class ParseCppClassTest : public ::testing::Test {
protected:
void ExpectOK(const string& cpp_class, const string& want_class_name,
const std::vector<string>& want_namespaces) {
string class_name;
std::vector<string> namespaces;
TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces));
EXPECT_EQ(class_name, want_class_name);
EXPECT_EQ(namespaces, want_namespaces);
}
void ExpectFail(const string& cpp_class) {
string class_name;
std::vector<string> namespaces;
EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), Status::OK());
}
};
TEST_F(ParseCppClassTest, ParseOK) {
ExpectOK("MyClass", "MyClass", {});
ExpectOK("_MyClass", "_MyClass", {});
ExpectOK("a::MyClass", "MyClass", {"a"});
ExpectOK("a::foo::MyClass", "MyClass", {"a", "foo"});
ExpectOK("a::foo::b::MyClass", "MyClass", {"a", "foo", "b"});
ExpectOK("a::foo::b::bar::MyClass", "MyClass", {"a", "foo", "b", "bar"});
ExpectOK("foo::MyClass", "MyClass", {"foo"});
ExpectOK("_foo::MyClass", "MyClass", {"_foo"});
ExpectOK("_foo::_MyClass", "_MyClass", {"_foo"});
// Make sure we didn't skip a valid letter or digit
string ident;
for (char c = 'a'; c <= 'z'; c++) {
ident.append(1, c);
}
for (char c = 'A'; c <= 'Z'; c++) {
ident.append(1, c);
}
for (char c = '0'; c <= '9'; c++) {
ident.append(1, c);
}
ident += "_";
ExpectOK(ident, ident, {});
ExpectOK(ident + "::" + ident, ident, {ident});
ExpectOK(ident + "::" + ident + "::" + ident, ident, {ident, ident});
}
TEST_F(ParseCppClassTest, ParseFail) {
ExpectFail("");
ExpectFail("::");
ExpectFail("::MyClass"); // valid C++, but disallowed for simpler code.
ExpectFail("0");
ExpectFail("a.b");
ExpectFail("a:b");
ExpectFail("good::.bad");
ExpectFail("good:::bad");
ExpectFail("good:: bad");
ExpectFail("good::0bad");
}
TEST(GenerateHeader, Golden) {
HeaderOpts opts;
opts.class_name = "MyClass";
opts.namespaces = {"foo", "bar"};
Config config;
Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("feed0");
feed->set_name("myfeed");
feed = config.add_feed();
feed->mutable_id()->set_node_name("feed1");
Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("fetch0");
fetch->set_name("myfetch");
CompileResult compile_result;
compile_result.aot.reset(
new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5));
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
{
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
xla::ShapeUtil::MakeOpaqueShape(),
},
xla::ShapeUtil::MakeShape(xla::U32, {5, 6}));
compile_result.has_context_arg = true;
compile_result.entry_point = "entry_point";
compile_result.pointer_size = 8;
string header;
TF_EXPECT_OK(GenerateHeader(opts, config, compile_result, &header));
// Compare against the golden file.
const string golden_name = io::JoinPath(testing::TensorFlowSrcRoot(),
"compiler/aot/codegen_test_h.golden");
// To update the golden file, flip update_golden to true and run the
// following:
// bazel test --test_strategy=local \
// third_party/tensorflow/compiler/aot:codegen_test
const bool update_golden = false;
if (update_golden) {
TF_EXPECT_OK(WriteStringToFile(Env::Default(), golden_name, header));
}
string golden_data;
TF_EXPECT_OK(ReadFileToString(Env::Default(), golden_name, &golden_data));
EXPECT_EQ(header, golden_data);
}
} // namespace
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,268 @@
// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT!
//
// This header was generated via ahead-of-time compilation of a TensorFlow
// graph. An object file corresponding to this header was also generated.
// This header gives access to the functionality in that object file.
//
// clang-format off
#ifndef TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
#define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
#include "tensorflow/compiler/aot/runtime.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace Eigen { class ThreadPoolDevice; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void entry_point(
void* result, xla::ExecutableRunOptions* run_options,
void** args, void** temps);
namespace foo {
namespace bar {
// MyClass represents a computation previously specified in a
// TensorFlow graph, now compiled into executable code. Usage example:
//
// MyClass computation;
// // ...set args using computation.argN methods
// CHECK(computation.Run());
// // ...inspect results using computation.resultN methods
//
// The Run method invokes the actual computation, with inputs read from arg
// buffers, and outputs written to result buffers. Each Run call may also use
// a set of temporary buffers for the computation.
//
// By default each instance of this class manages its own arg, result and temp
// buffers. The AllocMode constructor parameter may be used to modify the
// buffer allocation strategy.
//
// Under the default allocation strategy, this class is thread-compatible:
// o Calls to non-const methods require exclusive access to the object.
// o Concurrent calls to const methods are OK, if those calls are made while
// it is guaranteed that no thread may call a non-const method.
//
// The logical function signature is:
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> u32[5,6]
//
// Memory stats:
// arg bytes total: 104
// arg bytes aligned: 128
// temp bytes total: 126
// temp bytes aligned: 224
class MyClass {
public:
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = 3;
// Byte size of each argument buffer. There are kNumArgs entries.
static const intptr_t* ArgSizes() {
static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96, -1};
return kArgSizes;
}
// AllocMode controls the buffer allocation mode.
enum class AllocMode {
// Allocate all buffers - args, results and temps.
ARGS_RESULTS_AND_TEMPS,
// Only allocate result and temp buffers.
// Use set_argN_data to set argument buffers before Run is called.
RESULTS_AND_TEMPS_ONLY,
};
MyClass(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) {
if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
ArgSizes(), kNumArgs, args_, false /* annotate_initialized */);
}
args_[kNumArgs-1] = &context_;
alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
TempSizes(), kNumTemps, temps_, true /* annotate_initialized */);
}
~MyClass() {
tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
}
// Sets the thread pool to use during the Run call.
MyClass& set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
run_options_.set_intra_op_thread_pool(pool);
context_.thread_pool = pool;
return *this;
}
// Runs the computation, with inputs read from arg buffers, and outputs
// written to result buffers. Returns true on success and false on failure.
bool Run() {
entry_point(temps_[kResultIndex], &run_options_, args_, temps_);
return !context_.error;
}
// Returns the error message from the previous failed Run call.
tensorflow::string error_msg() const { return context_.error_msg; }
// Arg methods for managing input buffers. Buffers are in row-major order.
// There is a set of methods for each positional argument, with the following
// general form:
//
// void set_argN_data(void* data)
// Sets the buffer of type T for positional argument N. May be called in
// any AllocMode. Must be called before Run to have an affect. Must be
// called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument,
// to set the argument buffers.
//
// T* argN_data()
// Returns the buffer of type T for positional argument N.
//
// T& argN(...dim indices...)
// Returns a reference to the value of type T for positional argument N,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
//
// void** args()
// Returns an array of argument buffers, where args()[N] is the buffer for
// positional argument N.
void** args() { return args_; }
const void *const *args() const { return args_; }
void set_arg0_data(void* data) {
args_[0] = data;
}
float* arg0_data() {
return static_cast<float*>(args_[0]);
}
float& arg0(size_t dim0, size_t dim1) {
return (*static_cast<float(*)[1][2]>(
args_[0]))[dim0][dim1];
}
const float* arg0_data() const {
return static_cast<const float*>(args_[0]);
}
const float& arg0(size_t dim0, size_t dim1) const {
return (*static_cast<const float(*)[1][2]>(
args_[0]))[dim0][dim1];
}
void set_arg_myfeed_data(void* data) {
args_[0] = data;
}
float* arg_myfeed_data() {
return static_cast<float*>(args_[0]);
}
float& arg_myfeed(size_t dim0, size_t dim1) {
return (*static_cast<float(*)[1][2]>(
args_[0]))[dim0][dim1];
}
const float* arg_myfeed_data() const {
return static_cast<const float*>(args_[0]);
}
const float& arg_myfeed(size_t dim0, size_t dim1) const {
return (*static_cast<const float(*)[1][2]>(
args_[0]))[dim0][dim1];
}
void set_arg1_data(void* data) {
args_[1] = data;
}
tensorflow::int64* arg1_data() {
return static_cast<tensorflow::int64*>(args_[1]);
}
tensorflow::int64& arg1(size_t dim0, size_t dim1) {
return (*static_cast<tensorflow::int64(*)[3][4]>(
args_[1]))[dim0][dim1];
}
const tensorflow::int64* arg1_data() const {
return static_cast<const tensorflow::int64*>(args_[1]);
}
const tensorflow::int64& arg1(size_t dim0, size_t dim1) const {
return (*static_cast<const tensorflow::int64(*)[3][4]>(
args_[1]))[dim0][dim1];
}
// Result methods for managing output buffers. Buffers are in row-major order.
// Must only be called after a successful Run call. There is a set of methods
// for each positional result, with the following general form:
//
// T* resultN_data()
// Returns the buffer of type T for positional result N.
//
// T& resultN(...dim indices...)
// Returns a reference to the value of type T for positional result N,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
//
// void** results()
// Returns an array of result buffers, where results()[N] is the buffer for
// positional result N.
//
// Unlike the arg methods, there is no set_resultN_data method. The result
// buffers are managed internally, and may change after each call to Run.
void** results() { return temps_ + kResultIndex; }
const void *const *results() const { return temps_ + kResultIndex; }
tensorflow::uint32* result0_data() {
return static_cast<tensorflow::uint32*>(temps_[kResultIndex]);
}
tensorflow::uint32& result0(size_t dim0, size_t dim1) {
return (*static_cast<tensorflow::uint32(*)[5][6]>(
temps_[kResultIndex]))[dim0][dim1];
}
const tensorflow::uint32* result0_data() const {
return static_cast<const tensorflow::uint32*>(temps_[kResultIndex]);
}
const tensorflow::uint32& result0(size_t dim0, size_t dim1) const {
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
temps_[kResultIndex]))[dim0][dim1];
}
tensorflow::uint32* result_myfetch_data() {
return static_cast<tensorflow::uint32*>(temps_[kResultIndex]);
}
tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) {
return (*static_cast<tensorflow::uint32(*)[5][6]>(
temps_[kResultIndex]))[dim0][dim1];
}
const tensorflow::uint32* result_myfetch_data() const {
return static_cast<const tensorflow::uint32*>(temps_[kResultIndex]);
}
const tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) const {
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
temps_[kResultIndex]))[dim0][dim1];
}
private:
// Number of result and temporary buffers for the compiled computation.
static constexpr size_t kNumTemps = 6;
// The 0-based index of the result in the temporary buffers.
static constexpr size_t kResultIndex = 5;
// Byte size of each result / temporary buffer. There are kNumTemps entries.
static const intptr_t* TempSizes() {
static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120};
return kTempSizes;
}
void* args_[kNumArgs];
void* temps_[kNumTemps];
void* alloc_args_ = nullptr;
void* alloc_temps_ = nullptr;
xla::ExecutableRunOptions run_options_;
tensorflow::XlaLocalRuntimeContext context_;
TF_DISALLOW_COPY_AND_ASSIGN(MyClass);
};
} // end namespace bar
} // end namespace foo
#endif // TFCOMPILE_GENERATED_entry_point_H_
// clang-format on

View File

@ -0,0 +1,416 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/compile.h"
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace tfcompile {
const char* const kArgOp = "_Arg";
const char* const kRetvalOp = "_Retval";
const char* const kFeedIdAttr = "_feed_id";
const char* const kFetchIdAttr = "_fetch_id";
const char* const kShapeAttr = "_shape";
const char* const kDebugNameAttr = "_debug_name";
namespace {
Status DumpGraph(const MainFlags& flags, const string& name,
const Graph& graph) {
if (flags.debug_dir.empty()) {
return Status::OK();
}
GraphDef graph_def;
graph.ToGraphDef(&graph_def);
string file = io::JoinPath(flags.debug_dir, name + ".pbtxt");
return WriteTextProto(Env::Default(), file, graph_def);
}
string TensorIdToString(const TensorId& id) {
return strings::StrCat(id.node_name(), ":", id.output_index());
}
typedef std::unordered_map<string, Node*> NodeMap;
// Each feed id identifies the positional output of some node, which may consist
// of multiple edges. For each feed node, replaces all matching edges so that
// they point from a new _Arg node instead.
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
const protobuf::RepeatedPtrField<Feed>& feeds) {
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
const Feed& feed = feeds[arg_index];
const TensorId& id = feed.id();
auto it = node_map.find(id.node_name());
if (it == node_map.end()) {
return errors::NotFound("Can't find feed id: ", TensorIdToString(id));
}
const Node* feed_node = it->second;
if (id.output_index() >= feed_node->num_outputs()) {
return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id),
", output index should be < ",
feed_node->num_outputs());
}
// TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr
// if we can determine it. That way the graph will be initialized with
// whatever shapes we can infer, while the user can still explicitly specify
// or override them.
Node* arg_node = nullptr;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
.Attr("T", BaseType(feed_node->output_type(id.output_index())))
.Attr("index", arg_index)
.Attr(kFeedIdAttr, TensorIdToString(id))
.Attr(kShapeAttr, TensorShape(feed.shape()))
.Attr(kDebugNameAttr, feed.name())
.Finalize(graph, &arg_node));
// Collects out-edges from the feed node that have a matching edge index;
// these will be replaced with edges from the arg node instead. Also
// replaces all control edges from Placeholder feed nodes; similar code
// exists in subgraph::RewriteGraphForExecution.
// TODO(toddw): Why only replace control edges from Placeholder?
//
// We must collect the edges first and process them in a second pass, since
// removing the edge from the graph invalidates feed_node->out_edges.
std::vector<const Edge*> feed_edges;
for (const Edge* edge : feed_node->out_edges()) {
if (edge->src_output() == id.output_index() ||
(edge->src_output() == Graph::kControlSlot &&
feed_node->type_string() == "Placeholder")) {
feed_edges.push_back(edge);
}
}
for (const Edge* edge : feed_edges) {
if (edge->src_output() == id.output_index()) {
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
} else {
CHECK_EQ(edge->src_output(), Graph::kControlSlot);
graph->AddControlEdge(arg_node, edge->dst());
}
graph->RemoveEdge(edge);
}
}
return Status::OK();
}
// Each fetch id identifies the positional output of some node. For each fetch
// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
const protobuf::RepeatedPtrField<Fetch>& fetches,
std::unordered_set<const Node*>* retval_nodes) {
for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
const TensorId& id = fetches[ret_index].id();
auto it = node_map.find(id.node_name());
if (it == node_map.end()) {
return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
}
Node* fetch_node = it->second;
if (id.output_index() >= fetch_node->num_outputs()) {
return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
", output index should be < ",
fetch_node->num_outputs());
}
// Connects fetch_node -> retval_node.
Node* retval_node = nullptr;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
.Input(fetch_node, id.output_index())
.Attr("T", BaseType(fetch_node->output_type(id.output_index())))
.Attr("index", ret_index)
.Attr(kFetchIdAttr, TensorIdToString(id))
.Finalize(graph, &retval_node));
retval_nodes->insert(retval_node);
}
return Status::OK();
}
// RewriteAndPruneGraph identifies input and output edges (named by the feed and
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
// execution to know the input and output args for the generated function.
Status RewriteAndPruneGraph(Graph* graph, const Config& config,
const MainFlags& flags) {
NodeMap node_map;
for (Node* n : graph->nodes()) {
node_map[n->name()] = n;
}
TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed()));
std::unordered_set<const Node*> retval_nodes;
TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_rewrite", *graph));
PruneForReverseReachability(graph, retval_nodes);
FixupSourceAndSinkEdges(graph);
TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_prune", *graph));
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
std::set<string> missing_feeds, missing_fetches;
for (const Feed& feed : config.feed()) {
missing_feeds.insert(TensorIdToString(feed.id()));
}
for (const Fetch& fetch : config.fetch()) {
missing_fetches.insert(TensorIdToString(fetch.id()));
}
for (const Node* n : graph->nodes()) {
if (n->type_string() == kArgOp) {
string feed_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id));
if (missing_feeds.erase(feed_id) == 0) {
return errors::Aborted(kArgOp, " node found with unknown feed id: ",
feed_id);
}
} else if (n->type_string() == kRetvalOp) {
string fetch_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id));
if (missing_fetches.erase(fetch_id) == 0) {
return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ",
fetch_id);
}
}
}
if (!missing_feeds.empty() || !missing_fetches.empty()) {
return errors::Aborted("Post graph-pruning", ", missing feeds: ",
str_util::Join(missing_feeds, ", "),
", missing fetches: ",
str_util::Join(missing_fetches, ", "));
}
return Status::OK();
}
// CollectArgNodes collects _Arg nodes from the graph, and performs basic
// sanity-checking to ensure the index and type attributes of each node are
// initialized correctly.
Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
std::map<int, Node*> indexed_arg_nodes;
for (Node* n : graph.nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
auto insert_result = indexed_arg_nodes.insert({index, n});
if (!insert_result.second) {
const Node* dup = insert_result.first->second;
return errors::InvalidArgument(
"Multiple ", kArgOp, " nodes with index ", index, ", ",
n->DebugString(), " and ", dup->DebugString());
}
}
}
arg_nodes->clear();
for (const auto& index_node : indexed_arg_nodes) {
if (index_node.first != arg_nodes->size()) {
return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
arg_nodes->size(), ", but got index ",
index_node.first);
}
arg_nodes->push_back(index_node.second);
}
return Status::OK();
}
// Fills in xla_args from the corresponding _Arg nodes in the graph.
Status CreateXlaArgs(const Graph& graph,
std::vector<XlaCompiler::Argument>* xla_args) {
std::vector<Node*> arg_nodes;
TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
for (const Node* node : arg_nodes) {
XlaCompiler::Argument arg;
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &arg.parameter));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name));
xla_args->push_back(arg);
}
return Status::OK();
}
// Converts the TensorFlow graph into an XLA computation, by executing the
// graph symbolically, with each op building up the XLA HLO.
Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
const FunctionLibraryDefinition* flib_def,
xla::Computation* computation, bool* has_context_arg) {
// Create a device and context to convert the graph into an XLA computation.
XlaOpRegistry::RegisterJitKernels();
// Populate the context with args from the graph.
for (Node* node : graph->nodes()) {
node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
}
std::vector<XlaCompiler::Argument> xla_args;
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
// Compile the graph into an XLA computation.
XlaCompiler::Options compiler_options;
compiler_options.client = client;
compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
compiler_options.allow_cpu_custom_calls = true;
XlaCompiler compiler(compiler_options);
std::unique_ptr<FunctionLibraryRuntime> flib_run(NewFunctionLibraryRuntime(
compiler.device_mgr(), Env::Default(), compiler.device(),
graph->versions().producer(), flib_def, OptimizerOptions()));
XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(compiler.CompileGraph("tfcompile", std::move(graph),
flib_run.get(), xla_args,
false /* use_tuple_arg */, &result));
*has_context_arg = result.requires_runtime_context;
*computation = std::move(result.computation);
int num_const_results = 0;
for (int i = 0; i < result.outputs.size(); ++i) {
// Ending up with const results (i.e. output args) is an error, since it
// means that one or more fetches that the user specified will be dropped
// from the generated function. It's most likely a configuration error,
// since the user shouldn't be asking for output args that end up as consts.
//
// TODO(toddw): Provide a way for the user to access const output args,
// e.g. perhaps hard-coded into the header, or somehow copied into the
// output buffers.
if (result.outputs[i].is_constant) {
++num_const_results;
LOG(ERROR) << "ConstRetVal index:" << i
<< " value:" << result.outputs[i].constant_value.DebugString();
}
}
if (num_const_results > 0) {
return errors::Unimplemented(
"Conversion from TensorFlow graph to XLA resulted in ",
num_const_results,
" constant results. The configuration of "
"the output args (i.e. fetch ids) is probably wrong.");
}
if (computation->IsNull()) {
return errors::Aborted(
"Conversion from TensorFlow graph to XLA resulted in an empty "
"computation.");
}
return Status::OK();
}
// Compiles the XLA computation into executable code.
Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
const xla::cpu::CpuAotCompilationOptions& aot_opts,
CompileResult* compile_result) {
// Retrieves arg and result layouts from the computation.
// TODO(toddw): Should we let the user choose the major/minor ordering?
xla::StatusOr<std::unique_ptr<xla::ProgramShape>> pshape_or =
client->GetComputationShape(computation);
if (!pshape_or.ok()) {
return errors::Unknown("Couldn't get XLA program shape: ",
pshape_or.status().error_message());
}
compile_result->program_shape = *pshape_or.ValueOrDie();
xla::ProgramShape* pshape = &compile_result->program_shape;
std::vector<const xla::Shape*> arg_layouts;
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
}
xla::StatusOr<std::unique_ptr<xla::AotCompilationResult>> aot_or =
client->CompileAheadOfTime(computation, arg_layouts, pshape->result(),
aot_opts);
if (!aot_or.ok()) {
return errors::Unknown("XLA compilation failed: ",
aot_or.status().error_message());
}
compile_result->aot =
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
aot_or.ConsumeValueOrDie());
compile_result->entry_point = aot_opts.entry_point_name();
compile_result->pointer_size =
xla::LocalClient::PointerSizeForTriple(aot_opts.triple());
return Status::OK();
}
} // namespace
Status InitGraph(const GraphDef& graph_def, const Config& config,
const MainFlags& flags, const FunctionLibraryDefinition* flib,
std::unique_ptr<Graph>* graph) {
TF_RETURN_IF_ERROR(ValidateConfig(config));
std::unique_ptr<Graph> g(new Graph(flib));
GraphDef copy_def(graph_def);
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&copy_def, *g->op_registry(),
0 /*node_offset*/));
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get()));
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags));
*graph = std::move(g);
return Status::OK();
}
Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
const FunctionLibraryDefinition* flib,
CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the
// computation.
// TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
namespace gpu = perftools::gputools;
gpu::Platform* cpu_platform =
gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
xla::LocalClient* client =
xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie();
xla::Computation computation;
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), flib,
&computation,
&compile_result->has_context_arg));
if (!flags.debug_dir.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
computation.Snapshot());
string file = io::JoinPath(flags.debug_dir, "tfcompile_xla_module.pb");
TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), file, *module));
}
xla::cpu::CpuAotCompilationOptions aot_opts(
flags.target_triple, flags.target_cpu, flags.target_features,
flags.entry_point,
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
return CompileXla(client, computation, aot_opts, compile_result);
}
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,92 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_AOT_COMPILE_H_
#define TENSORFLOW_COMPILER_AOT_COMPILE_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
namespace tfcompile {
// Constants for op types and attribute names.
extern const char* const kArgOp;
extern const char* const kRetvalOp;
extern const char* const kFeedIdAttr;
extern const char* const kFetchIdAttr;
extern const char* const kShapeAttr;
extern const char* const kDebugNameAttr;
// InitGraph creates a graph based on the graph_def, that may then be compiled
// by CompileGraph.
//
// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
// and outputs of the function that will be compiled. Each feed id causes a new
// _Arg node to be created, where we first collect all existing edges pointing
// from the named node's output index, and then rewrite them to point from that
// _Arg node instead. Each fetch id causes a new _Retval node to be created,
// with a new edge pointing from the named node's output index to that _Retval
// node. All _Retval nodes also point to a special CompileExpressions node,
// used internally to finish the compilation.
//
// The rewritten graph is then pruned to only contain the portion necessary to
// compute the outputs. If dump_graphs is true, graph rewrites will be dumped
// for debugging.
Status InitGraph(const GraphDef& graph_def, const Config& config,
const MainFlags& flags, const FunctionLibraryDefinition* flib,
std::unique_ptr<Graph>* graph);
// CompileResult describes the output of CompileGraph, where the object file
// data and meta-information is available in aot.
struct CompileResult {
// Contains object file and meta-info.
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
xla::ProgramShape program_shape; // Static shape of args and results.
bool has_context_arg = false; // Is last arg XlaLocalRuntimeContext?
string entry_point; // Name of generated function.
int pointer_size = 0; // Size of a pointer in bytes.
};
// CompileGraph compiles the graph into an object file containing a function
// that performs the graph operations.
//
// The graph must have _Arg and _Retval nodes representing the function inputs
// and outputs. Every _Arg node must have a shape attribute (key=kShapeAttr,
// value=TensorShape) representing the static shape of that input, and every
// _Retval node must point to a CompileExpressions node.
//
// Typically InitGraph is called to perform this initialization, followed by
// full specification of the shape attributes.
//
// The XLA compilation options are specified in the flags.
Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
const FunctionLibraryDefinition* flib,
CompileResult* result);
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_COMPILE_H_

View File

@ -0,0 +1,72 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/flags.h"
namespace tensorflow {
namespace tfcompile {
void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
const std::vector<Flag> tmp = {
{"graph", &flags->graph,
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
"be in the human-readable proto text format, otherwise it is expected "
"to be in the proto binary format."},
{"config", &flags->config,
"Input file containing Config proto. If the file ends in '.pbtxt' it "
"is expected to be in the human-readable proto text format, otherwise "
"it is expected to be in the proto binary format."},
{"dump_fetch_nodes", &flags->dump_fetch_nodes,
"If set, only flags related to fetches are processed, and the resulting "
"fetch nodes will be dumped to stdout in a comma-separated list. "
"Typically used to format arguments for other tools, e.g. "
"freeze_graph."},
{"debug_dir", &flags->debug_dir,
"Specifies a directory to dump debugging information, including "
"rewritten graphs and the XLA HLO module."},
// Flags controlling the XLA ahead-of-time compilation, that correspond to
// the fields of xla::cpu::CpuAotCompilationOptions.
//
// TODO(toddw): The following flags also need to be supported:
// --xla_cpu_llvm_opt_level
// --xla_cpu_llvm_cl_opts
{"target_triple", &flags->target_triple,
"Target platform, similar to the clang -target flag. The general "
"format is <arch><sub>-<vendor>-<sys>-<abi>. "
"http://clang.llvm.org/docs/CrossCompilation.html#target-triple."},
{"target_cpu", &flags->target_cpu,
"Target cpu, similar to the clang -mcpu flag. "
"http://clang.llvm.org/docs/CrossCompilation.html#cpu-fpu-abi"},
{"target_features", &flags->target_features,
"Target features, e.g. +avx2, +neon, etc."},
{"entry_point", &flags->entry_point,
"Name of the generated function. If multiple generated object files "
"will be linked into the same binary, each will need a unique entry "
"point."},
{"cpp_class", &flags->cpp_class,
"Name of the generated C++ class, wrapping the generated function. The "
"syntax of this flag is [[<optional_namespace>::],...]<class_name>. "
"This mirrors the C++ syntax for referring to a class, where multiple "
"namespaces may precede the class name, separated by double-colons. "
"The class will be generated in the given namespace(s), or if no "
"namespaces are given, within the global namespace."},
{"out_object", &flags->out_object, "Output object file name."},
{"out_header", &flags->out_header, "Output header file name."},
};
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
}
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_AOT_FLAGS_H_
#define TENSORFLOW_COMPILER_AOT_FLAGS_H_
#include <string>
#include <vector>
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace tfcompile {
// Flags for the tfcompile binary. See *.cc file for descriptions.
struct MainFlags {
string graph;
string config;
bool dump_fetch_nodes = false;
string debug_dir;
string target_triple;
string target_cpu;
string target_features;
string entry_point;
string cpp_class;
string out_object;
string out_header;
};
// Appends to flag_list a tensorflow::Flag for each field in MainFlags.
void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags);
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_FLAGS_H_

View File

@ -0,0 +1,98 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/runtime.h"
#include <stdlib.h>
#include "tensorflow/core/platform/dynamic_annotations.h"
namespace tensorflow {
namespace tfcompile {
namespace runtime {
namespace {
// Inline memory allocation routines here, because depending on '//base' brings
// in libraries which use c++ streams, which adds considerable code size on
// android.
inline void* aligned_malloc(size_t size, int minimum_alignment) {
#if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN)
return memalign(minimum_alignment, size);
#else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN
void* ptr = nullptr;
// posix_memalign requires that the requested alignment be at least
// sizeof(void*). In this case, fall back on malloc which should return memory
// aligned to at least the size of a pointer.
const int required_alignment = sizeof(void*);
if (minimum_alignment < required_alignment) return malloc(size);
if (posix_memalign(&ptr, minimum_alignment, size) != 0)
return nullptr;
else
return ptr;
#endif
}
inline void aligned_free(void* aligned_memory) { free(aligned_memory); }
size_t align_to(size_t n, size_t align) {
return (((n - 1) / align) + 1) * align;
}
} // namespace
size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) {
size_t total = 0;
for (size_t i = 0; i < n; ++i) {
if (sizes[i] != -1) {
total += align_to(sizes[i], kAlign);
}
}
return total;
}
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
bool annotate_initialized) {
const size_t total = aligned_buffer_bytes(sizes, n);
void* contiguous = nullptr;
if (total > 0) {
contiguous = aligned_malloc(total, kAlign);
if (annotate_initialized) {
// Since the memory for temp buffers is written to by JITed code, msan has
// no way of knowing the memory was initialized, so explicitly mark it.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(contiguous, total);
}
}
uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
for (size_t i = 0; i < n; ++i) {
if (sizes[i] == -1) {
bufs[i] = nullptr;
} else {
bufs[i] = reinterpret_cast<void*>(pos);
pos += align_to(sizes[i], kAlign);
}
}
return contiguous;
}
void FreeContiguous(void* contiguous) {
if (contiguous != nullptr) {
aligned_free(contiguous);
}
}
} // namespace runtime
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,58 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains utilities to make it easier to invoke functions generated
// by tfcompile. Usage of these utilities is optional.
#ifndef TENSORFLOW_COMPILER_AOT_RUNTIME_H_
#define TENSORFLOW_COMPILER_AOT_RUNTIME_H_
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace tfcompile {
namespace runtime {
// Align to 32-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
static constexpr size_t kAlign = 32;
// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1
// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign
// byte boundaries.
size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n);
// MallocContiguousBuffers allocates buffers for use by the entry point
// generated by tfcompile. `sizes` is an array of byte sizes for each buffer,
// where -1 causes the buffer pointer to be nullptr. There are `n` entries in
// `sizes`. If `annotate_initialized` is set, the allocated memory will be
// annotated as having been initialized - this is useful when allocating
// temporary buffers.
//
// A single contiguous block of memory is allocated, and portions of it are
// parceled out into `bufs`, which must have space for `n` entries. Returns the
// head of the allocated contiguous block, which should be passed to
// FreeContiguous when the buffers are no longer in use.
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
bool annotate_initialized);
// FreeContiguous frees the contiguous block of memory allocated by
// MallocContiguousBuffers.
void FreeContiguous(void* contiguous);
} // namespace runtime
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_RUNTIME_H_

View File

@ -0,0 +1,125 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/runtime.h"
#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace tfcompile {
namespace runtime {
namespace {
TEST(Runtime, AlignmentValue) {
// We've chosen 32 byte alignment for the tfcompile runtime to mimic the
// regular tensorflow allocator, which was chosen to play nicely with Eigen.
// The tfcompile runtime also has a requirement that comes from the xla
// generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8
// So any value that we choose must abide by that constraint as well.
EXPECT_EQ(kAlign, Allocator::kAllocatorAlignment);
}
TEST(Runtime, AlignedBufferBytes) {
EXPECT_EQ(aligned_buffer_bytes(nullptr, 0), 0);
static constexpr intptr_t sizesA[1] = {-1};
EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0);
static constexpr intptr_t sizesB[1] = {3};
EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 32);
static constexpr intptr_t sizesC[1] = {32};
EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 32);
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 192);
}
void* add_ptr(void* base, uintptr_t delta) {
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(base) + delta);
}
// To test MallocContiguousBuffers and FreeContiguous, we just check for
// expected nullptrs, and write to each byte of allocated memory. We rely on
// the leak checker to tell us if there's an inconsistency between malloc and
// free. We also check the contiguous property.
TEST(Runtime, MallocFreeContiguousBuffers) {
// Test empty sizes.
void* base = MallocContiguousBuffers(nullptr, 0, nullptr, false);
EXPECT_EQ(base, nullptr);
FreeContiguous(base);
// Test non-empty sizes with 0 sum.
static constexpr intptr_t sizesA[1] = {-1};
void* bufA[1];
base = MallocContiguousBuffers(sizesA, 1, bufA, false);
EXPECT_EQ(base, nullptr);
EXPECT_EQ(bufA[0], nullptr);
FreeContiguous(base);
// Test non-empty sizes with non-0 sum.
static constexpr intptr_t sizesB[1] = {3};
void* bufB[1];
base = MallocContiguousBuffers(sizesB, 1, bufB, false);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufB[0], add_ptr(base, 0));
char* bufB0_bytes = static_cast<char*>(bufB[0]);
bufB0_bytes[0] = 'A';
bufB0_bytes[1] = 'B';
bufB0_bytes[2] = 'C';
FreeContiguous(base);
// Test non-empty sizes with non-0 sum, and annotate_initialized.
static constexpr intptr_t sizesC[1] = {3};
void* bufC[1];
base = MallocContiguousBuffers(sizesC, 1, bufC, true);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufC[0], add_ptr(base, 0));
char* bufC0_bytes = static_cast<char*>(bufC[0]);
bufC0_bytes[0] = 'A';
bufC0_bytes[1] = 'B';
bufC0_bytes[2] = 'C';
FreeContiguous(base);
// Test mixed sizes.
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
void* bufD[7];
base = MallocContiguousBuffers(sizesD, 7, bufD, false);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufD[0], add_ptr(base, 0));
EXPECT_EQ(bufD[1], nullptr);
EXPECT_EQ(bufD[2], add_ptr(base, 32));
EXPECT_EQ(bufD[3], nullptr);
EXPECT_EQ(bufD[4], add_ptr(base, 64));
EXPECT_EQ(bufD[5], add_ptr(base, 128));
EXPECT_EQ(bufD[6], add_ptr(base, 160));
for (int i = 0; i < 7; ++i) {
const intptr_t size = sizesD[i];
if (size != -1) {
char* bufD_bytes = static_cast<char*>(bufD[i]);
for (size_t j = 0; j < size; ++j) {
bufD_bytes[j] = 'A' + j;
}
}
}
FreeContiguous(base);
}
} // namespace
} // namespace runtime
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,94 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Generated by the tf_library build rule. DO NOT EDIT!
//
// This file contains a test and benchmark for the function generated by
// tfcompile. All tokens of the form `{{TFCOMPILE_*}}` must be rewritten to
// real values before this file can be compiled.
//
// TFCOMPILE_HEADER : Path to the header file generated by tfcompile.
// TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile.
// TFCOMPILE_NAME : Name for tests and benchmarks.
//
// The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and
// generates a cc_test rule for you.
// These macros must be defined before eigen files are included.
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
// clang-format off
#include "{{TFCOMPILE_HEADER}}" // NOLINT(whitespace/braces)
// clang-format on
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
// Macros that expand to tokens based on the entry point name.
// clang-format off
#define CPP_CLASS {{TFCOMPILE_CPP_CLASS}} // NOLINT(whitespace/braces)
#define TEST_NAME {{TFCOMPILE_NAME}}Test // NOLINT(whitespace/braces)
#define BM_NAME BM_{{TFCOMPILE_NAME}} // NOLINT(whitespace/braces)
// clang-format on
namespace tensorflow {
namespace tfcompile {
namespace {
void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) {
for (int i = 0; i < n; ++i) {
if (sizes[i] != -1) {
memset(bufs[i], 0, sizes[i]);
}
}
}
// Trivial test that runs the generated function to ensure it doesn't crash.
TEST(TEST_NAME, NoCrash) {
Eigen::ThreadPool pool(port::NumSchedulableCPUs());
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
CPP_CLASS computation;
computation.set_thread_pool(&device);
zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
EXPECT_TRUE(computation.Run());
}
// Simple benchmark that repeatedly runs the generated function.
void BM_NAME(int iters) {
testing::StopTiming();
Eigen::ThreadPool pool(port::NumSchedulableCPUs());
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
CPP_CLASS computation;
computation.set_thread_pool(&device);
zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
testing::StartTiming();
while (--iters) {
computation.Run();
}
testing::StopTiming();
}
BENCHMARK(BM_NAME);
} // namespace
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,16 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_const" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_const" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "x_y_sum" }
}

View File

@ -0,0 +1,63 @@
node {
name : "x_const"
op : "Const"
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 1
}
}
}
attr {
key : "dtype"
value {
type : DT_INT32
}
}
}
node {
name : "y_const"
op : "Const"
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 2
}
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
node {
name : "x_y_sum"
op : "Add"
input : "x_const"
input : "y_const"
attr {
key : "T"
value {
type: DT_INT32
}
}
}
versions {
producer: 15
}

View File

@ -0,0 +1,146 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//visibility:private"],
)
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
test_suite(
name = "all_tests",
tags = ["manual"],
tests = [
":test_graph_tfadd_test",
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
":test_graph_tfmatmulandadd_test",
":tfcompile_test",
],
)
py_binary(
name = "make_test_graphs",
testonly = 1,
srcs = ["make_test_graphs.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python", # TODO(b/34059704): remove when fixed
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python:variables",
],
)
genrule(
name = "gen_test_graphs",
testonly = 1,
outs = [
"test_graph_tfadd.pb",
"test_graph_tfadd_with_ckpt.pb",
"test_graph_tfadd_with_ckpt.ckpt",
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.ckpt",
"test_graph_tfadd_with_ckpt_saver.saver",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
],
cmd = "$(location :make_test_graphs) --out_dir $(@D)",
tags = ["manual"],
tools = [":make_test_graphs"],
)
tf_library(
name = "test_graph_tfadd",
testonly = 1,
config = "test_graph_tfadd.config.pbtxt",
cpp_class = "AddComp",
graph = "test_graph_tfadd.pb",
tags = ["manual"],
)
tf_library(
name = "test_graph_tfadd_with_ckpt",
testonly = 1,
config = "test_graph_tfadd_with_ckpt.config.pbtxt",
cpp_class = "AddWithCkptComp",
freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt",
graph = "test_graph_tfadd_with_ckpt.pb",
tags = ["manual"],
)
tf_library(
name = "test_graph_tfadd_with_ckpt_saver",
testonly = 1,
config = "test_graph_tfadd_with_ckpt.config.pbtxt",
cpp_class = "AddWithCkptSaverComp",
freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt",
freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver",
graph = "test_graph_tfadd_with_ckpt_saver.pb",
tags = ["manual"],
)
tf_library(
name = "test_graph_tfgather",
testonly = 1,
config = "test_graph_tfgather.config.pbtxt",
cpp_class = "GatherComp",
graph = "test_graph_tfgather.pb",
tags = ["manual"],
)
tf_library(
name = "test_graph_tfmatmul",
testonly = 1,
config = "test_graph_tfmatmul.config.pbtxt",
cpp_class = "foo::bar::MatMulComp",
graph = "test_graph_tfmatmul.pb",
tags = ["manual"],
)
tf_library(
name = "test_graph_tfmatmulandadd",
testonly = 1,
config = "test_graph_tfmatmulandadd.config.pbtxt",
cpp_class = "MatMulAndAddComp",
graph = "test_graph_tfmatmulandadd.pb",
tags = ["manual"],
)
cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
tags = ["manual"],
deps = [
":test_graph_tfadd",
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tfgather",
":test_graph_tfmatmul",
":test_graph_tfmatmulandadd",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
],
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,119 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generate tensorflow graphs for testing tfcompile."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.training import saver as saver_lib
flags = flags_lib
FLAGS = flags.FLAGS
flags.DEFINE_string('out_dir', '',
'Output directory for graphs, checkpoints and savers.')
def tfadd():
x = constant_op.constant([1], name='x_const')
y = constant_op.constant([2], name='y_const')
math_ops.add(x, y, name='x_y_sum')
def tfadd_with_ckpt():
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = variables.Variable(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum')
init_op = variables.initialize_all_variables()
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
with session.Session() as sess:
sess.run(init_op)
sess.run(y.assign(y + 42))
# Without the checkpoint, the variable won't be set to 42.
ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % FLAGS.out_dir
saver.save(sess, ckpt)
def tfadd_with_ckpt_saver():
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = variables.Variable(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum')
init_op = variables.initialize_all_variables()
saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1)
with session.Session() as sess:
sess.run(init_op)
sess.run(y.assign(y + 42))
# Without the checkpoint, the variable won't be set to 42.
ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % FLAGS.out_dir
saver.save(sess, ckpt_file)
# Without the SaverDef, the restore op won't be named correctly.
saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % FLAGS.out_dir
with open(saver_file, 'w') as f:
f.write(saver.as_saver_def().SerializeToString())
def tfgather():
params = array_ops.placeholder(dtypes.float32, name='params')
indices = array_ops.placeholder(dtypes.int32, name='indices')
array_ops.gather(params, indices, name='gather_output')
def tfmatmul():
x = array_ops.placeholder(dtypes.float32, name='x_hold')
y = array_ops.placeholder(dtypes.float32, name='y_hold')
math_ops.matmul(x, y, name='x_y_prod')
def tfmatmulandadd():
# This tests multiple outputs.
x = array_ops.placeholder(dtypes.float32, name='x_hold')
y = array_ops.placeholder(dtypes.float32, name='y_hold')
math_ops.matmul(x, y, name='x_y_prod')
math_ops.add(x, y, name='x_y_sum')
def write_graph(build_graph):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
with g.as_default():
build_graph()
filename = '%s/test_graph_%s.pb' % (FLAGS.out_dir, build_graph.__name__)
with open(filename, 'w') as f:
f.write(g.as_graph_def().SerializeToString())
def main(_):
write_graph(tfadd)
write_graph(tfadd_with_ckpt)
write_graph(tfadd_with_ckpt_saver)
write_graph(tfgather)
write_graph(tfmatmul)
write_graph(tfmatmulandadd)
if __name__ == '__main__':
app.run()

View File

@ -0,0 +1,16 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_const" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_const" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "x_y_sum" }
}

View File

@ -0,0 +1,10 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_hold" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "x_y_sum" }
}

View File

@ -0,0 +1,16 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "params" }
shape {
dim { size: 4 }
}
}
feed {
id { node_name: "indices" }
shape {
dim { size: 2 }
}
}
fetch {
id { node_name: "gather_output" }
}

View File

@ -0,0 +1,18 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_hold" }
shape {
dim { size: 2 }
dim { size: 3 }
}
}
feed {
id { node_name: "y_hold" }
shape {
dim { size: 3 }
dim { size: 2 }
}
}
fetch {
id { node_name: "x_y_prod" }
}

View File

@ -0,0 +1,25 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_hold" }
shape {
dim { size: 2 }
dim { size: 2 }
}
name: "x"
}
feed {
id { node_name: "y_hold" }
shape {
dim { size: 2 }
dim { size: 2 }
}
name: "y"
}
fetch {
id { node_name: "x_y_prod" }
name: "x_y_prod"
}
fetch {
id { node_name: "x_y_sum" }
name: "x_y_sum"
}

View File

@ -0,0 +1,381 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace tfcompile {
namespace {
TEST(TFCompileTest, Add) {
AddComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]);
EXPECT_EQ(add.arg1_data(), add.args()[1]);
add.arg0() = 1;
add.arg1() = 2;
EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), "");
EXPECT_EQ(add.result0(), 3);
EXPECT_EQ(add.result0_data()[0], 3);
EXPECT_EQ(add.result0_data(), add.results()[0]);
add.arg0_data()[0] = 123;
add.arg1_data()[0] = 456;
EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), "");
EXPECT_EQ(add.result0(), 579);
EXPECT_EQ(add.result0_data()[0], 579);
EXPECT_EQ(add.result0_data(), add.results()[0]);
const AddComp& add_const = add;
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 123);
EXPECT_EQ(add_const.arg0_data()[0], 123);
EXPECT_EQ(add_const.arg0_data(), add.args()[0]);
EXPECT_EQ(add_const.arg1(), 456);
EXPECT_EQ(add_const.arg1_data()[0], 456);
EXPECT_EQ(add_const.arg1_data(), add.args()[1]);
EXPECT_EQ(add_const.result0(), 579);
EXPECT_EQ(add_const.result0_data()[0], 579);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}
// Run tests that use set_argN_data separately, to avoid accidentally re-using
// non-existent buffers.
TEST(TFCompileTest, Add_SetArg) {
AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY);
int32 arg_x = 10;
int32 arg_y = 32;
add.set_arg0_data(&arg_x);
add.set_arg1_data(&arg_y);
EXPECT_EQ(add.arg0_data(), add.args()[0]);
EXPECT_EQ(add.arg1_data(), add.args()[1]);
EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), "");
EXPECT_EQ(add.result0(), 42);
EXPECT_EQ(add.result0_data()[0], 42);
EXPECT_EQ(add.result0_data(), add.results()[0]);
}
TEST(TFCompileTest, AddWithCkpt) {
AddWithCkptComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]);
add.arg0() = 1;
EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), "");
EXPECT_EQ(add.result0(), 43);
EXPECT_EQ(add.result0_data()[0], 43);
EXPECT_EQ(add.result0_data(), add.results()[0]);
add.arg0_data()[0] = 111;
EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), "");
EXPECT_EQ(add.result0(), 153);
EXPECT_EQ(add.result0_data()[0], 153);
EXPECT_EQ(add.result0_data(), add.results()[0]);
const AddWithCkptComp& add_const = add;
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 111);
EXPECT_EQ(add_const.arg0_data()[0], 111);
EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
EXPECT_EQ(add_const.result0(), 153);
EXPECT_EQ(add_const.result0_data()[0], 153);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}
TEST(TFCompileTest, AddWithCkptSaver) {
AddWithCkptSaverComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]);
add.arg0() = 1;
EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), "");
EXPECT_EQ(add.result0(), 43);
EXPECT_EQ(add.result0_data()[0], 43);
EXPECT_EQ(add.result0_data(), add.results()[0]);
add.arg0_data()[0] = 111;
EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), "");
EXPECT_EQ(add.result0(), 153);
EXPECT_EQ(add.result0_data()[0], 153);
EXPECT_EQ(add.result0_data(), add.results()[0]);
const AddWithCkptSaverComp& add_const = add;
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 111);
EXPECT_EQ(add_const.arg0_data()[0], 111);
EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
EXPECT_EQ(add_const.result0(), 153);
EXPECT_EQ(add_const.result0_data()[0], 153);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}
TEST(TFCompileTest, Gather) {
GatherComp gather;
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
EXPECT_EQ(gather.arg1_data(), gather.args()[1]);
// Successful gather.
{
const float params[4] = {1, 2, 3, 4};
std::copy(params + 0, params + 4, gather.arg0_data());
const int32 indices[2] = {1, 3};
std::copy(indices + 0, indices + 2, gather.arg1_data());
EXPECT_TRUE(gather.Run());
EXPECT_EQ(gather.error_msg(), "");
const float results[2] = {2, 4};
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather.result0(i), results[i]);
EXPECT_EQ(gather.result0_data()[i], results[i]);
}
EXPECT_EQ(gather.result0_data(), gather.results()[0]);
const GatherComp& gather_const = gather;
EXPECT_EQ(gather_const.error_msg(), "");
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(gather_const.arg0(i), params[i]);
EXPECT_EQ(gather_const.arg0_data()[i], params[i]);
}
EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]);
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather_const.arg1(i), indices[i]);
EXPECT_EQ(gather_const.arg1_data()[i], indices[i]);
}
EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]);
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather_const.result0(i), results[i]);
EXPECT_EQ(gather_const.result0_data()[i], results[i]);
}
EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
}
// Bad indices returns an error.
{
const float params[4] = {1, 2, 3, 4};
std::copy(params + 0, params + 4, gather.arg0_data());
const int32 indices[2] = {1, 4};
std::copy(indices + 0, indices + 2, gather.arg1_data());
EXPECT_FALSE(gather.Run());
EXPECT_EQ(gather.error_msg(), "Invalid index for gather");
}
}
TEST(TFCompileTest, MatMul2) {
Eigen::ThreadPool tp(2);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
foo::bar::MatMulComp matmul;
matmul.set_thread_pool(&device);
EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
// Test using the argN() methods.
{
matmul.arg0(0, 0) = 1;
matmul.arg0(0, 1) = 2;
matmul.arg0(0, 2) = 3;
matmul.arg0(1, 0) = 4;
matmul.arg0(1, 1) = 5;
matmul.arg0(1, 2) = 6;
matmul.arg1(0, 0) = 7;
matmul.arg1(0, 1) = 8;
matmul.arg1(1, 0) = 9;
matmul.arg1(1, 1) = 10;
matmul.arg1(2, 0) = 11;
matmul.arg1(2, 1) = 12;
EXPECT_TRUE(matmul.Run());
EXPECT_EQ(matmul.error_msg(), "");
const float results[4] = {58, 64, 139, 154};
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
EXPECT_EQ(matmul.result0_data()[i], results[i]);
}
EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
}
// Test using the argN_data() methods.
{
const float args[12] = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120};
std::copy(args + 0, args + 6, matmul.arg0_data());
std::copy(args + 6, args + 12, matmul.arg1_data());
EXPECT_TRUE(matmul.Run());
EXPECT_EQ(matmul.error_msg(), "");
const float results[4] = {5800, 6400, 13900, 15400};
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
EXPECT_EQ(matmul.result0_data()[i], results[i]);
}
EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
const foo::bar::MatMulComp& matmul_const = matmul;
EXPECT_EQ(matmul_const.error_msg(), "");
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]);
EXPECT_EQ(matmul_const.arg0_data()[i], args[i]);
}
EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]);
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]);
EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]);
}
EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]);
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]);
EXPECT_EQ(matmul_const.result0_data()[i], results[i]);
}
EXPECT_EQ(matmul_const.result0_data(), matmul.results()[0]);
}
}
// Run tests that use set_argN_data separately, to avoid accidentally re-using
// non-existent buffers.
TEST(TFCompileTest, MatMul2_SetArg) {
Eigen::ThreadPool tp(2);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
foo::bar::MatMulComp matmul(
foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY);
matmul.set_thread_pool(&device);
// Test using the set_argN_data() methods.
float arg0[2][3] = {{1, 2, 3}, {4, 5, 6}};
float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}};
matmul.set_arg0_data(&arg0);
matmul.set_arg1_data(&arg1);
EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
EXPECT_TRUE(matmul.Run());
EXPECT_EQ(matmul.error_msg(), "");
const float results[4] = {58, 64, 139, 154};
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
EXPECT_EQ(matmul.result0_data()[i], results[i]);
}
EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
}
TEST(TFCompileTest, MatMulAndAdd1) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
MatMulAndAddComp muladd;
muladd.set_thread_pool(&device);
EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]);
EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]);
// Test methods with positional args and results.
{
const float args[8] = {1, 2, 3, 4, 5, 6, 7, 8};
std::copy(args + 0, args + 4, muladd.arg0_data());
std::copy(args + 4, args + 8, muladd.arg1_data());
EXPECT_TRUE(muladd.Run());
EXPECT_EQ(muladd.error_msg(), "");
const float results0[4] = {19, 22, 43, 50};
const float results1[4] = {6, 8, 10, 12};
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd.result0(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd.result0_data()[i], results0[i]);
EXPECT_EQ(muladd.result1(i / 2, i % 2), results1[i]);
EXPECT_EQ(muladd.result1_data()[i], results1[i]);
}
EXPECT_EQ(muladd.result0_data(), muladd.results()[0]);
EXPECT_EQ(muladd.result1_data(), muladd.results()[1]);
const MatMulAndAddComp& muladd_const = muladd;
EXPECT_EQ(muladd_const.error_msg(), "");
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]);
EXPECT_EQ(muladd_const.arg0_data()[i], args[i]);
}
EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]);
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]);
EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]);
}
EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]);
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd_const.result0_data()[i], results0[i]);
EXPECT_EQ(muladd_const.result1(i / 2, i % 2), results1[i]);
EXPECT_EQ(muladd_const.result1_data()[i], results1[i]);
}
EXPECT_EQ(muladd_const.result0_data(), muladd.results()[0]);
EXPECT_EQ(muladd_const.result1_data(), muladd.results()[1]);
}
// Test methods with named args and results.
{
const float args[8] = {10, 20, 30, 40, 50, 60, 70, 80};
std::copy(args + 0, args + 4, muladd.arg_x_data());
std::copy(args + 4, args + 8, muladd.arg_y_data());
EXPECT_TRUE(muladd.Run());
EXPECT_EQ(muladd.error_msg(), "");
const float results0[4] = {1900, 2200, 4300, 5000};
const float results1[4] = {60, 80, 100, 120};
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd.result_x_y_prod(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd.result_x_y_prod_data()[i], results0[i]);
EXPECT_EQ(muladd.result_x_y_sum(i / 2, i % 2), results1[i]);
EXPECT_EQ(muladd.result_x_y_sum_data()[i], results1[i]);
}
EXPECT_EQ(muladd.result_x_y_prod_data(), muladd.results()[0]);
EXPECT_EQ(muladd.result_x_y_sum_data(), muladd.results()[1]);
// Test const methods.
const MatMulAndAddComp& muladd_const = muladd;
EXPECT_EQ(muladd_const.error_msg(), "");
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]);
EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]);
}
EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]);
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]);
EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]);
}
EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]);
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]);
EXPECT_EQ(muladd_const.result_x_y_sum(i / 2, i % 2), results1[i]);
EXPECT_EQ(muladd_const.result_x_y_sum_data()[i], results1[i]);
}
EXPECT_EQ(muladd_const.result_x_y_prod_data(), muladd.results()[0]);
EXPECT_EQ(muladd_const.result_x_y_sum_data(), muladd.results()[1]);
}
}
} // namespace
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,285 @@
# -*- Python -*-
"""Build macro that compiles a TensorFlow graph into a cc_library.
To use from your BUILD file, add the following line to load the macro:
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
Then call the macro like this:
tf_library(
name = "test_graph_tfmatmul",
config = "test_graph_tfmatmul.config.pbtxt",
cpp_class = "MatMulComp",
graph = ":test_graph_tfmatmul.pb",
)
"""
load("//tensorflow:tensorflow.bzl", "if_android", "tf_copts")
def tf_library(name, graph, config,
freeze_checkpoint=None, freeze_saver=None,
cpp_class=None, gen_test=True, gen_benchmark=True,
visibility=None, testonly=None,
tfcompile_flags=None,
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
deps=None, tags=None):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
Args:
name: The name of the build rule.
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
is expected to be in the human-readable proto text format, otherwise it is
expected to be in the proto binary format.
config: File containing tensorflow.tfcompile.Config proto. If the file ends
in '.pbtxt' it is expected to be in the human-readable proto text format,
otherwise it is expected to be in the proto binary format.
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
convert variables into constants.
freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
binary form, to convert variables into constants.
cpp_class: The name of the generated C++ class, wrapping the generated
function. The syntax of this flag is
[[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
for referring to a class, where multiple namespaces may precede the class
name, separated by double-colons. The class will be generated in the
given namespace(s), or if no namespaces are given, within the global
namespace.
gen_test: If True, also generate a cc_test rule that builds a simple
test and benchmark.
gen_benchmark: If True, also generate a binary with a simple benchmark.
Unlike the output of gen_test, this benchmark can be run on android.
visibility: Bazel build visibility.
testonly: Bazel testonly attribute.
tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
tfcompile_tool: The tfcompile binary. A non-default can be passed to
use a tfcompile built with extra dependencies.
deps: a list of extra deps to include on the build rules for
the generated library.
tags: tags to apply to subsidiary build rules.
The output header is called <name>.h.
"""
if not cpp_class:
fail("cpp_class must be specified")
tfcompile_graph = graph
if freeze_checkpoint or freeze_saver:
if not freeze_checkpoint:
fail("freeze_checkpoint must be specified when freeze_saver is specified")
freeze_name = "freeze_" + name
freeze_file = freeze_name + ".pb"
# First run tfcompile to generate the list of out_nodes.
out_nodes_file = "out_nodes_" + freeze_name
native.genrule(
name=("gen_" + out_nodes_file),
srcs=[config],
outs=[out_nodes_file],
cmd=("$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"),
tools=[tfcompile_tool],
# Run tfcompile on the build host, rather than forge, since it's
# typically way faster on the local machine.
local=1,
tags=tags,
)
# Now run freeze_graph to convert variables into constants.
freeze_args = (" --input_graph=$(location " + graph + ")" +
" --input_binary=" + str(not graph.endswith(".pbtxt")) +
" --input_checkpoint=$(location " + freeze_checkpoint + ")" +
" --output_graph=$(location " + freeze_file + ")" +
" --output_node_names=$$(<$(location " + out_nodes_file +
"))")
freeze_saver_srcs = []
if freeze_saver:
freeze_args += " --input_saver=$(location " + freeze_saver + ")"
freeze_saver_srcs += [freeze_saver]
native.genrule(
name=freeze_name,
srcs=[
graph,
freeze_checkpoint,
out_nodes_file,
] + freeze_saver_srcs,
outs=[freeze_file],
cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
freeze_args),
tools=["//tensorflow/python/tools:freeze_graph"],
tags=tags,
)
tfcompile_graph = freeze_file
# Rule that runs tfcompile to produce the header and object file.
header_file = name + ".h"
object_file = name + ".o"
ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_")
native.genrule(
name=("gen_" + name),
srcs=[
tfcompile_graph,
config,
],
outs=[
header_file,
object_file,
],
cmd=("$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_header=$(@D)/" + header_file +
" --out_object=$(@D)/" + object_file +
" " + (tfcompile_flags or "")),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
# Run tfcompile on the build host since it's typically faster on the local
# machine.
#
# Note that setting the local=1 attribute on a *test target* causes the
# test infrastructure to skip that test. However this is a genrule, not a
# test target, and runs with --genrule_strategy=forced_forge, meaning the
# local=1 attribute is ignored, and the genrule is still run.
#
# https://www.bazel.io/versions/master/docs/be/general.html#genrule
local=1,
tags=tags,
)
# The cc_library rule packaging up the header and object file, and needed
# kernel implementations.
native.cc_library(
name=name,
srcs=[object_file],
hdrs=[header_file],
visibility=visibility,
testonly=testonly,
deps = [
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
"//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int32",
"//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int64",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
"//tensorflow/core:framework_lite",
] + (deps or []),
tags=tags,
)
# Variables used for gen_test and gen_benchmark.
no_ns_name = ""
cpp_class_split = cpp_class.rsplit("::", maxsplit=2)
if len(cpp_class_split) == 1:
no_ns_name = cpp_class_split[0]
else:
no_ns_name = cpp_class_split[1]
sed_replace = (
"-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
"-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
"-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ")
if gen_test:
test_name = name + "_test"
test_file = test_name + ".cc"
# Rule to rewrite test.cc to produce the test_file.
native.genrule(
name=("gen_" + test_name),
testonly=1,
srcs=[
"//tensorflow/compiler/aot:test.cc",
header_file,
],
outs=[test_file],
cmd=("sed " + sed_replace +
" $(location //tensorflow/compiler/aot:test.cc) " +
"> $(OUTS)"),
tags=tags,
)
# The cc_test rule for the generated code.
native.cc_test(
name=test_name,
srcs=[test_file],
deps=[
":" + name,
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/aot:tf_library_test_main",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
tags=tags,
)
if gen_benchmark:
benchmark_name = name + "_benchmark"
benchmark_file = benchmark_name + ".cc"
benchmark_main = ("//tensorflow/compiler/aot:" +
"benchmark_main.template")
# Rule to rewrite benchmark.cc to produce the benchmark_file.
native.genrule(
name=("gen_" + benchmark_name),
srcs=[
benchmark_main,
header_file,
],
testonly = testonly,
outs=[benchmark_file],
cmd=("sed " + sed_replace +
" $(location " + benchmark_main + ") " +
"> $(OUTS)"),
tags=tags,
)
# The cc_benchmark rule for the generated code.
#
# Note: to get smaller size on android for comparison, compile with:
# --copt=-fvisibility=hidden
# --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
# --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
native.cc_binary(
name=benchmark_name,
srcs=[benchmark_file],
testonly = testonly,
copts = tf_copts(),
linkopts = if_android(["-pie", "-s"]),
deps=[
":" + name,
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
"//tensorflow/compiler/aot:benchmark",
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
] + if_android([
"//tensorflow/compiler/aot:benchmark_extra_android",
]),
tags=tags,
)
def target_llvm_triple():
"""Returns the target LLVM triple to be used for compiling the target."""
# TODO(toddw): Add target_triple for other targets. For details see:
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
return select({
"//tensorflow:android_arm": "armv7-none-android",
"//tensorflow:android_arm64": "aarch64-none-android",
"//conditions:default": "x86_64-pc-linux",
})

View File

@ -0,0 +1,43 @@
syntax = "proto3";
package tensorflow.tfcompile;
option cc_enable_arenas = true;
option java_outer_classname = "CompileProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.tfcompile";
import "tensorflow/core/framework/tensor_shape.proto";
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
// index of a particular node in the graph. If the output of the named node
// feeds into other node(s), this corresponds to one or more edges. Otherwise
// it doesn't correspond to any existing edges at all, e.g. for output nodes.
message TensorId {
string node_name = 1;
int64 output_index = 2;
};
// Feed represents a single feed tensor in the graph, which corresponds to an
// input argument for the generated function.
message Feed {
TensorId id = 1;
TensorShapeProto shape = 2;
string name = 3; // Optional name for generated code.
};
// Fetch represents a single fetch tensor in the graph, which corresponds to an
// output argument for the generated function.
message Fetch {
TensorId id = 1;
string name = 2; // Optional name for generated code.
};
// Config represents configuration information for tfcompile.
message Config {
// Each feed is a positional input argument for the generated function. The
// order of each entry matches the order of each input argument.
repeated Feed feed = 1;
// Each fetch is a positional output argument for the generated function. The
// order of each entry matches the order of each output argument.
repeated Fetch fetch = 2;
};

View File

@ -0,0 +1,142 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace tfcompile {
const char kUsageHeader[] =
"tfcompile performs ahead-of-time compilation of a TensorFlow graph,\n"
"resulting in an object file compiled for your target architecture, and a\n"
"header file that gives access to the functionality in the object file.\n"
"A typical invocation looks like this:\n"
"\n"
" $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n"
"\n";
Status ReadProtoFile(const string& kind, const string& fname,
protobuf::Message* proto) {
if (StringPiece(fname).ends_with(".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
void ParseTensorId(const string& name, TensorId* id) {
const std::pair<StringPiece, int> name_index = ParseTensorName(name);
id->set_node_name(name_index.first.ToString());
id->set_output_index(name_index.second);
}
Status Main(const MainFlags& flags) {
// Process config.
Config config;
TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << str_util::Join(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def));
std::unique_ptr<Graph> graph;
FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library());
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &flib, &graph));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph), flags, &flib, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object,
StringPiece(obj.data(), obj.size())));
HeaderOpts header_opts;
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name,
&header_opts.namespaces));
string header;
TF_RETURN_IF_ERROR(
GenerateHeader(header_opts, config, compile_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // end namespace tfcompile
} // end namespace tensorflow
int main(int argc, char** argv) {
tensorflow::tfcompile::MainFlags flags;
flags.target_triple = "x86_64-pc-linux";
flags.out_object = "out.o";
flags.out_header = "out.h";
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list);
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
tensorflow::string usage = tensorflow::tfcompile::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
QCHECK(parsed_flags_ok) << "\n" << usage;
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(argc == 1 && !flags.config.empty() &&
(flags.dump_fetch_nodes ||
(!flags.graph.empty() && !flags.entry_point.empty())))
<< "\n"
<< usage;
TF_QCHECK_OK(tensorflow::tfcompile::Main(flags));
return 0;
}

View File

@ -0,0 +1,119 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include <set>
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
namespace tfcompile {
namespace {
bool IsAlpha(char c) {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
}
bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
Status ValidateTensorId(const TensorId& id) {
if (id.node_name().empty()) {
return errors::InvalidArgument("TensorId node_name must be non-empty");
}
if (id.output_index() < 0) {
return errors::InvalidArgument("TensorId output_index must be positive");
}
return Status::OK();
}
Status ValidateFeedFetchName(const string& kind, const string& name,
std::set<string>* names) {
if (!name.empty()) {
TF_RETURN_IF_ERROR(ValidateCppIdent(name, kind + " name"));
if (!names->insert(name).second) {
return errors::InvalidArgument("duplicate ", kind, " name: ", name);
}
}
return Status::OK();
}
Status CheckFeedFetchNameConflicts(const string& kind,
const std::set<string>& names) {
// We don't allow the feeds or fetches to contain both "foo" and "foo_data",
// since that will cause a collision in codegen symbols.
for (const string& name : names) {
const string name_data(name + "_data");
if (names.find(name_data) != names.end()) {
return errors::InvalidArgument("conflicting ", kind, " name: ", name,
" and ", name_data);
}
}
return Status::OK();
}
} // namespace
Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
if (ident.empty()) {
return errors::InvalidArgument("empty identifier: ", msg);
}
// Require that the identifier starts with a nondigit, and is composed of
// nondigits and digits, as specified in section [2.11 Identifiers] of the
// C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
// defined as [0-9].
//
// Technically the standard also allows for `universal-character-name`, with a
// table of allowed unicode ranges, as well as `other implementation-defined
// characters`. We disallow those here to give better error messages, at the
// expensive of being more restrictive than the standard.
if (ident[0] != '_' && !IsAlpha(ident[0])) {
return errors::InvalidArgument("illegal leading char: ", msg);
}
for (size_t pos = 1; pos < ident.size(); ++pos) {
if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
return errors::InvalidArgument("illegal char: ", msg);
}
}
return Status::OK();
}
Status ValidateConfig(const Config& config) {
std::set<string> names;
for (const Feed& feed : config.feed()) {
TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
TF_RETURN_IF_ERROR(ValidateFeedFetchName("feed", feed.name(), &names));
}
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
names.clear();
for (const Fetch& fetch : config.fetch()) {
TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
TF_RETURN_IF_ERROR(ValidateFeedFetchName("fetch", fetch.name(), &names));
}
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
if (config.feed().empty() || config.fetch().empty()) {
return errors::InvalidArgument("feeds and fetches must be specified");
}
return Status::OK();
}
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,36 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
namespace tfcompile {
// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
// appended to error messages.
Status ValidateCppIdent(StringPiece ident, StringPiece msg);
// ValidateConfig returns OK iff config is valid.
Status ValidateConfig(const Config& config);
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_

View File

@ -0,0 +1,185 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace tfcompile {
namespace {
void ExpectErrorContains(Status status, StringPiece str) {
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
<< "expected error: " << status.error_message() << " to contain: " << str;
}
TEST(ValidateCppIdent, Simple) {
TF_EXPECT_OK(ValidateCppIdent("a", ""));
TF_EXPECT_OK(ValidateCppIdent("abc", ""));
TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
// Make sure we didn't skip a valid letter or digit
string ident;
for (char c = 'a'; c <= 'z'; c++) {
ident.append(1, c);
}
for (char c = 'A'; c <= 'Z'; c++) {
ident.append(1, c);
}
for (char c = '0'; c <= '9'; c++) {
ident.append(1, c);
}
ident += "_";
TF_EXPECT_OK(ValidateCppIdent(ident, ""));
ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
}
TEST(ValidateConfig, Good) {
Config config;
Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->mutable_id()->set_output_index(123);
feed->set_name("foo_debug");
feed = config.add_feed();
feed->mutable_id()->set_node_name("bar");
feed->mutable_id()->set_output_index(0);
Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("baz");
fetch->mutable_id()->set_output_index(456);
fetch->set_name("baz_debug");
fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("banana");
fetch->mutable_id()->set_output_index(0);
TF_EXPECT_OK(ValidateConfig(config));
}
TEST(ValidateConfig, BadEmpty) {
Config config;
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadNoFeed) {
Config config;
Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("foo");
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadNoFetch) {
Config config;
Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadFeedNodeName) {
Config config;
config.add_feed();
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
}
TEST(ValidateConfig, BadFeedOutputIndex) {
Config config;
Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->mutable_id()->set_output_index(-1);
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
}
TEST(ValidateConfig, BadFetchNodeName) {
Config config;
Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
config.add_fetch();
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
}
TEST(ValidateConfig, BadFetchOutputIndex) {
Config config;
Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->mutable_id()->set_output_index(-1);
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
}
TEST(ValidateConfig, DuplicateFeedName) {
Config config;
Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->set_name("dup");
feed = config.add_feed();
feed->mutable_id()->set_node_name("bar");
feed->set_name("dup");
ExpectErrorContains(ValidateConfig(config), "duplicate feed name");
}
TEST(ValidateConfig, DuplicateFetchName) {
Config config;
Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->set_name("dup");
fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("baz");
fetch->set_name("dup");
ExpectErrorContains(ValidateConfig(config), "duplicate fetch name");
}
TEST(ValidateConfig, ConflictingFeedName) {
Config config;
Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->set_name("conflict");
feed = config.add_feed();
feed->mutable_id()->set_node_name("bar");
feed->set_name("conflict_data");
ExpectErrorContains(ValidateConfig(config), "conflicting feed name");
}
TEST(ValidateConfig, ConflictingFetchName) {
Config config;
Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->set_name("conflict");
fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("baz");
fetch->set_name("conflict_data");
ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
}
} // namespace
} // namespace tfcompile
} // namespace tensorflow

View File

@ -0,0 +1,282 @@
licenses(["notice"]) # Apache 2.0
package_group(
name = "internal",
includes = [
"//tensorflow/compiler/tf2xla:internal",
],
)
package_group(
name = "friends",
includes = [
"//tensorflow/compiler/tf2xla:friends",
],
)
package(
default_visibility = [":internal"],
)
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
name = "jit",
visibility = [":friends"],
deps = [
":xla_cpu_device",
":xla_cpu_jit",
":xla_gpu_device",
":xla_gpu_jit",
],
alwayslink = 1,
)
cc_library(
name = "xla_cpu_jit",
visibility = [":friends"],
deps = [
":jit_compilation_passes",
":xla_local_launch_op",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
],
alwayslink = 1,
)
cc_library(
name = "xla_gpu_jit",
visibility = [":friends"],
deps = [
":jit_compilation_passes",
":xla_local_launch_op",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin",
],
alwayslink = 1,
)
cc_library(
name = "xla_cpu_device",
srcs = ["xla_cpu_device.cc"],
visibility = [":friends"],
deps = [
":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_library(
name = "xla_gpu_device",
srcs = ["xla_gpu_device.cc"],
visibility = [":friends"],
deps = [
":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
],
alwayslink = 1,
)
# Internal targets below this point.
cc_library(
name = "common",
srcs = [
"defs.cc",
],
hdrs = [
"defs.h",
],
visibility = [":friends"],
)
cc_library(
name = "xla_device",
srcs = [
"xla_device.cc",
"xla_device_context.cc",
"xla_device_launch_op.cc",
"xla_device_ops.cc",
],
hdrs = [
"xla_device.h",
"xla_device_context.h",
"xla_device_launch_op.h",
"xla_device_ops.h",
],
deps = [
":common",
":jit_compilation_passes",
":xla_compilation_cache",
":xla_local_launch_op",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core/kernels:assign_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:variable_ops",
],
alwayslink = 1,
)
cc_library(
name = "xla_compilation_cache",
srcs = ["xla_compilation_cache.cc"],
hdrs = ["xla_compilation_cache.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "jit_compilation_passes",
srcs = ["jit_compilation_pass_registration.cc"],
deps = [
":compilation_passes",
"//tensorflow/core:core_cpu_internal",
],
alwayslink = 1,
)
cc_library(
name = "compilation_passes",
srcs = [
"build_xla_launch_ops_pass.cc",
"encapsulate_subgraphs_pass.cc",
"graph_to_functiondef.cc",
"mark_for_compilation_pass.cc",
],
hdrs = [
"build_xla_launch_ops_pass.h",
"encapsulate_subgraphs_pass.h",
"graph_to_functiondef.h",
"mark_for_compilation_pass.h",
],
deps = [
":common",
":parallel_check_op",
":xla_local_launch_op",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/tf2xla:const_analysis",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "compilation_passes_test",
size = "small",
srcs = [
"encapsulate_subgraphs_pass_test.cc",
"graph_to_functiondef_test.cc",
"mark_for_compilation_pass_test.cc",
],
deps = [
":compilation_passes",
":xla_local_launch_op",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "xla_local_launch_op",
srcs = ["xla_local_launch_op.cc"],
hdrs = ["xla_local_launch_op.h"],
deps = [
":common",
":xla_compilation_cache",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:tensorflow_opensource",
],
alwayslink = 1,
)
tf_kernel_library(
name = "parallel_check_op",
srcs = ["parallel_check_op.cc"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
alwayslink = 1,
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,215 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/xla_local_launch_op.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
static Status BuildLaunchNode(
const string& nodename, const string& function_name,
const AttrValueMap& function_attr, const string& device_name,
const DataTypeVector& constant_dtypes, const DataTypeVector& arg_dtypes,
const DataTypeVector& result_dtypes, Graph* graph, Node** node) {
NodeDef def;
def.set_name(graph->NewName(nodename));
def.set_op("_XlaLaunch");
def.set_device(device_name);
AddNodeAttr("Tconstants", constant_dtypes, &def);
AddNodeAttr("Targs", arg_dtypes, &def);
AddNodeAttr("Tresults", result_dtypes, &def);
NameAttrList function;
function.set_name(function_name);
*function.mutable_attr() = function_attr;
AddNodeAttr("function", function, &def);
Status status;
*node = graph->AddNode(def, &status);
return status;
}
static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
int num_constant_args;
TF_RETURN_IF_ERROR(
GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args));
if (num_constant_args < 0 || num_constant_args > node->input_types().size()) {
return errors::InvalidArgument(
"Invalid number of constant arguments to XLA kernel");
}
DataTypeVector const_dtypes(node->input_types().begin(),
node->input_types().begin() + num_constant_args);
DataTypeVector arg_dtypes(node->input_types().begin() + num_constant_args,
node->input_types().end());
// Build a _XlaLaunch operator to execute the function body.
Node* launch_node;
TF_RETURN_IF_ERROR(
BuildLaunchNode(graph->NewName(node->name()), node->type_string(),
node->def().attr(), node->def().device(), const_dtypes,
arg_dtypes, node->output_types(), graph, &launch_node));
launch_node->set_assigned_device_name(node->assigned_device_name());
// Copy incoming edges to the launch node.
for (const Edge* edge : node->in_edges()) {
if (edge->IsControlEdge()) {
graph->AddControlEdge(edge->src(), launch_node);
} else {
graph->AddEdge(edge->src(), edge->src_output(), launch_node,
edge->dst_input());
}
}
// Copy outgoing edges to the launch node.
std::vector<const Edge*> out_edges(node->out_edges().begin(),
node->out_edges().end());
for (const Edge* edge : out_edges) {
Node* dst = edge->dst();
int src_output = edge->src_output();
int dst_input = edge->dst_input();
graph->RemoveEdge(edge);
if (edge->IsControlEdge()) {
graph->AddControlEdge(launch_node, dst);
} else {
graph->AddEdge(launch_node, src_output, dst, dst_input);
}
}
graph->RemoveNode(node);
return Status::OK();
}
Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
for (Node* n : graph->nodes()) {
// In all cases, only try to compile computational nodes.
if (!n->IsOp() || n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
continue;
}
// Only compile nodes that are marked for compilation by the
// compilation-marking pass (via 'attr_name').
if (IsXlaCompiledKernel(*n)) {
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
}
}
return Status::OK();
}
namespace {
// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
// 'ndef' is a call to a compilable function defined in 'flr', returns OK
// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
// node. Otherwise, returns a non-OK.
//
// This routine is here so that FunctionLibraryRuntime can jit a
// specific function call as requested.
Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
std::unique_ptr<OpKernel>* kernel) {
bool xla_compile = false;
if (!flr->GetFunctionLibraryDefinition()
->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
.ok() ||
!xla_compile) {
// Not marked as _XlaCompile=true.
return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
}
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterJitKernels();
if (!IsCompilable(flr, ndef)) {
// ndef is calling a function that XLA can't compile.
return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
}
FunctionLibraryRuntime::Handle handle;
// If ndef is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle));
const FunctionBody* fbody = flr->GetFunctionBody(handle);
CHECK(fbody); // Can't be nullptr since we just instantiated it.
std::vector<bool> const_args(fbody->arg_types.size());
// If we can't analyze the const args. Bail out.
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
// There is a const arg. Bail out.
return errors::InvalidArgument("Const arg: ", i, " in ",
DebugString(fbody->fdef));
}
}
NodeDef launch_def;
launch_def.set_name(ndef.name());
launch_def.set_op("_XlaLaunch");
launch_def.set_device(flr->device()->name());
AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
AddNodeAttr("Targs", fbody->arg_types, &launch_def);
AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
NameAttrList func;
func.set_name(ndef.op());
*(func.mutable_attr()) = ndef.attr();
AddNodeAttr("function", func, &launch_def);
// TODO(b/32387911): Handles the host memory types across function
// calls properly. For now, we assume all inputs and outputs are on
// the device memory.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &launch_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
kernel->reset(new XlaLocalLaunchOp(&construction));
return s;
}
bool RegisterLaunchOpCreator() {
RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
return true;
}
static bool register_me = RegisterLaunchOpCreator();
} // end namespace
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class BuildXlaLaunchOpsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_

View File

@ -0,0 +1,22 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/defs.h"
namespace tensorflow {
const char* const kXlaCompileAttr = "_XlaCompile";
} // namespace tensorflow

View File

@ -0,0 +1,29 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Provides definitions needed for use of the TensorFlow XLA
// device.
#ifndef TENSORFLOW_COMPILER_JIT_DEFS_H_
#define TENSORFLOW_COMPILER_JIT_DEFS_H_
namespace tensorflow {
// Name of attribute used to tag operators for compilation with XLA
extern const char* const kXlaCompileAttr; // "_XlaCompile"
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_DEFS_H_

View File

@ -0,0 +1,660 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include <functional>
#include <numeric>
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
namespace {
// A node/slot pair.
// TODO(phawkins): is there a common definition of this?
struct NodeSlot {
NodeSlot() : node(nullptr), slot(-1) {}
NodeSlot(const Node* node, int slot) : node(node), slot(slot) {}
const Node* node;
int slot;
bool operator==(const NodeSlot& other) const {
return node == other.node && slot == other.slot;
}
struct Hasher {
uint64 operator()(NodeSlot const& s) const {
return Hash64Combine(std::hash<const Node*>()(s.node),
std::hash<int>()(s.slot));
}
};
struct PairHasher {
uint64 operator()(std::pair<NodeSlot, NodeSlot> const& s) const {
return Hash64Combine(Hasher()(s.first), Hasher()(s.second));
}
};
};
class Encapsulator {
public:
Encapsulator(string group_attribute, Graph const* graph_in)
: group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {}
// Find subgraphs marked with 'group_attribute', and build a new
// subgraph, one for each value of 'group_attribute'.
Status SplitIntoSubgraphs();
// Build a FunctionDef for each subgraph, and add it 'library'. The values of
// the 'group_attribute' annotations become the function names.
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
// function conversion.
Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
FunctionLibraryDefinition* library);
// Write a copy of the input graph to 'graph_out', where the subgraphs are
// replaced with calls to the new functions.
Status BuildOutputGraph(bool parallel_checking, Graph* graph_out);
private:
// Returns the key attribute associated with a node. Returns the empty string
// if no key attribute is found.
string GetFunctionNameAttr(const Node* node) const;
// A subgraph of the input, all marked with a common 'group_attribute'
// value.
struct Subgraph {
// The subgraph extracted from the input graph, suitable for being turned
// into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
// returned by _Retval nodes.
std::unique_ptr<Graph> graph;
// Which device are these nodes on? Used both to check that all nodes
// are assigned to the same device, and to assign a device to the call node.
string device;
// NodeDef for the function call node.
NodeDef call_node_def;
// Function call node(s) in the output graph. Not owned.
// If parallel_checking is enabled, 'call_node_inputs' is the function call
// node to which inputs should be fed, and 'call_node_outputs' is the
// parallel check op from which outputs should be read. If parallel checking
// is disabled, both point to the function call node.
Node* call_node_inputs;
Node* call_node_outputs;
// Maps from source (producer node/slot) and destination
// (consumer node/slot) tensors in the input graph to _Arg numbers in
// the subgraph. The source map is one-to-one, whereas the dest map may be
// many-to-one.
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_src;
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_dst;
// The _Arg nodes in the subgraph, in order by argument number.
std::vector<Node*> args;
// Map from source tensor in the input graph to result #.
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> results;
};
// Builds a ParallelCheck op that compares the output of the original subgraph
// with the encapsulated subgraph.
Status BuildParallelCheckOp(
const std::unordered_map<const Node*, Node*>& node_images,
const Subgraph& subgraph, Graph* graph_out, Node** parallel_check_op);
const string group_attribute_;
const Graph* graph_in_;
std::unordered_map<string, Subgraph> subgraphs_;
TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
};
// TODO(phawkins) add a canonical copy of these operator names and refactor
// everything to use it.
static const char* const kArgOp = "_Arg";
static const char* const kRetValOp = "_Retval";
// Returns the function name attached to 'node', or the empty string if there is
// none.
string Encapsulator::GetFunctionNameAttr(Node const* node) const {
string attr;
if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) {
attr.clear();
}
return attr;
}
Status Encapsulator::SplitIntoSubgraphs() {
Status s;
// Map from input graph nodes to subgraph nodes.
std::unordered_map<Node*, Node*> node_images;
// Copy all marked nodes to a subgraph. Do nothing for unmarked nodes.
for (Node* node : graph_in_->nodes()) {
if (node->IsSource() || node->IsSink()) continue;
string func_id = GetFunctionNameAttr(node);
if (func_id.empty()) continue;
Subgraph& subgraph = subgraphs_[func_id];
if (!subgraph.graph) {
subgraph.graph.reset(new Graph(graph_in_->op_registry()));
subgraph.graph->set_versions(graph_in_->versions());
}
Node* image = subgraph.graph->CopyNode(node);
image->ClearAttr(group_attribute_);
node_images[node] = image;
// Check the device matches any existing device.
string device = node->assigned_device_name().empty()
? node->def().device()
: node->assigned_device_name();
if (subgraph.device.empty()) {
subgraph.device = device;
} else if (subgraph.device != device) {
s.Update(errors::InvalidArgument(
"Mismatched devices for nodes to be grouped by Encapsulator"));
}
}
// Copy edges local to a subgraph. Add _Arg and _Retval nodes to subgraphs for
// data edges that cross subgraph boundaries.
for (const Edge* edge : graph_in_->edges()) {
string src_func_id = GetFunctionNameAttr(edge->src());
string dst_func_id = GetFunctionNameAttr(edge->dst());
Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
// Copy edges that are local to a subgraph.
if (!src_func_id.empty() && src_func_id == dst_func_id) {
Graph* g = subgraphs_[src_func_id].graph.get();
if (edge->IsControlEdge()) {
g->AddControlEdge(src_image, dst_image);
} else {
g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
}
continue;
}
// Ignore cross-boundary control edges for right now. We will lift them
// onto the enclosing call operators in BuildOutputGraph().
if (edge->IsControlEdge()) continue;
// Add 'src' as an output of its subgraph, if applicable.
if (!src_func_id.empty()) {
Subgraph& src_subgraph = subgraphs_[src_func_id];
int ret_index = src_subgraph.results.size();
if (src_subgraph.results
.emplace(NodeSlot(edge->src(), edge->src_output()), ret_index)
.second) {
// Create a new _Retval node
DataType dtype = edge->src()->output_type(edge->src_output());
NodeDef ret_def;
ret_def.set_op(kRetValOp);
ret_def.set_name(src_subgraph.graph->NewName("output"));
AddNodeAttr("T", dtype, &ret_def);
AddNodeAttr("index", ret_index, &ret_def);
Node* ret = src_subgraph.graph->AddNode(ret_def, &s);
if (!s.ok()) return s;
// Add an edge from 'src' to _Retval.
src_subgraph.graph->AddEdge(src_image, edge->src_output(), ret, 0);
}
}
// Add 'dst' as an input of its subgraph, if applicable.
if (!dst_func_id.empty()) {
Subgraph& dst_subgraph = subgraphs_[dst_func_id];
// Create an _Arg node for this tensor, if none exists yet.
std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
bool inserted;
std::tie(iter, inserted) = dst_subgraph.args_by_src.emplace(
NodeSlot(edge->src(), edge->src_output()), dst_subgraph.args.size());
int arg_index = iter->second;
if (inserted) {
// This is the first time we have seen this tensor. Create an _Arg node.
DataType dtype = edge->dst()->input_type(edge->dst_input());
NodeDef arg_def;
NodeDefBuilder builder(dst_subgraph.graph->NewName("input"), kArgOp);
builder.Attr("T", dtype);
builder.Attr("index", arg_index);
s = builder.Finalize(&arg_def);
if (!s.ok()) return s;
Node* arg = dst_subgraph.graph->AddNode(arg_def, &s);
if (!s.ok()) return s;
dst_subgraph.args.push_back(arg);
}
// Add an edge from the _Arg node to 'dst' in the subgraph.
dst_subgraph.args_by_dst[NodeSlot(edge->dst(), edge->dst_input())] =
arg_index;
dst_subgraph.graph->AddEdge(dst_subgraph.args[arg_index], 0, dst_image,
edge->dst_input());
}
}
for (auto& entry : subgraphs_) {
FixupSourceAndSinkEdges(entry.second.graph.get());
}
return s;
}
Status Encapsulator::BuildFunctionDefs(
const RewriteSubgraphFn& rewrite_subgraph_fn,
FunctionLibraryDefinition* library) {
// For each subgraph, build a FunctionDef.
for (auto& subgraph_entry : subgraphs_) {
const string& name = subgraph_entry.first;
Subgraph& subgraph = subgraph_entry.second;
subgraph.call_node_def.set_op(name);
subgraph.call_node_def.set_name(name);
subgraph.call_node_def.set_device(subgraph.device);
if (rewrite_subgraph_fn) {
// Initialize the input and output permutations to the identity.
std::vector<int> input_permutation(subgraph.args_by_src.size());
std::iota(input_permutation.begin(), input_permutation.end(), 0);
std::vector<int> output_permutation(subgraph.results.size());
std::iota(output_permutation.begin(), output_permutation.end(), 0);
TF_RETURN_IF_ERROR(
rewrite_subgraph_fn(&subgraph.graph, &input_permutation,
&output_permutation, &subgraph.call_node_def));
// Apply the input/output permutations to the 'args_by_...' and 'results'
// mappings in 'subgraph', so when we build edges in BuildOutputGraph() we
// connect them to the right input/output positions.
if (input_permutation.size() != subgraph.args_by_src.size()) {
return errors::InvalidArgument("Input permutation has incorrect size.");
}
if (output_permutation.size() != subgraph.results.size()) {
return errors::InvalidArgument(
"Output permutation has incorrect size.");
}
for (auto& arg : subgraph.args_by_src) {
arg.second = input_permutation[arg.second];
}
for (auto& arg : subgraph.args_by_dst) {
arg.second = input_permutation[arg.second];
}
for (auto& result : subgraph.results) {
result.second = output_permutation[result.second];
}
}
FunctionDef fdef;
TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph.graph, name, &fdef));
if (VLOG_IS_ON(1)) {
VLOG(2) << "Build function def " << name;
dump_graph::DumpGraphToFile(
strings::StrCat("encapsulate_fdef_graph_", name), *subgraph.graph,
library);
dump_graph::DumpFunctionDefToFile(
strings::StrCat("encapsulate_fdef_", name), fdef);
}
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
}
return Status::OK();
}
Status Encapsulator::BuildParallelCheckOp(
const std::unordered_map<const Node*, Node*>& node_images,
const Encapsulator::Subgraph& subgraph, Graph* graph_out,
Node** parallel_check_op) {
// Build an index mapping output positions to node/slot pairs in the
// original graph.
std::vector<NodeSlot> results_by_num(subgraph.results.size());
for (const auto& entry : subgraph.results) {
results_by_num[entry.second] = entry.first;
}
// Build a parallel check NodeDef.
int num_results = results_by_num.size();
std::vector<DataType> result_dtypes(num_results);
std::vector<NodeDefBuilder::NodeOut> expected_outputs(num_results);
std::vector<NodeDefBuilder::NodeOut> actual_outputs(num_results);
for (int i = 0; i < num_results; ++i) {
const NodeSlot& node_slot = results_by_num[i];
result_dtypes[i] = node_slot.node->output_type(node_slot.slot);
expected_outputs[i] =
NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(),
node_slot.slot, result_dtypes[i]);
actual_outputs[i] = NodeDefBuilder::NodeOut(subgraph.call_node_def.name(),
i, result_dtypes[i]);
}
// Assign the parallel check op to a CPU on the same task as the cluster it is
// checking.
string device, dummy;
if (!DeviceNameUtils::SplitDeviceName(
subgraph.call_node_inputs->assigned_device_name(), &device, &dummy)) {
return errors::InvalidArgument("Could not parse device name");
}
strings::StrAppend(&device, "/cpu:0");
NodeDef check_def;
TF_RETURN_IF_ERROR(
NodeDefBuilder(graph_out->NewName(strings::StrCat(
subgraph.call_node_def.name(), "_parallel_check")),
"ParallelCheck")
.Device(device)
.Attr("T", result_dtypes)
.Input(expected_outputs)
.Input(actual_outputs)
.Finalize(&check_def));
Status s;
Node* check_op = graph_out->AddNode(check_def, &s);
if (!s.ok()) return s;
check_op->set_assigned_device_name(device);
// TODO(phawkins): it seems redundant to call AddEdge as well as
// pass Inputs to the NodeDefBuilder, but I have been unable to find a
// way to avoid it.
for (int i = 0; i < num_results; ++i) {
const NodeSlot& node_slot = results_by_num[i];
graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op,
i);
graph_out->AddEdge(subgraph.call_node_inputs, i, check_op, num_results + i);
}
*parallel_check_op = check_op;
return Status::OK();
}
Status Encapsulator::BuildOutputGraph(bool parallel_checking,
Graph* graph_out) {
Status s;
// Map from nodes in the input graph to nodes in the output graph.
std::unordered_map<const Node*, Node*> node_images;
// Copy all unmarked nodes to the output graph.
for (Node* node : graph_in_->nodes()) {
if (node->IsSource() || node->IsSink()) continue;
string func_id = GetFunctionNameAttr(node);
// Don't copy nodes that going to be encapsulated, unless parallel checking
// is enabled.
if (!func_id.empty() && !parallel_checking) continue;
Node* image = graph_out->CopyNode(node);
node_images[node] = image;
}
node_images[graph_in_->source_node()] = graph_out->source_node();
node_images[graph_in_->sink_node()] = graph_out->sink_node();
// Add function call nodes for each subgraph.
for (auto& subgraph_entry : subgraphs_) {
Subgraph& subgraph = subgraph_entry.second;
subgraph.call_node_inputs = graph_out->AddNode(subgraph.call_node_def, &s);
if (!s.ok()) return s;
// Copy the assigned device and the key_annotation over.
subgraph.call_node_inputs->set_assigned_device_name(subgraph.device);
subgraph.call_node_outputs = subgraph.call_node_inputs;
if (parallel_checking) {
TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, subgraph, graph_out,
&subgraph.call_node_outputs));
}
}
// Set of edges already added to the output graph, represented as (src, dst)
// pairs. We use the set to deduplicate edges; multiple edges in the input
// graph may map to one edge in the output graph.
std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>
edges_added;
// Add edges to the graph_out graph.
for (const Edge* edge : graph_in_->edges()) {
string src_func_id = GetFunctionNameAttr(edge->src());
string dst_func_id = GetFunctionNameAttr(edge->dst());
// Ignore edges that are strictly contained within one subgraph, unless
// we are constructing parallel check graphs.
if (!src_func_id.empty() && src_func_id == dst_func_id) {
if (parallel_checking) {
Node* src_image = node_images.at(edge->src());
Node* dst_image = node_images.at(edge->dst());
if (edge->IsControlEdge()) {
graph_out->AddControlEdge(src_image, dst_image);
} else {
graph_out->AddEdge(src_image, edge->src_output(), dst_image,
edge->dst_input());
}
}
continue;
}
// We have an edge that crosses a cluster boundary.
Node* src_image = src_func_id.empty()
? node_images.at(edge->src())
: subgraphs_.at(src_func_id).call_node_outputs;
Node* dst_image = dst_func_id.empty()
? node_images.at(edge->dst())
: subgraphs_.at(dst_func_id).call_node_inputs;
// Copy control edges. Lift control edges onto the enclosing call operator.
if (edge->IsControlEdge()) {
// Add the control edge, if we have not already added it.
if (edges_added.emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1))
.second) {
graph_out->AddControlEdge(src_image, dst_image);
}
// If parallel checking is enabled, also add a control edge to the
// corresponding parallel check op.
if (parallel_checking) {
graph_out->AddControlEdge(src_image, node_images.at(edge->dst()));
}
continue;
}
int src_output = edge->src_output();
if (!src_func_id.empty()) {
// 'src' is in a subgraph. Use the corresponding call output instead.
const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
src_output =
src_subgraph.results.at(NodeSlot(edge->src(), edge->src_output()));
}
int dst_input = edge->dst_input();
if (!dst_func_id.empty()) {
// 'dst' is in a subgraph. Use the corresponding call input instead.
const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
dst_input =
dst_subgraph.args_by_dst.at(NodeSlot(edge->dst(), edge->dst_input()));
// If we are parallel checking, also feed the tensor as an input to the
// corresponding parallel check subgraph.
if (parallel_checking) {
graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()),
edge->dst_input());
}
}
// Add the edge, if we have not already added it.
if (edges_added
.emplace(NodeSlot(src_image, src_output),
NodeSlot(dst_image, dst_input))
.second) {
graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
}
}
return s;
}
} // anonymous namespace
Status EncapsulateSubgraphsInFunctions(
string group_attribute, const Graph& graph_in,
const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking,
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
Status s;
Encapsulator encapsulator(std::move(group_attribute), &graph_in);
s = encapsulator.SplitIntoSubgraphs();
if (!s.ok()) return s;
s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, library);
if (!s.ok()) return s;
std::unique_ptr<Graph> out(new Graph(library));
out->set_versions(graph_in.versions());
s = encapsulator.BuildOutputGraph(parallel_checking, out.get());
if (!s.ok()) return s;
*graph_out = std::move(out);
return s;
}
// Renumber the indices of _Arg nodes in a graph, according to
// 'permutation' that maps old indices to new indices.
static Status RenumberArguments(Graph* graph,
const std::vector<int>& permutation) {
for (Node* n : graph->nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
if (index < 0 || index >= permutation.size()) {
return errors::InvalidArgument("Invalid argument number");
}
n->AddAttr("index", permutation[index]);
}
}
return Status::OK();
}
Status EncapsulateSubgraphsPass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "EncapsulateSubgraphsPass::Run";
legacy_flags::EncapsulateSubgraphsPassFlags* flags =
legacy_flags::GetEncapsulateSubgraphsPassFlags();
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph,
options.flib_def);
}
std::unique_ptr<Graph> graph_out;
FunctionLibraryDefinition* const library = options.flib_def;
OptimizerOptions opts;
std::unique_ptr<FunctionLibraryRuntime> flr(
NewFunctionLibraryRuntime(nullptr, options.session_options->env, nullptr,
TF_GRAPH_DEF_VERSION, library, opts));
auto rewrite_subgraph = [&flr](
std::unique_ptr<Graph>* subgraph, std::vector<int>* input_permutation,
std::vector<int>* output_permutation, NodeDef* node) {
// Optimize the subgraph.
Graph* g = subgraph->release();
OptimizeGraph(flr.get(), &g);
subgraph->reset(g);
std::vector<bool> const_args(input_permutation->size());
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*g, &const_args));
// Compute a permutation of the arguments such that the constant arguments
// are first.
const int num_consts =
std::count(const_args.begin(), const_args.end(), true);
int const_pos = 0;
int arg_pos = num_consts;
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
(*input_permutation)[i] = const_pos;
++const_pos;
} else {
(*input_permutation)[i] = arg_pos;
++arg_pos;
}
}
// Renumber argument nodes in the graph.
TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation));
// TODO(phawkins): add a forward is-constant analysis, similarly split
// outputs into host-memory constants and device-memory non-constants.
AddNodeAttr(kXlaCompiledKernelAttr, true, node);
AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
return Status::OK();
};
TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
kXlaClusterAttr, **options.graph, rewrite_subgraph,
flags->tf_xla_parallel_checking, &graph_out, library));
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
options.flib_def);
}
*options.graph = std::move(graph_out);
return Status::OK();
}
bool IsXlaCompiledKernel(const Node& node) {
bool is_compiled = false;
bool has_compilation_attr =
GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
is_compiled;
return has_compilation_attr ? is_compiled : false;
}
} // namespace tensorflow

View File

@ -0,0 +1,86 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// An optimization pass that groups nodes marked with a common
// kXlaClusterAttr into functions, and replaces the original nodes by
// calls. The calls are annotated with kXlaCompiledKernelAttr.
#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// A rewriting function to apply to each subgraph during encapsulation.
// 'graph' is the subgraph. The rewriting may renumber the inputs and outputs;
// 'input_permutation' is a mapping from old argument numbers to new argument
// numbers, whereas 'output_permutation' is the same for outputs. Both
// 'input_permutation' and 'output_permutation' are initialized to the identity
// permutation. 'nodedef' is the NodeDef for the call to the function under
// construction, provided to allow additional attributes to be set.
typedef std::function<Status(
std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
std::vector<int>* output_permutation, NodeDef* node_def)>
RewriteSubgraphFn;
// Transformation that finds subgraphs whose nodes are marked with
// 'group_attribute', splits those subgraphs into functions, and replaces
// the originals with function calls.
//
// 'group_attribute' must be a string valued-attribute that names the new
// functions to introduce.
//
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
// function conversion.
//
// If 'parallel_checking' is true, the unencapsulated operators are added to the
// output graph, together with a "ParallelCheck" operator, that verifies that
// the original and encapsulated subgraphs produce similar results.
//
// TODO(phawkins): currently, some information in control edges
// is not preserved. Suppose you have A and B in the main
// graph, C and D in a subgraph. B and C have control deps from A, D has control
// dep from B. Originally D must run after C, post-transformation this
// dependency is lost.
Status EncapsulateSubgraphsInFunctions(
string group_attribute, const Graph& graph_in,
const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking,
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
// The attribute that marks function calls produced by the encapsulate
// subgraphs pass and that should in turn be compiled via _XlaLaunch operators.
extern const char* const kXlaCompiledKernelAttr;
// Does `node` have the kXlaCompiledKernelAttr attribute?
bool IsXlaCompiledKernel(const Node& node);
// Functions produce by the EncapsulateSubgraphs pass have their arguments
// ordered such that compile-time constant arguments are first in the argument
// order. The functions are annotated with the following attribute giving the
// number of constant arguments.
extern const char* const kXlaNumConstantArgsAttr;
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_

View File

@ -0,0 +1,397 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/graph/equal_graph_def.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
string* diff) {
// TODO(phawkins) use a more sophisticated equality test.
if (a.DebugString() != b.DebugString()) {
if (diff) {
*diff = strings::StrCat("Definition mismatch for function ",
a.signature().name(), ", expected:\n",
a.DebugString());
}
return false;
}
return true;
}
bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
const FunctionDefLibrary& actual, string* diff) {
std::unordered_map<string, const FunctionDef*> actual_index;
for (const FunctionDef& function : actual.function()) {
actual_index[function.signature().name()] = &function;
}
for (const FunctionDef& expected_function : expected.function()) {
auto it = actual_index.find(expected_function.signature().name());
if (it == actual_index.end()) {
if (diff) {
*diff = strings::StrCat("Did not find expected function '",
expected_function.signature().name(), "'");
}
return false;
}
if (!EqualFunctionDef(expected_function, *it->second, diff)) return false;
actual_index.erase(it);
}
if (!actual_index.empty()) {
if (diff != nullptr) {
*diff = strings::StrCat("Found unexpected function '",
actual_index.begin()->second->signature().name(),
"'");
}
return false;
}
return true;
}
#define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \
do { \
string diff; \
EXPECT_TRUE(EqualFunctionDefLibrary(actual, expected, &diff)) \
<< diff << "\nActual: " << actual.DebugString(); \
} while (false)
REGISTER_OP("InputTest").Output("o: float");
REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
REGISTER_OP("BinaryTest")
.Input("a: float")
.Input("b: float")
.Output("o: float");
REGISTER_OP("AddNLikeTest")
.Input("inputs: N * T")
.Output("sum: T")
.Attr("N: int >= 1")
.Attr("T: numbertype")
.SetIsCommutative()
.SetIsAggregate();
Node* Input(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("InputTest", opts);
}
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
return ops::UnaryOp("UnaryTest", a, opts);
}
Node* Binary(ops::NodeOut a, ops::NodeOut b,
const GraphDefBuilder::Options& opts) {
return ops::BinaryOp("BinaryTest", a, b, opts);
}
Node* AddNLike(std::vector<ops::NodeOut> inputs,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest",
opts.op_registry());
node_builder.Input(inputs);
return opts.FinalizeBuilder(&node_builder);
}
Node* ArgOp(int index, DataType type, const GraphDefBuilder::Options& opts) {
return ops::SourceOp("_Arg",
opts.WithAttr("T", type).WithAttr("index", index));
}
Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
opts.op_registry());
node_builder.Input(a).Attr("index", index);
return opts.FinalizeBuilder(&node_builder);
}
Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
Status s;
// Convert the GraphDef to a Graph
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), *library));
GraphConstructorOptions options;
options.allow_internal_ops = true;
std::unique_ptr<Graph> graph(new Graph(lib_def.get()));
s = ConvertGraphDefToGraph(options, *graphdef, graph.get());
if (!s.ok()) return s;
std::unique_ptr<Graph> graph_out;
s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
/* rewrite_subgraph_fn= */ {},
/* parallel_checking= */ false,
&graph_out, lib_def.get());
if (!s.ok()) return s;
GraphDef graphdef_out;
graph_out->ToGraphDef(&graphdef_out);
graphdef->Swap(&graphdef_out);
*library = lib_def->ToProto();
return s;
}
// If there are no marked nodes, funcification should be a no-op.
TEST(EncapsulateSubgraphsTest, NoFunctions) {
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = Input(builder.opts().WithName("A"));
Node* b = Input(builder.opts().WithName("B"));
Node* c = Unary(a, builder.opts().WithName("C"));
Binary(b, c, builder.opts().WithName("D"));
GraphDef graphdef_in;
FunctionDefLibrary library_in;
builder.ToGraphDef(&graphdef_in);
*library_in.add_function() = test::function::XTimesTwo();
GraphDef graphdef_out = graphdef_in;
FunctionDefLibrary library_out = library_in;
TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out));
TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out);
}
// Test with one function to transform.
TEST(EncapsulateSubgraphsTest, OneFunction) {
FunctionDefLibrary library;
GraphDef graphdef;
{
*library.add_function() = test::function::XTimesTwo();
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
// Give nodes 'c' and 'd' names that collide after lowercasing.
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr(
"_encapsulate", "F1"));
Binary(a, d, b1.opts().WithName("E"));
b1.ToGraphDef(&graphdef);
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"input__0:float", "input__1:float"}, {"output__2:float"}, {},
{
{{"C"}, "UnaryTest", {"input__0"}},
{{"c"}, "BinaryTest", {"input__1", "C:o:0"}, {}, {"C"}},
},
{{"output__2", "c:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Binary(a, call, b2.opts().WithName("E"));
b2.ToGraphDef(&graphdef_expected);
}
// If there are no marked nodes, funcification should be a no-op.
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two functions to transform.
TEST(EncapsulateSubgraphsTest, TwoFunctions) {
FunctionDefLibrary library;
GraphDef graphdef;
{
*library.add_function() = test::function::XTimesTwo();
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* control = Input(b1.opts().WithName("Control"));
Node* c =
Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr(
"_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr(
"_encapsulate", "F2"));
Binary(a, d, b1.opts().WithName("E"));
b1.ToGraphDef(&graphdef);
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"input__0:float"}, {"output__1:float"}, {},
{
{{"C"}, "UnaryTest", {"input__0"}},
},
{{"output__1", "C:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
"F2", {"input__0:float", "input__1:float"}, {"output__2:float"}, {},
{
{{"D"}, "BinaryTest", {"input__0", "input__1"}},
},
{{"output__2", "D:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* control = Input(b2.opts().WithName("Control"));
NodeBuilder nb("F1", "F1", lib_def.get());
nb.Input(a).ControlInput(control);
Node* call1 = b2.opts().FinalizeBuilder(&nb);
NodeBuilder nb2("F2", "F2", lib_def.get());
nb2.Input(b).Input(call1).ControlInput(control);
Node* call2 = b2.opts().FinalizeBuilder(&nb2);
Binary(a, call2, b2.opts().WithName("E"));
b2.ToGraphDef(&graphdef_expected);
}
// If there are no marked nodes, funcification should be a no-op.
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Returns a vector of node names in 'graph', sorted by name.
std::vector<string> GraphNodes(const Graph& graph) {
std::vector<string> nodes;
for (const auto& node : graph.nodes()) {
if (!node->IsSource() && !node->IsSink()) {
nodes.push_back(node->name());
}
}
std::sort(nodes.begin(), nodes.end());
return nodes;
}
// Returns a sorted vector of (src, dst) edges in 'graph'.
std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) {
std::vector<std::pair<string, string>> edges;
for (const Edge* edge : graph.edges()) {
if (edge->src()->IsSource() || edge->dst()->IsSink()) continue;
edges.emplace_back(
strings::StrCat(edge->src()->name(), ":", edge->src_output()),
strings::StrCat(edge->dst()->name(), ":", edge->dst_input()));
}
std::sort(edges.begin(), edges.end());
return edges;
}
TEST(EncapsulateSubgraphsTest, InputDeduplication) {
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
"/job:localhost/replica:0/task:0/cpu:0");
auto x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT);
auto add1 = ops::Add(root.WithOpName("add1"), x, x);
add1.node()->AddAttr("_cluster", "cluster1");
auto add2 = ops::Add(root.WithOpName("add2"), add1, add1);
add2.node()->AddAttr("_cluster", "cluster2");
auto out = ops::Mul(root.WithOpName("mul"), add1, add2);
Graph graph_before_encapsulation(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
FunctionLibraryDefinition library(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
/*parallel_checking=*/false, &graph, &library));
std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
EXPECT_EQ(expected_nodes, GraphNodes(*graph));
std::vector<std::pair<string, string>> expected_edges = {
{"cluster1:0", "cluster2:0"},
{"cluster1:0", "mul:0"},
{"cluster2:0", "mul:1"},
{"x:0", "cluster1:0"}};
EXPECT_EQ(expected_edges, GraphEdges(*graph));
}
TEST(EncapsulateSubgraphsTest, ParallelChecking) {
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
"/job:localhost/replica:0/task:0/cpu:0");
auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
auto add1 = ops::Add(root.WithOpName("add1"), x1, x2);
add1.node()->AddAttr("_cluster", "cluster1");
auto add2 = ops::Add(root.WithOpName("add2"), add1, x2);
add2.node()->AddAttr("_cluster", "cluster1");
auto out = ops::Mul(root.WithOpName("mul"), x1, add2);
Graph graph_before_encapsulation(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
FunctionLibraryDefinition library(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
/*parallel_checking=*/true, &graph, &library));
std::vector<string> expected_nodes = {
"add1", "add2", "cluster1", "cluster1_parallel_check/_0",
"mul", "x1", "x2"};
EXPECT_EQ(expected_nodes, GraphNodes(*graph));
std::vector<std::pair<string, string>> expected_edges = {
{"add1:0", "add2:0"},
{"add2:0", "cluster1_parallel_check/_0:0"},
{"cluster1:0", "cluster1_parallel_check/_0:1"},
{"cluster1_parallel_check/_0:0", "mul:1"},
{"x1:0", "add1:0"},
{"x1:0", "cluster1:0"},
{"x1:0", "mul:0"},
{"x2:0", "add1:1"},
{"x2:0", "add2:1"},
{"x2:0", "cluster1:1"},
};
EXPECT_EQ(expected_edges, GraphEdges(*graph));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,274 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
// TODO(phawkins) add a canonical copy of these operator names and refactor
// everything to use it.
const char* const kArgOp = "_Arg";
const char* const kRetValOp = "_Retval";
// Class that maintains a one-to-one original node name -> new name mapping.
// We have to normalize the names used as input and output arguments to
// match regexp "[a-z][a-z0-9_]*". Once we rename them, we risk creating
// a name collision with the other node names, so if necessary we add
// a suffix to make names unique. So if we have an input named "A" and a
// node in the function body named "a", they will be renamed to "a" and "a_0".
class NodeNameMapping {
public:
NodeNameMapping() = default;
// Normalize the input/output name and then make it unique.
string Normalize(const string& name);
// Make the node name unique.
string Uniquify(const string& name);
// Look up how a node name was previously normalized/uniquified.
// Returns empty if name was never seen.
string Renormalize(const string& name) const;
private:
string NormalizeHelper(string name) const;
string UniquifyHelper(string name);
std::unordered_set<string> used_names_;
std::unordered_map<string, string> name_mapping_;
};
string NodeNameMapping::NormalizeHelper(string name) const {
// Convert letters to lowercase and non-alphanumeric characters to '_'.
if (name.empty()) name = "unknown";
const int n = name.size();
for (int i = 0; i < n; i++) {
char c = name[i];
if (isalnum(c)) {
if (isupper(c)) {
name[i] = tolower(c);
}
} else {
name[i] = '_';
}
}
return name;
}
string NodeNameMapping::UniquifyHelper(string name) {
// If the name hasn't been used yet, use it as-is.
if (used_names_.insert(name).second) return name;
// Add a suffix to name to make it unique.
for (int i = 0;; ++i) {
const string candidate = strings::StrCat(name, "_", i);
if (used_names_.insert(candidate).second) return candidate;
}
}
string NodeNameMapping::Normalize(const string& name) {
const string normalized = UniquifyHelper(NormalizeHelper(name));
name_mapping_[name] = normalized;
return normalized;
}
string NodeNameMapping::Uniquify(const string& name) {
const string uniqued = UniquifyHelper(name);
name_mapping_[name] = uniqued;
return uniqued;
}
string NodeNameMapping::Renormalize(const string& name) const {
const auto iter = name_mapping_.find(name);
if (iter == name_mapping_.end()) return string();
return iter->second;
}
} // anonymous namespace
// Graph to FunctionDef conversion. This code is closely modeled on the Python
// code in third_party/tensorflow/python/framework/function.py.
Status GraphToFunctionDef(const Graph& graph, const string& name,
FunctionDef* fdef) {
fdef->mutable_signature()->set_name(name);
std::unordered_map<string, string> tensor_renaming;
std::unordered_map<string, string> return_values;
NodeNameMapping node_names;
for (Node const* node : graph.nodes()) {
if (!node->IsOp()) continue;
if (node->type_string() == kArgOp) {
int index;
DataType type;
GetNodeAttr(node->def(), "T", &type);
GetNodeAttr(node->def(), "index", &index);
while (fdef->signature().input_arg_size() <= index) {
fdef->mutable_signature()->add_input_arg();
}
OpDef::ArgDef* argdef =
fdef->mutable_signature()->mutable_input_arg(index);
argdef->set_type(type);
const string normalized = node_names.Normalize(node->name());
argdef->set_name(normalized);
tensor_renaming[strings::StrCat(node->name(), ":0")] = normalized;
continue;
}
if (node->type_string() == kRetValOp) {
int index;
DataType type;
GetNodeAttr(node->def(), "T", &type);
GetNodeAttr(node->def(), "index", &index);
while (fdef->signature().output_arg_size() <= index) {
fdef->mutable_signature()->add_output_arg();
}
OpDef::ArgDef* argdef =
fdef->mutable_signature()->mutable_output_arg(index);
argdef->set_type(type);
const string normalized = node_names.Normalize(node->name());
argdef->set_name(normalized);
CHECK_EQ(node->in_edges().size(), 1);
Edge const* edge = *node->in_edges().begin();
return_values[normalized] =
strings::StrCat(edge->src()->name(), ":", edge->src_output());
continue;
}
NodeDef* node_def = fdef->add_node_def();
node_def->CopyFrom(node->def());
node_def->set_name(node_names.Uniquify(node->name()));
node_def->clear_device();
// Reset input names based on graph rather than the NodeDef.
node_def->clear_input();
// Edges, indexed by dst_input.
std::vector<const Edge*> in_edges;
std::vector<const Edge*> control_edges;
for (Edge const* edge : node->in_edges()) {
if (edge->src()->IsSource()) continue;
if (edge->IsControlEdge()) {
control_edges.push_back(edge);
} else {
if (in_edges.size() <= edge->dst_input()) {
in_edges.resize(edge->dst_input() + 1);
}
in_edges[edge->dst_input()] = edge;
}
}
// Add regular inputs
for (int i = 0; i < in_edges.size(); ++i) {
const Edge* edge = in_edges[i];
if (edge == nullptr) {
return errors::InvalidArgument(
"Nonconsecutive input edges; missing "
"input edge ",
i, " for node ", node->name());
}
node_def->add_input(
strings::StrCat(edge->src()->name(), ":", edge->src_output()));
}
// Add control inputs
for (const Edge* edge : control_edges) {
node_def->add_input(strings::StrCat("^", edge->src()->name()));
}
// Populate tensor_renaming.
NameRangeMap output_ranges;
TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr,
&output_ranges));
for (const auto& output : output_ranges) {
for (int i = output.second.first; i < output.second.second; ++i) {
const string tensor_name = strings::StrCat(
node_def->name(), ":", output.first, ":", i - output.second.first);
tensor_renaming[strings::StrCat(node->name(), ":", i)] = tensor_name;
}
}
}
// Detect missing function inputs.
for (int i = 0; i < fdef->signature().input_arg_size(); ++i) {
const string& input_name = fdef->signature().input_arg(i).name();
if (input_name.empty()) {
return errors::InvalidArgument("Missing input ", i, " to function ",
name);
}
}
// Remap input names. We do this as a second pass to allow the nodes to be in
// any order.
for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) {
NodeDef* node_def = fdef->mutable_node_def(n_index);
for (int i = 0; i < node_def->input_size(); ++i) {
if (StringPiece(node_def->input(i)).starts_with("^")) {
// Control input
const string normalized =
node_names.Renormalize(node_def->input(i).substr(1));
if (normalized.empty()) {
return errors::InvalidArgument(
"Could not remap control input ", i, ", '", node_def->input(i),
"', of node '", node_def->name(), "' in function ", name);
}
*node_def->mutable_input(i) = strings::StrCat("^", normalized);
} else {
const auto iter = tensor_renaming.find(node_def->input(i));
if (iter == tensor_renaming.end()) {
return errors::InvalidArgument(
"Could not remap input ", i, ", '", node_def->input(i),
"', of node '", node_def->name(), "' in function ", name);
}
*node_def->mutable_input(i) = iter->second;
}
}
}
// Remap return values.
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
const string& ret_name = fdef->signature().output_arg(r).name();
if (ret_name.empty()) {
return errors::InvalidArgument("Missing output ", r, " to function ",
name);
}
const string& return_value = return_values[ret_name];
const auto iter = tensor_renaming.find(return_value);
if (iter == tensor_renaming.end()) {
return errors::InvalidArgument("Could not remap return value ", r, ", '",
ret_name, "', of '", return_value,
"' in function ", name);
}
(*fdef->mutable_ret())[ret_name] = iter->second;
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,33 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
#define TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Converts 'graph' to a FunctionDef 'fdef', with name 'name'.
// Closely modeled on the Python code in
// third_party/tensorflow/python/framework/function.py
Status GraphToFunctionDef(const Graph& graph, const string& name,
FunctionDef* fdef);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_

View File

@ -0,0 +1,87 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/graph/equal_graph_def.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
string* diff) {
// TODO(phawkins) use a more sophisticated equality test.
if (a.DebugString() != b.DebugString()) {
if (diff) {
*diff = strings::StrCat("Definition mismatch for function ",
a.signature().name(), ":\n", a.DebugString(),
"\n ---- vs. ----\n", b.DebugString());
}
return false;
}
return true;
}
TEST(GraphToFunctionDefTest, Basics) {
Scope root = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0);
auto b = ops::_Arg(root.WithOpName("B"), DT_FLOAT, 1);
auto c = ops::_Arg(root.WithOpName("C"), DT_FLOAT, 2);
auto d = ops::Add(root.WithOpName("D"), a, b);
auto e = ops::Add(root.WithOpName("b"), d, c);
auto f = ops::Neg(root.WithOpName("h"), e);
auto g =
ops::AddN(root.WithOpName("G"), std::initializer_list<ops::Output>{e, f});
auto h = ops::_Retval(root.WithOpName("H"), g, 0);
GraphDef graph_def;
root.ToGraphDef(&graph_def);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphConstructorOptions options;
TF_EXPECT_OK(ConvertGraphDefToGraph(options, graph_def, graph.get()));
FunctionDef fdef;
TF_EXPECT_OK(GraphToFunctionDef(*graph, "test_fn", &fdef));
FunctionDef fdef_expected = FunctionDefHelper::Create(
"test_fn", // function name
{"a: float", "b: float", "c: float"}, // inputs
{"h_0: float"}, // outputs
{}, // attrs
{
// nodes in the function body
{{"D"}, "Add", {"a", "b"}, {{"T", DT_FLOAT}}},
{{"b_0"}, "Add", {"D:z:0", "c"}, {{"T", DT_FLOAT}}},
{{"h"}, "Neg", {"b_0:z:0"}, {{"T", DT_FLOAT}}},
{{"G"}, "AddN", {"b_0:z:0", "h:y:0"}, {{"N", 2}, {"T", DT_FLOAT}}},
},
{{"h_0", "G:sum:0"}}); // return values
string diff;
bool fdefs_equal = EqualFunctionDef(fdef_expected, fdef, &diff);
EXPECT_TRUE(fdefs_equal) << diff;
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,41 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//tensorflow/compiler/tf2xla:internal",
],
)
cc_library(
name = "graphcycles",
srcs = ["graphcycles.cc"],
hdrs = ["graphcycles.h"],
deps = [
"//tensorflow/core:lib",
],
)
cc_test(
name = "graphcycles_test",
srcs = ["graphcycles_test.cc"],
deps = [
":graphcycles",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,391 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// GraphCycles provides incremental cycle detection on a dynamic
// graph using the following algorithm:
//
// A dynamic topological sort algorithm for directed acyclic graphs
// David J. Pearce, Paul H. J. Kelly
// Journal of Experimental Algorithmics (JEA) JEA Homepage archive
// Volume 11, 2006, Article No. 1.7
//
// Brief summary of the algorithm:
//
// (1) Maintain a rank for each node that is consistent
// with the topological sort of the graph. I.e., path from x to y
// implies rank[x] < rank[y].
// (2) When a new edge (x->y) is inserted, do nothing if rank[x] < rank[y].
// (3) Otherwise: adjust ranks in the neighborhood of x and y.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include <algorithm>
#include <unordered_set>
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
typedef std::unordered_set<int32> NodeSet;
template <typename T>
struct VecStruct {
typedef gtl::InlinedVector<T, 4> type;
};
template <typename T>
using Vec = typename VecStruct<T>::type;
struct Node {
Node() : in(4), out(4) {} // Small hashtables for in/out edges
int32 rank; // rank number assigned by Pearce-Kelly algorithm
bool visited; // Temporary marker used by depth-first-search
void* data; // User-supplied data
NodeSet in; // List of immediate predecessor nodes in graph
NodeSet out; // List of immediate successor nodes in graph
};
} // namespace
struct GraphCycles::Rep {
Vec<Node*> nodes_;
Vec<int32> free_nodes_; // Indices for unused entries in nodes_
// Temporary state.
Vec<int32> deltaf_; // Results of forward DFS
Vec<int32> deltab_; // Results of backward DFS
Vec<int32> list_; // All nodes to reprocess
Vec<int32> merged_; // Rank values to assign to list_ entries
Vec<int32> stack_; // Emulates recursion stack when doing depth first search
};
GraphCycles::GraphCycles() : rep_(new Rep) {}
GraphCycles::~GraphCycles() {
for (int i = 0; i < rep_->nodes_.size(); i++) {
delete rep_->nodes_[i];
}
delete rep_;
}
bool GraphCycles::CheckInvariants() const {
Rep* r = rep_;
NodeSet ranks; // Set of ranks seen so far.
for (int32 x = 0; x < r->nodes_.size(); x++) {
Node* nx = r->nodes_[x];
if (nx->visited) {
LOG(FATAL) << "Did not clear visited marker on node " << x;
}
if (!ranks.insert(nx->rank).second) {
LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank;
}
for (auto y : nx->out) {
Node* ny = r->nodes_[y];
if (nx->rank >= ny->rank) {
LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment "
<< nx->rank << "->" << ny->rank;
}
}
}
return true;
}
int32 GraphCycles::NewNode() {
if (rep_->free_nodes_.empty()) {
Node* n = new Node;
n->visited = false;
n->data = NULL;
n->rank = rep_->nodes_.size();
rep_->nodes_.push_back(n);
return n->rank;
} else {
// Preserve preceding rank since the set of ranks in use must be
// a permutation of [0,rep_->nodes_.size()-1].
int32 r = rep_->free_nodes_.back();
rep_->nodes_[r]->data = NULL;
rep_->free_nodes_.pop_back();
return r;
}
}
void GraphCycles::RemoveNode(int32 node) {
Node* x = rep_->nodes_[node];
for (auto y : x->out) {
rep_->nodes_[y]->in.erase(node);
}
for (auto y : x->in) {
rep_->nodes_[y]->out.erase(node);
}
x->in.clear();
x->out.clear();
rep_->free_nodes_.push_back(node);
}
void* GraphCycles::GetNodeData(int32 node) const {
return rep_->nodes_[node]->data;
}
void GraphCycles::SetNodeData(int32 node, void* data) {
rep_->nodes_[node]->data = data;
}
bool GraphCycles::HasEdge(int32 x, int32 y) const {
return rep_->nodes_[x]->out.find(y) != rep_->nodes_[x]->out.end();
}
void GraphCycles::RemoveEdge(int32 x, int32 y) {
rep_->nodes_[x]->out.erase(y);
rep_->nodes_[y]->in.erase(x);
// No need to update the rank assignment since a previous valid
// rank assignment remains valid after an edge deletion.
}
static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound);
static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound);
static void Reorder(GraphCycles::Rep* r);
static void Sort(const Vec<Node*>&, Vec<int32>* delta);
static void MoveToList(GraphCycles::Rep* r, Vec<int32>* src, Vec<int32>* dst);
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32>& nodes);
bool GraphCycles::InsertEdge(int32 x, int32 y) {
if (x == y) return false;
Rep* r = rep_;
Node* nx = r->nodes_[x];
if (!nx->out.insert(y).second) {
// Edge already exists.
return true;
}
Node* ny = r->nodes_[y];
ny->in.insert(x);
if (nx->rank <= ny->rank) {
// New edge is consistent with existing rank assignment.
return true;
}
// Current rank assignments are incompatible with the new edge. Recompute.
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
if (!ForwardDFS(r, y, nx->rank)) {
// Found a cycle. Undo the insertion and tell caller.
nx->out.erase(y);
ny->in.erase(x);
// Since we do not call Reorder() on this path, clear any visited
// markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf_);
return false;
}
BackwardDFS(r, x, ny->rank);
Reorder(r);
return true;
}
static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound) {
// Avoid recursion since stack space might be limited.
// We instead keep a stack of nodes to visit.
r->deltaf_.clear();
r->stack_.clear();
r->stack_.push_back(n);
while (!r->stack_.empty()) {
n = r->stack_.back();
r->stack_.pop_back();
Node* nn = r->nodes_[n];
if (nn->visited) continue;
nn->visited = true;
r->deltaf_.push_back(n);
for (auto w : nn->out) {
Node* nw = r->nodes_[w];
if (nw->rank == upper_bound) {
return false; // Cycle
}
if (!nw->visited && nw->rank < upper_bound) {
r->stack_.push_back(w);
}
}
}
return true;
}
static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound) {
r->deltab_.clear();
r->stack_.clear();
r->stack_.push_back(n);
while (!r->stack_.empty()) {
n = r->stack_.back();
r->stack_.pop_back();
Node* nn = r->nodes_[n];
if (nn->visited) continue;
nn->visited = true;
r->deltab_.push_back(n);
for (auto w : nn->in) {
Node* nw = r->nodes_[w];
if (!nw->visited && lower_bound < nw->rank) {
r->stack_.push_back(w);
}
}
}
}
static void Reorder(GraphCycles::Rep* r) {
Sort(r->nodes_, &r->deltab_);
Sort(r->nodes_, &r->deltaf_);
// Adds contents of delta lists to list_ (backwards deltas first).
r->list_.clear();
MoveToList(r, &r->deltab_, &r->list_);
MoveToList(r, &r->deltaf_, &r->list_);
// Produce sorted list of all ranks that will be reassigned.
r->merged_.resize(r->deltab_.size() + r->deltaf_.size());
std::merge(r->deltab_.begin(), r->deltab_.end(), r->deltaf_.begin(),
r->deltaf_.end(), r->merged_.begin());
// Assign the ranks in order to the collected list.
for (int32 i = 0; i < r->list_.size(); i++) {
r->nodes_[r->list_[i]]->rank = r->merged_[i];
}
}
static void Sort(const Vec<Node*>& nodes, Vec<int32>* delta) {
struct ByRank {
const Vec<Node*>* nodes;
bool operator()(int32 a, int32 b) const {
return (*nodes)[a]->rank < (*nodes)[b]->rank;
}
};
ByRank cmp;
cmp.nodes = &nodes;
std::sort(delta->begin(), delta->end(), cmp);
}
static void MoveToList(GraphCycles::Rep* r, Vec<int32>* src, Vec<int32>* dst) {
for (int32 i = 0; i < src->size(); i++) {
int32 w = (*src)[i];
(*src)[i] = r->nodes_[w]->rank; // Replace src entry with its rank
r->nodes_[w]->visited = false; // Prepare for future DFS calls
dst->push_back(w);
}
}
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32>& nodes) {
for (int32 i = 0; i < nodes.size(); i++) {
r->nodes_[nodes[i]]->visited = false;
}
}
int GraphCycles::FindPath(int32 x, int32 y, int max_path_len,
int32 path[]) const {
// Forward depth first search starting at x until we hit y.
// As we descend into a node, we push it onto the path.
// As we leave a node, we remove it from the path.
int path_len = 0;
Rep* r = rep_;
NodeSet seen;
r->stack_.clear();
r->stack_.push_back(x);
while (!r->stack_.empty()) {
int32 n = r->stack_.back();
r->stack_.pop_back();
if (n < 0) {
// Marker to indicate that we are leaving a node
path_len--;
continue;
}
if (path_len < max_path_len) {
path[path_len] = n;
}
path_len++;
r->stack_.push_back(-1); // Will remove tentative path entry
if (n == y) {
return path_len;
}
for (auto w : r->nodes_[n]->out) {
if (seen.insert(w).second) {
r->stack_.push_back(w);
}
}
}
return 0;
}
bool GraphCycles::IsReachable(int32 x, int32 y) const {
return FindPath(x, y, 0, NULL) > 0;
}
bool GraphCycles::IsReachableNonConst(int32 x, int32 y) {
if (x == y) return true;
Rep* r = rep_;
Node* nx = r->nodes_[x];
Node* ny = r->nodes_[y];
if (nx->rank >= ny->rank) {
// x cannot reach y since it is after it in the topological ordering
return false;
}
// See if x can reach y using a DFS search that is limited to y's rank
bool reachable = !ForwardDFS(r, x, ny->rank);
// Clear any visited markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf_);
return reachable;
}
bool GraphCycles::ContractEdge(int32 a, int32 b) {
CHECK(HasEdge(a, b));
RemoveEdge(a, b);
if (IsReachableNonConst(a, b)) {
// Restore the graph to its original state.
InsertEdge(a, b);
return false;
}
Node* nb = rep_->nodes_[b];
std::unordered_set<int32> out = std::move(nb->out);
std::unordered_set<int32> in = std::move(nb->in);
for (auto y : out) {
rep_->nodes_[y]->in.erase(b);
}
for (auto y : in) {
rep_->nodes_[y]->out.erase(b);
}
rep_->free_nodes_.push_back(b);
for (auto y : out) {
InsertEdge(a, y);
}
for (auto y : in) {
InsertEdge(y, a);
}
return true;
}
std::unordered_set<int32> GraphCycles::Successors(int32 node) {
return rep_->nodes_[node]->out;
}
} // namespace tensorflow

View File

@ -0,0 +1,128 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
// GraphCycles detects the introduction of a cycle into a directed
// graph that is being built up incrementally.
//
// Nodes are identified by small integers. It is not possible to
// record multiple edges with the same (source, destination) pair;
// requests to add an edge where one already exists are silently
// ignored.
//
// It is also not possible to introduce a cycle; an attempt to insert
// an edge that would introduce a cycle fails and returns false.
//
// GraphCycles uses no internal locking; calls into it should be
// serialized externally.
// Performance considerations:
// Works well on sparse graphs, poorly on dense graphs.
// Extra information is maintained incrementally to detect cycles quickly.
// InsertEdge() is very fast when the edge already exists, and reasonably fast
// otherwise.
// FindPath() is linear in the size of the graph.
// The current implementation uses O(|V|+|E|) space.
#include <unordered_set>
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// NOTE!!!
// For now a copy of this is forked to net/plaque. If you
// find a bug or add a feature, please inform the owners of the
// net/plaque copy in case it should be integrated.
// NOTE!!!
class GraphCycles {
public:
GraphCycles();
~GraphCycles();
// Allocate an unused node id and return it.
// The new node has a null pointer for its node data.
// All node identifiers passed to other routines in this interface
// must have been allocated by NewNode() and not yet deallocated
// by RemoveNode().
int32 NewNode();
// Remove "node" from the graph, deleting all edges to and from it.
// After this call the identifier "node" it may no longer be used
// as an argument to any routine until it has been reallocated with
// NewNode().
void RemoveNode(int32 node);
// Attempt to insert an edge from source_node to dest_node. If the
// edge would introduce a cycle, return false without making any
// changes. Otherwise add the edge and return true.
bool InsertEdge(int32 source_node, int32 dest_node);
// Remove any edge that exists from source_node to dest_node.
void RemoveEdge(int32 source_node, int32 dest_node);
// Return whether there is an edge directly from source_node to dest_node.
bool HasEdge(int32 source_node, int32 dest_node) const;
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. 'b' is
// removed from the graph, and edges to/from 'b' are replaced with edges
// to/from 'a'. If contracting the edge would create a cycle, does nothing
// and returns false.
bool ContractEdge(int32 a, int32 b);
// Return whether dest_node is reachable from source_node
// by following edges.
bool IsReachable(int32 source_node, int32 dest_node) const;
// A faster non-thread-safe version of IsReachable.
bool IsReachableNonConst(int32 source_node, int32 dest_node);
// Return or set the node data for a node. This data is unused
// by the implementation.
void *GetNodeData(int32 node) const;
void SetNodeData(int32 node, void *data);
// Find a path from "source" to "dest". If such a path exists, place the
// node IDs of the nodes on the path in the array path[], and return the
// number of nodes on the path. If the path is longer than max_path_len
// nodes, only the first max_path_len nodes are placed in path[]. The client
// should compare the return value with max_path_len" to see when this
// occurs. If no path exists, return 0. Any valid path stored in path[]
// will start with "source" and end with "dest". There is no guarantee that
// the path is the shortest, but no node will appear twice in the path,
// except the source and destination node if they are identical; therefore,
// the return value is at most one greater than the number of nodes in the
// graph.
int FindPath(int32 source, int32 dest, int max_path_len, int32 path[]) const;
// Check internal invariants. Crashes on failure, returns true on success.
// Expensive: should only be called from graphcycles_test.cc.
bool CheckInvariants() const;
std::unordered_set<int32> Successors(int32 node);
// ----------------------------------------------------
struct Rep;
private:
Rep *rep_; // opaque representation
TF_DISALLOW_COPY_AND_ASSIGN(GraphCycles);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_

View File

@ -0,0 +1,515 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A test for the GraphCycles interface.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include <random>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::int32;
using tensorflow::string;
// We emulate a GraphCycles object with a node vector and an edge vector.
// We then compare the two implementations.
typedef std::vector<int> Nodes;
struct Edge {
int from;
int to;
};
typedef std::vector<Edge> Edges;
// Return whether "to" is reachable from "from".
static bool IsReachable(Edges *edges, int from, int to,
std::unordered_set<int> *seen) {
seen->insert(from); // we are investigating "from"; don't do it again
if (from == to) return true;
for (int i = 0; i != edges->size(); i++) {
Edge *edge = &(*edges)[i];
if (edge->from == from) {
if (edge->to == to) { // success via edge directly
return true;
} else if (seen->find(edge->to) == seen->end() && // success via edge
IsReachable(edges, edge->to, to, seen)) {
return true;
}
}
}
return false;
}
static void PrintNodes(Nodes *nodes) {
LOG(INFO) << "NODES (" << nodes->size() << ")";
for (int i = 0; i != nodes->size(); i++) {
LOG(INFO) << (*nodes)[i];
}
}
static void PrintEdges(Edges *edges) {
LOG(INFO) << "EDGES (" << edges->size() << ")";
for (int i = 0; i != edges->size(); i++) {
int a = (*edges)[i].from;
int b = (*edges)[i].to;
LOG(INFO) << a << " " << b;
}
LOG(INFO) << "---";
}
static void PrintGCEdges(Nodes *nodes, tensorflow::GraphCycles *gc) {
LOG(INFO) << "GC EDGES";
for (int i = 0; i != nodes->size(); i++) {
for (int j = 0; j != nodes->size(); j++) {
int a = (*nodes)[i];
int b = (*nodes)[j];
if (gc->HasEdge(a, b)) {
LOG(INFO) << a << " " << b;
}
}
}
LOG(INFO) << "---";
}
static void PrintTransitiveClosure(Nodes *nodes, Edges *edges,
tensorflow::GraphCycles *gc) {
LOG(INFO) << "Transitive closure";
for (int i = 0; i != nodes->size(); i++) {
for (int j = 0; j != nodes->size(); j++) {
int a = (*nodes)[i];
int b = (*nodes)[j];
std::unordered_set<int> seen;
if (IsReachable(edges, a, b, &seen)) {
LOG(INFO) << a << " " << b;
}
}
}
LOG(INFO) << "---";
}
static void PrintGCTransitiveClosure(Nodes *nodes,
tensorflow::GraphCycles *gc) {
LOG(INFO) << "GC Transitive closure";
for (int i = 0; i != nodes->size(); i++) {
for (int j = 0; j != nodes->size(); j++) {
int a = (*nodes)[i];
int b = (*nodes)[j];
if (gc->IsReachable(a, b)) {
LOG(INFO) << a << " " << b;
}
}
}
LOG(INFO) << "---";
}
static void CheckTransitiveClosure(Nodes *nodes, Edges *edges,
tensorflow::GraphCycles *gc) {
std::unordered_set<int> seen;
for (int i = 0; i != nodes->size(); i++) {
for (int j = 0; j != nodes->size(); j++) {
seen.clear();
int a = (*nodes)[i];
int b = (*nodes)[j];
bool gc_reachable = gc->IsReachable(a, b);
CHECK_EQ(gc_reachable, gc->IsReachableNonConst(a, b));
bool reachable = IsReachable(edges, a, b, &seen);
if (gc_reachable != reachable) {
PrintEdges(edges);
PrintGCEdges(nodes, gc);
PrintTransitiveClosure(nodes, edges, gc);
PrintGCTransitiveClosure(nodes, gc);
LOG(FATAL) << "gc_reachable " << gc_reachable << " reachable "
<< reachable << " a " << a << " b " << b;
}
}
}
}
static void CheckEdges(Nodes *nodes, Edges *edges,
tensorflow::GraphCycles *gc) {
int count = 0;
for (int i = 0; i != edges->size(); i++) {
int a = (*edges)[i].from;
int b = (*edges)[i].to;
if (!gc->HasEdge(a, b)) {
PrintEdges(edges);
PrintGCEdges(nodes, gc);
LOG(FATAL) << "!gc->HasEdge(" << a << ", " << b << ")";
}
}
for (int i = 0; i != nodes->size(); i++) {
for (int j = 0; j != nodes->size(); j++) {
int a = (*nodes)[i];
int b = (*nodes)[j];
if (gc->HasEdge(a, b)) {
count++;
}
}
}
if (count != edges->size()) {
PrintEdges(edges);
PrintGCEdges(nodes, gc);
LOG(FATAL) << "edges->size() " << edges->size() << " count " << count;
}
}
// Returns the index of a randomly chosen node in *nodes.
// Requires *nodes be non-empty.
static int RandomNode(std::mt19937 *rnd, Nodes *nodes) {
std::uniform_int_distribution<int> distribution(0, nodes->size() - 1);
return distribution(*rnd);
}
// Returns the index of a randomly chosen edge in *edges.
// Requires *edges be non-empty.
static int RandomEdge(std::mt19937 *rnd, Edges *edges) {
std::uniform_int_distribution<int> distribution(0, edges->size() - 1);
return distribution(*rnd);
}
// Returns the index of edge (from, to) in *edges or -1 if it is not in *edges.
static int EdgeIndex(Edges *edges, int from, int to) {
int i = 0;
while (i != edges->size() &&
((*edges)[i].from != from || (*edges)[i].to != to)) {
i++;
}
return i == edges->size() ? -1 : i;
}
TEST(GraphCycles, RandomizedTest) {
Nodes nodes;
Edges edges; // from, to
tensorflow::GraphCycles graph_cycles;
static const int kMaxNodes = 7; // use <= 7 nodes to keep test short
static const int kDataOffset = 17; // an offset to the node-specific data
int n = 100000;
int op = 0;
std::mt19937 rnd(tensorflow::testing::RandomSeed() + 1);
for (int iter = 0; iter != n; iter++) {
if ((iter % 10000) == 0) VLOG(0) << "Iter " << iter << " of " << n;
if (VLOG_IS_ON(3)) {
LOG(INFO) << "===============";
LOG(INFO) << "last op " << op;
PrintNodes(&nodes);
PrintEdges(&edges);
PrintGCEdges(&nodes, &graph_cycles);
}
for (int i = 0; i != nodes.size(); i++) {
ASSERT_EQ(reinterpret_cast<intptr_t>(graph_cycles.GetNodeData(i)),
i + kDataOffset)
<< " node " << i;
}
CheckEdges(&nodes, &edges, &graph_cycles);
CheckTransitiveClosure(&nodes, &edges, &graph_cycles);
std::uniform_int_distribution<int> distribution(0, 5);
op = distribution(rnd);
switch (op) {
case 0: // Add a node
if (nodes.size() < kMaxNodes) {
int new_node = graph_cycles.NewNode();
ASSERT_NE(-1, new_node);
VLOG(1) << "adding node " << new_node;
ASSERT_EQ(0, graph_cycles.GetNodeData(new_node));
graph_cycles.SetNodeData(
new_node, reinterpret_cast<void *>(
static_cast<intptr_t>(new_node + kDataOffset)));
ASSERT_GE(new_node, 0);
for (int i = 0; i != nodes.size(); i++) {
ASSERT_NE(nodes[i], new_node);
}
nodes.push_back(new_node);
}
break;
case 1: // Remove a node
if (nodes.size() > 0) {
int node_index = RandomNode(&rnd, &nodes);
int node = nodes[node_index];
nodes[node_index] = nodes.back();
nodes.pop_back();
VLOG(1) << "removing node " << node;
graph_cycles.RemoveNode(node);
int i = 0;
while (i != edges.size()) {
if (edges[i].from == node || edges[i].to == node) {
edges[i] = edges.back();
edges.pop_back();
} else {
i++;
}
}
}
break;
case 2: // Add an edge
if (nodes.size() > 0) {
int from = RandomNode(&rnd, &nodes);
int to = RandomNode(&rnd, &nodes);
if (EdgeIndex(&edges, nodes[from], nodes[to]) == -1) {
if (graph_cycles.InsertEdge(nodes[from], nodes[to])) {
Edge new_edge;
new_edge.from = nodes[from];
new_edge.to = nodes[to];
edges.push_back(new_edge);
} else {
std::unordered_set<int> seen;
ASSERT_TRUE(IsReachable(&edges, nodes[to], nodes[from], &seen))
<< "Edge " << nodes[to] << "->" << nodes[from];
}
}
}
break;
case 3: // Remove an edge
if (edges.size() > 0) {
int i = RandomEdge(&rnd, &edges);
int from = edges[i].from;
int to = edges[i].to;
ASSERT_EQ(i, EdgeIndex(&edges, from, to));
edges[i] = edges.back();
edges.pop_back();
ASSERT_EQ(-1, EdgeIndex(&edges, from, to));
VLOG(1) << "removing edge " << from << " " << to;
graph_cycles.RemoveEdge(from, to);
}
break;
case 4: // Check a path
if (nodes.size() > 0) {
int from = RandomNode(&rnd, &nodes);
int to = RandomNode(&rnd, &nodes);
int32 path[2 * kMaxNodes];
int path_len = graph_cycles.FindPath(nodes[from], nodes[to],
2 * kMaxNodes, path);
std::unordered_set<int> seen;
bool reachable = IsReachable(&edges, nodes[from], nodes[to], &seen);
bool gc_reachable = graph_cycles.IsReachable(nodes[from], nodes[to]);
ASSERT_EQ(gc_reachable,
graph_cycles.IsReachableNonConst(nodes[from], nodes[to]));
ASSERT_EQ(path_len != 0, reachable);
ASSERT_EQ(path_len != 0, gc_reachable);
// In the following line, we add one because a node can appear
// twice, if the path is from that node to itself, perhaps via
// every other node.
ASSERT_LE(path_len, kMaxNodes + 1);
if (path_len != 0) {
ASSERT_EQ(nodes[from], path[0]);
ASSERT_EQ(nodes[to], path[path_len - 1]);
for (int i = 1; i < path_len; i++) {
ASSERT_NE(-1, EdgeIndex(&edges, path[i - 1], path[i]));
ASSERT_TRUE(graph_cycles.HasEdge(path[i - 1], path[i]));
}
}
}
break;
case 5: // Check invariants
CHECK(graph_cycles.CheckInvariants());
break;
default:
LOG(FATAL);
}
// Very rarely, test graph expansion by adding then removing many nodes.
std::bernoulli_distribution rarely(1.0 / 1024.0);
if (rarely(rnd)) {
VLOG(3) << "Graph expansion";
CheckEdges(&nodes, &edges, &graph_cycles);
CheckTransitiveClosure(&nodes, &edges, &graph_cycles);
for (int i = 0; i != 256; i++) {
int new_node = graph_cycles.NewNode();
ASSERT_NE(-1, new_node);
VLOG(1) << "adding node " << new_node;
ASSERT_GE(new_node, 0);
ASSERT_EQ(0, graph_cycles.GetNodeData(new_node));
graph_cycles.SetNodeData(
new_node, reinterpret_cast<void *>(
static_cast<intptr_t>(new_node + kDataOffset)));
for (int j = 0; j != nodes.size(); j++) {
ASSERT_NE(nodes[j], new_node);
}
nodes.push_back(new_node);
}
for (int i = 0; i != 256; i++) {
ASSERT_GT(nodes.size(), 0);
int node_index = RandomNode(&rnd, &nodes);
int node = nodes[node_index];
nodes[node_index] = nodes.back();
nodes.pop_back();
VLOG(1) << "removing node " << node;
graph_cycles.RemoveNode(node);
int j = 0;
while (j != edges.size()) {
if (edges[j].from == node || edges[j].to == node) {
edges[j] = edges.back();
edges.pop_back();
} else {
j++;
}
}
}
CHECK(graph_cycles.CheckInvariants());
}
}
}
class GraphCyclesTest : public ::testing::Test {
public:
tensorflow::GraphCycles g_;
// Test relies on ith NewNode() call returning Node numbered i
GraphCyclesTest() {
for (int i = 0; i < 100; i++) {
CHECK_EQ(i, g_.NewNode());
}
CHECK(g_.CheckInvariants());
}
bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); }
void AddMultiples() {
// For every node x > 0: add edge to 2*x, 3*x
for (int x = 1; x < 25; x++) {
EXPECT_TRUE(AddEdge(x, 2 * x)) << x;
EXPECT_TRUE(AddEdge(x, 3 * x)) << x;
}
CHECK(g_.CheckInvariants());
}
string Path(int x, int y) {
static const int kPathSize = 5;
int32 path[kPathSize];
int np = g_.FindPath(x, y, kPathSize, path);
string result;
for (int i = 0; i < np; i++) {
if (i >= kPathSize) {
result += " ...";
break;
}
if (!result.empty()) result.push_back(' ');
char buf[20];
snprintf(buf, sizeof(buf), "%d", path[i]);
result += buf;
}
return result;
}
};
TEST_F(GraphCyclesTest, NoCycle) {
AddMultiples();
CHECK(g_.CheckInvariants());
}
TEST_F(GraphCyclesTest, SimpleCycle) {
AddMultiples();
EXPECT_FALSE(AddEdge(8, 4));
EXPECT_EQ("4 8", Path(4, 8));
CHECK(g_.CheckInvariants());
}
TEST_F(GraphCyclesTest, IndirectCycle) {
AddMultiples();
EXPECT_TRUE(AddEdge(16, 9));
CHECK(g_.CheckInvariants());
EXPECT_FALSE(AddEdge(9, 2));
EXPECT_EQ("2 4 8 16 9", Path(2, 9));
CHECK(g_.CheckInvariants());
}
TEST_F(GraphCyclesTest, LongPath) {
ASSERT_TRUE(AddEdge(2, 4));
ASSERT_TRUE(AddEdge(4, 6));
ASSERT_TRUE(AddEdge(6, 8));
ASSERT_TRUE(AddEdge(8, 10));
ASSERT_TRUE(AddEdge(10, 12));
ASSERT_FALSE(AddEdge(12, 2));
EXPECT_EQ("2 4 6 8 10 ...", Path(2, 12));
CHECK(g_.CheckInvariants());
}
TEST_F(GraphCyclesTest, RemoveNode) {
ASSERT_TRUE(AddEdge(1, 2));
ASSERT_TRUE(AddEdge(2, 3));
ASSERT_TRUE(AddEdge(3, 4));
ASSERT_TRUE(AddEdge(4, 5));
g_.RemoveNode(3);
ASSERT_TRUE(AddEdge(5, 1));
}
TEST_F(GraphCyclesTest, ManyEdges) {
const int N = 50;
for (int i = 0; i < N; i++) {
for (int j = 1; j < N; j++) {
ASSERT_TRUE(AddEdge(i, i + j));
}
}
CHECK(g_.CheckInvariants());
ASSERT_TRUE(AddEdge(2 * N - 1, 0));
CHECK(g_.CheckInvariants());
ASSERT_FALSE(AddEdge(10, 9));
CHECK(g_.CheckInvariants());
}
TEST_F(GraphCyclesTest, ContractEdge) {
ASSERT_TRUE(AddEdge(1, 2));
ASSERT_TRUE(AddEdge(1, 3));
ASSERT_TRUE(AddEdge(2, 3));
ASSERT_TRUE(AddEdge(2, 4));
ASSERT_TRUE(AddEdge(3, 4));
EXPECT_FALSE(g_.ContractEdge(1, 3));
CHECK(g_.CheckInvariants());
EXPECT_TRUE(g_.HasEdge(1, 3));
EXPECT_TRUE(g_.ContractEdge(1, 2));
CHECK(g_.CheckInvariants());
EXPECT_TRUE(g_.HasEdge(1, 3));
EXPECT_TRUE(g_.HasEdge(1, 4));
EXPECT_TRUE(g_.HasEdge(3, 4));
EXPECT_TRUE(g_.ContractEdge(1, 3));
CHECK(g_.CheckInvariants());
EXPECT_TRUE(g_.HasEdge(1, 4));
}
static void BM_StressTest(int iters, int num_nodes) {
while (iters > 0) {
tensorflow::GraphCycles g;
int32 *nodes = new int32[num_nodes];
for (int i = 0; i < num_nodes; i++) {
nodes[i] = g.NewNode();
}
for (int i = 0; i < num_nodes && iters > 0; i++, iters--) {
int end = std::min(num_nodes, i + 5);
for (int j = i + 1; j < end; j++) {
if (nodes[i] >= 0 && nodes[j] >= 0) {
CHECK(g.InsertEdge(nodes[i], nodes[j]));
}
}
}
delete[] nodes;
}
}
BENCHMARK(BM_StressTest)->Range(2048, 1048576);

View File

@ -0,0 +1,37 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
// also need to run it after the graph been rewritten to have _Send nodes added
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
// name, and encapsulation might remove that node from the graph.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
EncapsulateSubgraphsPass);
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
BuildXlaLaunchOpsPass);
} // namespace tensorflow

View File

@ -0,0 +1,67 @@
# Legacy command line flags for the XLA bridge libraries.
# Please do not add more flags to this package.
# The XLA bridge libraries were written in an environment that allowed
# command-line flags to be scattered freely throughout the libraries. This
# model, while initially convenient, leads to a proliferation in unused command
# line flags in tests and binaries, and serious problems in servers, where one
# might wish parameters to be different in independent RPC calls to the same
# routine.
#
# Please don't add more flags. If you're a library author, pass options and
# parameters explicitly through the library's interface.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
cc_library(
name = "encapsulate_subgraphs_pass_flags",
srcs = ["encapsulate_subgraphs_pass_flags.cc"],
hdrs = ["encapsulate_subgraphs_pass_flags.h"],
deps =
[
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "mark_for_compilation_pass_flags",
srcs = ["mark_for_compilation_pass_flags.cc"],
hdrs = ["mark_for_compilation_pass_flags.h"],
deps =
[
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "parallel_check_op_flags",
srcs = ["parallel_check_op_flags.cc"],
hdrs = ["parallel_check_op_flags.h"],
deps =
[
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,63 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module.
#include <mutex>
#include <vector>
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static EncapsulateSubgraphsPassFlags* flags;
static std::vector<Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new EncapsulateSubgraphsPassFlags;
flags->tf_xla_parallel_checking = false;
flag_list = new std::vector<Flag>({
Flag("tf_xla_parallel_checking", &flags->tf_xla_parallel_checking,
"Debug tool. Runs both JIT-compiled and interpreted graphs in "
"parallel and verifies they produce the same outputs."),
});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with the XLA bridge's
// encapsulate_subgraphs_pass module.
void AppendEncapsulateSubgraphsPassFlags(std::vector<Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
// Return a pointer to the EncapsulateSubgraphsPassFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace tensorflow

View File

@ -0,0 +1,50 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module.
#include <vector>
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
// Append to *flag_list flag definitions associated with the XLA bridge's
// encapsulate_subgraphs_pass module.
void AppendEncapsulateSubgraphsPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// The values of flags associated with the XLA bridge's
// encapsulate_subgraphs_pass module.
typedef struct {
bool tf_xla_parallel_checking; // Debug tool. Runs both JIT-compiled and
// interpreted graphs in parallel and verifies
// they produce the same outputs.
} EncapsulateSubgraphsPassFlags;
// Return a pointer to the EncapsulateSubgraphsPassFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags();
} // namespace legacy_flags
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_

View File

@ -0,0 +1,76 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
#include <mutex>
#include <vector>
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static MarkForCompilationPassFlags* flags;
static std::vector<Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new MarkForCompilationPassFlags;
flags->tf_xla_auto_jit = 0;
flags->tf_xla_min_cluster_size = 2;
flags->tf_xla_max_cluster_size = std::numeric_limits<int32>::max();
flags->tf_xla_clustering_debug = false;
flag_list = new std::vector<Flag>({
Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
"Control compilation of operators into XLA computations on CPU and "
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
"things very likely to be improved; 2 = on for everything. "
"Experimental."),
Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size,
"Minimum number of operators in an XLA compilation. Ignored for "
"operators placed on an XLA device or operators explicitly marked "
"for compilation."),
Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size,
"Maximum number of operators in an XLA compilation."),
Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),
});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with the XLA bridge's
// mark_for_compilation_pass module.
void AppendMarkForCompilationPassFlags(std::vector<Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
// Return a pointer to the MarkForCompilationPassFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace tensorflow

View File

@ -0,0 +1,59 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
#include <vector>
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
// Append to *flag_list flag definitions associated with the XLA bridge's
// mark_for_compilation_pass module.
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// The values of flags associated with the XLA bridge's
// mark_for_compilation_pass module.
typedef struct {
int32 tf_xla_auto_jit; // Control compilation of operators into XLA
// computations on CPU and GPU devices. 0 = use
// ConfigProto setting; -1 = off; 1 = on for things
// very likely to be improved; 2 = on for everything.
// Experimental.
int32 tf_xla_min_cluster_size; // Minimum number of operators in an XLA
// compilation. Ignored for operators placed
// on an XLA device or operators explicitly
// marked for compilation.
int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA
// compilation.
bool tf_xla_clustering_debug; // Dump graphs during XLA compilation.
} MarkForCompilationPassFlags;
// Return a pointer to the MarkForCompilationPassFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
} // namespace legacy_flags
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_

View File

@ -0,0 +1,68 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legacy flags for the XLA bridge's parallel_check_op module.
#include <mutex>
#include <vector>
#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static ParallelCheckOpFlags* flags;
static std::vector<Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new ParallelCheckOpFlags;
flags->parallel_check_failfast = true;
flags->parallel_check_atol = "1e-5";
flags->parallel_check_rtol = "1e-5";
flag_list = new std::vector<Flag>({
Flag("parallel_check_failfast", &flags->parallel_check_failfast,
"Fail immediately on first parallel-check comparison error."),
Flag("parallel_check_atol", &flags->parallel_check_atol,
"Absolute error tolerance for parallel-check comparison."),
Flag("parallel_check_rtol", &flags->parallel_check_rtol,
"Relative error tolerance for parallel-check comparison."),
});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with the XLA bridge's
// parallel_check_op module.
void AppendParallelCheckOpFlags(std::vector<Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
// Return a pointer to the ParallelCheckOpFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
ParallelCheckOpFlags* GetParallelCheckOpFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace tensorflow

View File

@ -0,0 +1,52 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
// Legacy flags for the XLA bridge's parallel_check_op module.
#include <vector>
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
// Append to *flag_list flag definitions associated with the XLA bridge's
// parallel_check_op module.
void AppendParallelCheckOpFlags(std::vector<tensorflow::Flag>* flag_list);
// The values of flags associated with the XLA bridge's
// parallel_check_op module.
typedef struct {
bool parallel_check_failfast; // Fail immediately on first parallel-check
// comparison error.
string parallel_check_atol; // Absolute error tolerance for parallel-check
// comparison.
string parallel_check_rtol; // Relative error tolerance for parallel-check
// comparison.
} ParallelCheckOpFlags;
// Return a pointer to the ParallelCheckOpFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
ParallelCheckOpFlags* GetParallelCheckOpFlags();
} // namespace legacy_flags
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_

View File

@ -0,0 +1,534 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include <atomic>
#include <deque>
#include <limits>
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
const char* const kXlaClusterAttr = "_XlaCluster";
namespace {
bool HasXLAKernel(const NodeDef& node_def, DeviceType jit_device_type) {
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
// is really a kind of function call and will be handled by
// IsCompilableCall().
if (node_def.op() == "SymbolicGradient") return false;
return FindKernelDef(jit_device_type, node_def, nullptr, nullptr).ok();
}
// Make sure we don't recurse infinitely on recursive functions.
const int kMaxRecursionDepth = 5;
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
int depth, FunctionLibraryRuntime* lib_runtime);
// Tests whether 'while_def' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a
// while loop to be compilable.
bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type,
int depth, FunctionLibraryRuntime* lib_runtime) {
VLOG(2) << "Loop marking: " << while_def.op();
const NameAttrList* name_attr;
NodeDef call;
Status status;
status = GetNodeAttr(while_def, "cond", &name_attr);
if (!status.ok()) {
VLOG(2) << "Missing 'cond' attribute on While node.";
return false;
}
const string cond_func = name_attr->name();
call.set_name("while_cond");
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
VLOG(2) << "Can't compile loop condition: " << cond_func;
return false;
}
status = GetNodeAttr(while_def, "body", &name_attr);
if (!status.ok()) {
VLOG(2) << "Missing 'body' attribute on While node.";
return false;
}
const string body_func = name_attr->name();
call.set_name("while_body");
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
VLOG(2) << "Can't compile loop body: " << body_func;
return false;
}
VLOG(2) << "Loop is compilable.";
return true;
}
// Tests whether 'call_def' is a call to a completely compilable function.
// Every operator in the function must be compilable for a function to be
// compilable.
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
int depth, FunctionLibraryRuntime* lib_runtime) {
VLOG(2) << "Function marking: " << call_def.op();
if (depth > kMaxRecursionDepth) {
VLOG(2) << "Function depth limit exceeded";
return false;
}
FunctionLibraryRuntime::Handle handle;
Status status =
lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle);
if (!status.ok()) {
VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
return false;
}
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
CHECK(fbody);
for (Node* node : fbody->graph->nodes()) {
if (node->IsSource() || node->IsSink()) continue;
if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue;
if (node->def().op() == "While") {
// Handle functional While loop (not in open source build).
return IsCompilableWhile(node->def(), jit_device_type, depth + 1,
lib_runtime);
}
if (!HasXLAKernel(node->def(), jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, depth + 1,
lib_runtime)) {
VLOG(2) << "Function marking failed: unsupported op " << node->name()
<< ": " << node->def().ShortDebugString();
return false;
}
}
VLOG(2) << "Function is compilable: " << call_def.op();
return true;
}
// Returns the DeviceType corresponding to 'device'.
Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
return errors::Internal("Malformed assigned device '", device, "'");
}
*device_type = DeviceType(parsed.type);
return Status::OK();
}
Status FindCompilationCandidates(
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
std::unordered_set<Node*>* candidates) {
OptimizerOptions opts;
std::unique_ptr<FunctionLibraryRuntime> lib_runtime(NewFunctionLibraryRuntime(
nullptr, env, nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts));
for (Node* node : graph.nodes()) {
if (node->IsSource() || node->IsSink()) continue;
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceTypeOfDevice(node->assigned_device_name(), &device_type));
if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue;
const string* jit_device_name;
CHECK(XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device_name,
/*requires_jit=*/nullptr));
DeviceType jit_device_type(*jit_device_name);
if (!HasXLAKernel(node->def(), jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) {
VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
<< ": " << node->def().op();
continue;
}
if (node->def().op() == "While" &&
!IsCompilableWhile(node->def(), jit_device_type, 0,
lib_runtime.get())) {
continue;
}
candidates->insert(node);
}
return Status::OK();
}
// Union-Find data structure used to compute clusters. We use our own
// implementation because we want one key feature: when merging clusters, we
// need to know which value becomes the representative of the merged clusters.
// We use the representatives to name nodes in a cycle detection graph, and we
// need to control which node is named.
// TODO(phawkins): consider merging this code with union-find implementations
// in Tensorflow, e.g., in SimplePlacer.
class Cluster {
public:
Cluster();
int Size() { return FindRoot()->size_; }
// Merges this cluster with 'other'. This cluster's representative becomes
// the representative of the merged cluster; the representative of 'other'
// is ignored.
void Merge(Cluster* other);
// Each cluster has an associated integer 'representative', initialized to -1
// by default.
int GetRepresentative() { return FindRoot()->representative_; }
void SetRepresentative(int representative) {
FindRoot()->representative_ = representative;
}
private:
// Finds the root element of the cluster. Performs path compression.
Cluster* FindRoot();
int representative_;
int rank_;
int size_; // Size of the cluster.
Cluster* parent_;
};
Cluster::Cluster()
: representative_(-1), rank_(0), size_(1), parent_(nullptr) {}
void Cluster::Merge(Cluster* other) {
Cluster* a = FindRoot();
Cluster* b = other->FindRoot();
if (a == b) return;
if (a->rank_ > b->rank_) {
b->parent_ = a;
a->size_ += b->size_;
return;
}
a->parent_ = b;
if (a->rank_ == b->rank_) {
b->rank_++;
}
b->representative_ = a->representative_;
b->size_ += a->size_;
}
Cluster* Cluster::FindRoot() {
if (!parent_) return this;
// Path compression: update intermediate nodes to point to the root of the
// equivalence class.
parent_ = parent_->FindRoot();
return parent_;
}
} // anonymous namespace
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
Device* device = flr->device();
const string* jit_device_name;
CHECK(XlaOpRegistry::GetJitDevice(device->device_type(), &jit_device_name,
/*requires_jit=*/nullptr));
DeviceType jit_device_type(*jit_device_name);
return IsCompilableCall(ndef, jit_device_type, 0, flr);
}
Status MarkForCompilationPass::Run(
const GraphOptimizationPassOptions& options) {
// TODO(phawkins): precompute the "GetJitDevice" properties each device ahead
// of time.
OptimizerOptions::GlobalJitLevel global_jit_level =
options.session_options->config.graph_options()
.optimizer_options()
.global_jit_level();
if (global_jit_level == OptimizerOptions::DEFAULT) {
// To set compilation to be on by default, change the following line.
global_jit_level = OptimizerOptions::OFF;
}
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
if (flags->tf_xla_auto_jit == -1 ||
(1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
// If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
// the setting in ConfigProto.
global_jit_level =
static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
}
const FunctionLibraryDefinition* fld = options.flib_def;
auto is_compilable = [global_jit_level, fld](const Node* node,
const DeviceType& device_type) {
const string* jit_device;
bool requires_jit;
if (!XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device,
&requires_jit)) {
return false;
}
// If this device requires a JIT, we must say yes.
if (requires_jit) return true;
// If there is a _XlaCompile annotation, use its value.
bool compile = false;
Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile);
if (status.ok()) return compile;
status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile);
if (status.ok()) return compile;
// Otherwise use the value of global_jit_level.
return global_jit_level > 0;
};
return RunImpl(options, is_compilable);
}
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
static bool IsShapeConsumerOp(const Node& node) {
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
node.type_string() == "Size";
}
// Sequence number generator to ensure clusters have unique names.
static std::atomic<int64> cluster_sequence_num;
Status MarkForCompilationPass::RunImpl(
const GraphOptimizationPassOptions& options,
const std::function<bool(const Node*, const DeviceType&)>&
is_compilable_fn) {
VLOG(1) << "MarkForCompilationPass::Run";
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterJitKernels();
Graph* graph = options.graph->get();
std::unordered_set<Node*> compilation_candidates;
TF_RETURN_IF_ERROR(FindCompilationCandidates(
*graph, options.flib_def,
(options.session_options != nullptr) ? options.session_options->env
: Env::Default(),
is_compilable_fn, &compilation_candidates));
GraphCycles cycles;
for (int i = 0; i < graph->num_node_ids(); ++i) {
// We rely on the node IDs in the cycle detection graph being consecutive
// integers starting from 0.
CHECK_EQ(i, cycles.NewNode());
}
// Compute the loop structure of the graph.
std::vector<ControlFlowInfo> control_flow_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
// The clustering code must avoid adding cycles to the graph to prevent
// deadlock. However, the graph may contain loops, which would trigger the
// cycle detection code. To handle loops, we alter the structure of the cycle
// detection graph, disconnecting each loop from the enclosing graph.
// Specifically, we:
// * add a new "frame" node for each loop.
// * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
// to/from the corresponding frame node. In essence, we collapse the loop
// into a single node for the purpose of cycle detection in the enclosing
// graph.
// * the body of the loop should now be disconnected from the rest of the
// graph; we make it acyclic by breaking loop backedges (edges outgoing from
// "NextIteration" nodes.
// Map from frame name strings to node IDs in the cycle detection graph.
std::unordered_map<string, int> frame_nodes;
// Get the cycle graph node ID for frame 'frame_name', or add one if none
// exists.
auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) {
int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
if (frame_id < 0) {
// The emplace succeeded; we have not allocated a frame node yet.
frame_id = cycles.NewNode();
}
return frame_id;
};
for (Edge const* edge : graph->edges()) {
if (edge->dst()->IsEnter()) {
// Lift edges to an "Enter" node to the corresponding frame node.
const string& frame_name =
control_flow_info[edge->dst()->id()].frame_name;
if (!cycles.InsertEdge(edge->src()->id(),
GetOrAddFrameNodeId(frame_name))) {
return errors::Internal("Cycle detected when adding enter->frame edge");
}
continue;
}
if (edge->src()->IsExit()) {
// Lift edges from an "Exit" node to the corresponding frame node.
const string& frame_name =
control_flow_info[edge->src()->id()].frame_name;
if (!cycles.InsertEdge(GetOrAddFrameNodeId(frame_name),
edge->dst()->id())) {
return errors::Internal("Cycle detected when adding frame->exit edge");
}
// Drop the original edge.
continue;
}
if (edge->src()->IsNextIteration()) {
// Break loop back-edges.
continue;
}
if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) {
// This should never happen. All cycles in the graph should contain
// a control flow operator.
return errors::Internal(
"Found cycle in graph without control flow operator during XLA "
"compilation.");
}
}
// Each compilation candidate belongs to a cluster. The cluster's
// representative
// names the node in the 'cycles' graph that represents the cluster.
std::vector<Cluster> clusters(graph->num_node_ids());
std::deque<Cluster*> worklist;
for (Node* node : compilation_candidates) {
clusters[node->id()].SetRepresentative(node->id());
worklist.push_back(&clusters[node->id()]);
}
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle.
while (!worklist.empty()) {
int from = worklist.front()->GetRepresentative();
worklist.pop_front();
Node* node_from = graph->FindNodeId(from);
if (node_from->IsControlFlow()) {
// Control flow nodes aren't compilation candidates and should never
// appear.
return errors::Internal("Found control flow node in clustering worklist");
}
for (int to : cycles.Successors(from)) {
if (to >= graph->num_node_ids()) {
// Node is a "frame" node that is present only in the cycle detection
// graph. No clustering is possible.
continue;
}
Node* node_to = graph->FindNodeId(to);
if (compilation_candidates.find(node_to) == compilation_candidates.cend())
continue;
if (node_from->assigned_device_name() != node_to->assigned_device_name())
continue;
// Ops that consume shapes cannot be the root of a cluster. This is an
// optimization.
if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
continue;
}
// Don't exceed the maximum cluster size.
if (clusters[from].Size() + clusters[to].Size() >
flags->tf_xla_max_cluster_size) {
continue;
}
// If contracting the edge would create a cycle, bail out.
// However, just because we can't merge the clusters now does not mean
// we won't be able to merge them in the future.
// e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
// 1->3. But if we first contract 1->2 then we can later contract 1->3.
if (!cycles.ContractEdge(from, to)) continue;
// Merge the clusters. ContractEdge uses 'from' as the number of the
// merged node, so make sure 'from' is the chosen representative.
clusters[from].Merge(&clusters[to]);
worklist.push_back(&clusters[from]);
break;
}
}
// Count the number of elements in each cluster.
std::vector<int> cluster_sizes(graph->num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].GetRepresentative();
cluster_sizes[cluster]++;
}
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
// Mark clusters for compilation that:
// * are placed on a device that requires compilation (an XlaDevice),
// * are explicitly marked for compilation (_XlaCompile=true), or
// * have more than flags->tf_xla_min_cluster_size elements (applicable only
// if compilation is enabled, otherwise there will be no such candidates).
const int min_cluster_size = flags->tf_xla_min_cluster_size;
for (Node* n : compilation_candidates) {
int cluster = clusters[n->id()].GetRepresentative();
// Compile if the user marked this node _XlaCompile=true
bool compile_attr = false;
bool marked_for_compilation = false;
if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) {
marked_for_compilation = compile_attr;
} else if (options.flib_def
->GetAttr(n->def(), kXlaCompileAttr, &compile_attr)
.ok()) {
marked_for_compilation = compile_attr;
}
// Compile if this operator is placed on a device that requires
// compilation.
bool requires_jit = false;
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceTypeOfDevice(n->assigned_device_name(), &device_type));
XlaOpRegistry::GetJitDevice(device_type.type(),
/*jit_device_name=*/nullptr, &requires_jit);
// Or compile if this is a cluster of >= min_cluster_size compilable
// operators.
if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation ||
requires_jit) {
string& name = cluster_names[cluster];
if (name.empty()) {
name = strings::StrCat("cluster_", cluster_sequence_num++);
}
n->AddAttr(kXlaClusterAttr, name);
}
}
if (flags->tf_xla_clustering_debug) {
dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph,
options.flib_def);
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// An optimization passes that marks nodes that are to be compiled with
// attribute kXlaClusterAttr. Nodes with the same cluster ID will be compiled
// together.
#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
// The attribute that marks nodes to be grouped into functions by the
// encapsulate subgraphs pass.
extern const char* const kXlaClusterAttr;
// Pass that marks a subset of operators in the graph with attribute
// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass.
class MarkForCompilationPass : public GraphOptimizationPass {
public:
MarkForCompilationPass() = default;
Status Run(const GraphOptimizationPassOptions& options) override;
// Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass
// unconditionally, call RunImpl() directly.
// is_compilable_fn, if set, is a predicate that must be true for a node to
// be compiled.
Status RunImpl(const GraphOptimizationPassOptions& options,
const std::function<bool(const Node*, const DeviceType&)>&
is_compilable_fn = {});
};
// Returns true iff 'ndef' is a call to a function that is compilable. A
// function is compilable iff every operator in the function body is
// compilable.
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_

View File

@ -0,0 +1,357 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
void MarkForCompilation(std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def) {
// Assign all nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
n->set_assigned_device_name(kCpuDevice);
}
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.flib_def = flib_def;
MarkForCompilationPass pass;
CHECK(pass.RunImpl(opt_options).ok());
}
void MarkForCompilation(std::unique_ptr<Graph>* graph) {
FunctionDefLibrary flib;
FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
MarkForCompilation(graph, &flib_def);
}
std::unordered_map<string, string> GetClusters(const Graph& graph) {
std::unordered_map<string, string> ids;
for (Node* node : graph.nodes()) {
string cluster;
if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) {
CHECK(!cluster.empty());
ids[node->name()] = cluster;
}
}
return ids;
}
TEST(XlaCompilationTest, Chains) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
Node* d =
ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
builder.ToGraph(graph.get());
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
EXPECT_EQ(4, clusters.size());
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_EQ(clusters["E"], clusters["F"]);
EXPECT_NE(clusters["B"], clusters["E"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
TEST(XlaCompilationTest, UncompilableCycles) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b =
ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
builder.ToGraph(graph.get());
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
}
TEST(XlaCompilationTest, CompilableCycles) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
builder.ToGraph(graph.get());
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size());
EXPECT_EQ(clusters["A"], clusters["B"]);
EXPECT_EQ(clusters["A"], clusters["C"]);
}
TEST(XlaCompilationTest, UnsupportedTypes) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp(
"Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_COMPLEX64)
.WithAttr("value", Tensor(DT_COMPLEX64, TensorShape())));
Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
builder.ToGraph(graph.get());
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
}
TEST(XlaCompilationTest, ConcatWithConstArg) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
Tensor t(DT_INT32, TensorShape());
t.scalar<int32>()() = 0;
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* dim = ops::SourceOp("Const", builder.opts()
.WithName("Dim")
.WithAttr("dtype", DT_INT32)
.WithAttr("value", t));
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", t));
NodeBuilder concat_builder("Concat", "Concat",
builder.opts().op_registry());
concat_builder.Input(dim).Input({a, a}).Attr("N", 2);
builder.opts().FinalizeBuilder(&concat_builder);
builder.ToGraph(graph.get());
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size()); // Everything should be compiled.
}
TEST(XlaCompilationTest, FunctionCalls) {
FunctionDefLibrary flib;
*flib.add_function() = FunctionDefHelper::Define(
"CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
{{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
*flib.add_function() =
FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
{}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
std::unique_ptr<Graph> graph(new Graph(&flib_def));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
Node* a =
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
builder.ToGraph(graph.get());
}
MarkForCompilation(&graph, &flib_def);
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
EXPECT_FALSE(clusters["B"].empty());
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
// Metadata-only operators such as Shape/Rank/Size may not be the root of a
// cluster. This is partially to work around b/26800664, and partially because
// we should probably prefer to compile metadata operators with their producers
// wherever possible, rather than their consumers.
TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
// While all of the following ops are notionally compilable, none is
// permitted
// to start a cluster. So nothing should be compiled.
Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
ops::UnaryOp("Shape", d, builder.opts().WithName("C"));
builder.ToGraph(graph.get());
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
static Status GradForUnaryCwise(FunctionDef* g,
std::vector<FunctionDefHelper::Node> nodes) {
for (auto& n : nodes) {
if (n.attr.empty()) {
n.attr = {{"T", DT_FLOAT}};
}
}
*g = FunctionDefHelper::Define(
// Arg defs
{"x: float", "dy: float"},
// Ret val defs
{"dx: float"},
// Attr defs
{},
// Nodes
nodes);
return Status::OK();
}
// A gradient containing only supported operators
Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
{{"y"}, "Tanh", {"x"}},
{{"y2"}, "Square", {"y"}, {}, {"dy"}},
FunctionDefHelper::Const("one", 1.0f),
{{"a"}, "Sub", {"one", "y2"}},
{{"dx"}, "Mul", {"dy", "a"}},
});
// clang-format on
}
REGISTER_OP_GRADIENT("Supported", SupportedGrad);
// A gradient containing an unsupported operator.
Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
{{"y"}, "Tanh", {"x"}},
{{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
FunctionDefHelper::Const("one", 1.0f),
{{"a"}, "Sub", {"one", "y2"}},
{{"dx"}, "Mul", {"dy", "a"}},
});
// clang-format on
}
REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
TEST(XlaCompilationTest, SymbolicGradients) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
// Builds a Symbolic gradient for Supported
NodeBuilder b_builder("B", "SymbolicGradient",
builder.opts().op_registry());
NameAttrList b_name_attr;
b_name_attr.set_name("Supported");
b_builder.Attr("f", b_name_attr);
b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
b_builder.Attr("Tout", {DT_FLOAT});
b_builder.Input({a, a});
Node* b = builder.opts().FinalizeBuilder(&b_builder);
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
// Builds a Symbolic gradient for Unsupported
NodeBuilder d_builder("D", "SymbolicGradient",
builder.opts().op_registry());
NameAttrList d_name_attr;
d_name_attr.set_name("Unsupported");
d_builder.Attr("f", d_name_attr);
d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
d_builder.Attr("Tout", {DT_FLOAT});
d_builder.Input({c, c});
builder.opts().FinalizeBuilder(&d_builder);
builder.ToGraph(graph.get());
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
EXPECT_FALSE(clusters["B"].empty());
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
TEST(XlaCompilationTest, Loops) {
// Regression test for b/32350199, where the autoclustering code introduced a
// deadlock in a graph containing a while loop.
Scope root = Scope::NewRootScope().ExitOnError();
auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
auto c = ops::Add(root.WithOpName("C"), a, b);
auto enter = ops::Enter(root, c, "aframe");
auto next_iter = ops::NextIteration(root, enter);
auto exit = ops::Exit(root, next_iter);
auto d = ops::Add(root.WithOpName("D"), c, exit);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
root.ToGraph(graph.get());
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
// Nothing should be compiled. In particular, 'd' and 'c' must not be
// compiled.
EXPECT_EQ(0, clusters.size());
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,154 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace {
REGISTER_OP("ParallelCheck")
.Attr("T: list(type) >= 0")
.Input("expected: T")
.Input("actual: T")
.Output("result: T")
.Doc(R"doc(
Op that compares two sets of inputs for near-identity, and propagates the first.
Inequality is logged to ERROR log.
)doc");
// Inputs 2*N tensors, outputs the first N inputs.
// Logs errors if input tensor i and i + N are not (near) identical
// in any position.
class ParallelCheckOp : public OpKernel {
public:
explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
template <typename T>
int CompareTensors(DataType dtype, const char* v0, const char* v1,
int64 num_elts, int input_idx) {
int failed = 0;
const T* p0 = reinterpret_cast<const T*>(v0);
const T* p1 = reinterpret_cast<const T*>(v1);
double rtol;
legacy_flags::ParallelCheckOpFlags* flags =
legacy_flags::GetParallelCheckOpFlags();
if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(),
&rtol)) {
LOG(ERROR) << "can't convert parallel_check_rtol "
<< flags->parallel_check_rtol << " to double";
}
double atol;
if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(),
&atol)) {
LOG(ERROR) << "can't convert parallel_check_atol "
<< flags->parallel_check_atol << " to double";
}
for (int i = 0; i < num_elts; ++i) {
bool ok = (p0[i] == p1[i]);
VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i];
if (!ok) {
if (std::is_same<T, float>::value || std::is_same<T, double>::value) {
float tolerance =
std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i])));
T diff = p0[i] - p1[i];
if (diff < 0) diff = 0 - diff;
ok = (diff <= tolerance);
}
if (ok) continue;
LOG(ERROR) << "Op " << def().name() << " fails equality at output "
<< input_idx << " type " << DataTypeString(dtype)
<< " element " << i << ": std_val=" << p0[i]
<< " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]);
if (++failed > 10) break;
}
}
return failed;
}
void Compute(OpKernelContext* ctx) override {
VLOG(1) << "Compute " << def().name();
const int num_pairs = ctx->num_inputs() / 2;
for (int i = 0; i < num_pairs; ++i) {
CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs));
Tensor t0 = ctx->input(i);
Tensor t1 = ctx->input(i + num_pairs);
int64 num_elts = t0.NumElements();
CHECK_EQ(num_elts, t1.NumElements());
// Compare inputs elementwise for near-exact equality.
const char* v0 = t0.tensor_data().data();
const char* v1 = t1.tensor_data().data();
int failed = 0;
switch (ctx->input_dtype(i)) {
case DT_INT32:
failed =
CompareTensors<int32>(ctx->input_dtype(i), v0, v1, num_elts, i);
break;
case DT_INT64:
failed =
CompareTensors<int64>(ctx->input_dtype(i), v0, v1, num_elts, i);
break;
case DT_FLOAT:
failed =
CompareTensors<float>(ctx->input_dtype(i), v0, v1, num_elts, i);
break;
case DT_DOUBLE:
failed =
CompareTensors<double>(ctx->input_dtype(i), v0, v1, num_elts, i);
break;
case DT_BOOL:
failed =
CompareTensors<bool>(ctx->input_dtype(i), v0, v1, num_elts, i);
break;
default:
LOG(FATAL) << "unimpl: " << ctx->input_dtype(i);
}
if (failed > 0) {
LOG(ERROR) << "check failed for " << def().name() << " output " << i
<< " num_elts: " << num_elts;
legacy_flags::ParallelCheckOpFlags* flags =
legacy_flags::GetParallelCheckOpFlags();
if (flags->parallel_check_failfast) {
LOG(QFATAL) << "failfast on first parallel-check failure";
}
} else {
VLOG(1) << "check passed for " << def().name() << " output " << i
<< " num_elts: " << num_elts;
}
// Propagate the std value.
if (IsRefType(ctx->input_dtype(i))) {
ctx->forward_ref_input_to_ref_output(i, i);
} else {
ctx->set_output(i, ctx->input(i));
}
}
}
TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp);
};
REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU),
ParallelCheckOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,199 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include <numeric>
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
XlaCompilationCache::XlaCompilationCache(const XlaCompiler::Options& options)
: compiler_(options) {}
XlaCompilationCache::~XlaCompilationCache() = default;
string XlaCompilationCache::DebugString() {
return "XLA JIT compilation cache";
}
// Compute a string signature which encodes the shapes of the
// arguments in the supplied list.
string XlaCompilationCache::SignatureDebugString(const Signature& sig) {
string result = sig.name;
for (const auto& a : sig.arg_types) {
strings::StrAppend(&result, ",", DataTypeString(a.first),
a.second.DebugString());
}
for (const auto& v : sig.arg_values) {
strings::StrAppend(&result, "; ", v.first, ":", v.second.DebugString());
}
return result;
}
bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
if (name != other.name) return false;
if (arg_types != other.arg_types) return false;
if (arg_values.size() != other.arg_values.size()) return false;
for (int i = 0; i < arg_values.size(); ++i) {
if (arg_values[i].first != other.arg_values[i].first ||
arg_values[i].second.tensor_data() !=
other.arg_values[i].second.tensor_data()) {
return false;
}
}
return true;
}
uint64 XlaCompilationCache::Signature::Hash::operator()(
const XlaCompilationCache::Signature& signature) const {
uint64 h = std::hash<string>()(signature.name);
for (const auto& arg : signature.arg_types) {
h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
h = Hash64Combine(h, std::hash<int>()(arg.second.dims()));
for (int dim : arg.second.dim_sizes()) {
h = Hash64Combine(h, std::hash<int>()(dim));
}
}
for (const auto& arg : signature.arg_values) {
h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
h = Hash64Combine(h, Hash64(arg.second.tensor_data().data(),
arg.second.tensor_data().size()));
}
return h;
}
namespace {
// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch
// op. The first `num_constant_args` arguments must be host-memory Tensors.
std::vector<XlaCompiler::Argument> BuildArguments(int num_constant_args,
OpKernelContext* ctx) {
std::vector<XlaCompiler::Argument> args(ctx->num_inputs());
int parameter_num = 0;
for (int i = 0; i < ctx->num_inputs(); ++i) {
args[i].type = ctx->input(i).dtype();
args[i].shape = ctx->input(i).shape();
if (i < num_constant_args || ctx->input(i).NumElements() == 0) {
args[i].parameter = -1;
args[i].constant_value = ctx->input(i);
} else {
args[i].parameter = parameter_num;
++parameter_num;
}
}
return args;
}
} // namespace
Status XlaCompilationCache::Compile(
const NameAttrList& function, int num_constant_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable) {
VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
if (VLOG_IS_ON(2)) {
std::vector<string> argshapes;
VLOG(2) << "num_inputs = " << ctx->num_inputs()
<< " num_constant_args= " << num_constant_args;
for (int i = 0; i < ctx->num_inputs(); i++) {
TensorShape shape = ctx->input(i).shape();
VLOG(2) << i << ": dtype=" << ctx->input_dtype(i)
<< " present=" << ctx->has_input(i)
<< " shape=" << shape.DebugString();
argshapes.push_back(shape.DebugString());
}
VLOG(2) << "num_outputs = " << ctx->num_outputs();
for (int i = 0; i < ctx->num_outputs(); i++) {
VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i);
}
}
Signature signature;
signature.name = Canonicalize(function.name(), function.attr());
for (int i = 0; i < ctx->num_inputs(); ++i) {
signature.arg_types.emplace_back(ctx->input_dtype(i),
ctx->input(i).shape());
if (i < num_constant_args) {
signature.arg_values.emplace_back(i, ctx->input(i));
}
}
VLOG(2) << "Signature: " << SignatureDebugString(signature);
// The outer lock protects the existence of the cache entry. It does not
// protect the contents of the cache entry.
Entry* entry;
{
mutex_lock lock(mu_);
// Find or create a cache entry.
std::unique_ptr<Entry>& e = cache_[signature];
if (!e) {
e.reset(new Entry);
}
entry = e.get();
}
// Acquire the cache entry lock and compile, if necessary.
// TODO(phawkins): this locking will need to be restructured when we implement
// cache eviction.
mutex_lock entry_lock(entry->mu);
if (!entry->compiled) {
// Do the actual JIT compilation without holding the lock (it can take
// a long time.)
std::vector<XlaCompiler::Argument> args =
BuildArguments(num_constant_args, ctx);
std::unique_ptr<FunctionLibraryRuntime> flr(NewFunctionLibraryRuntime(
compiler_.device_mgr(), ctx->env(), compiler_.device(),
TF_GRAPH_DEF_VERSION,
ctx->function_library()->GetFunctionLibraryDefinition(),
OptimizerOptions(), nullptr /* custom_kernel_creator */));
entry->compiled = true;
entry->compilation_status = compiler_.CompileFunction(
flr.get(), function, args, &entry->compilation_result);
}
*compilation_result = &entry->compilation_result;
if (entry->compilation_status.ok() && executable) {
if (entry->executable == nullptr &&
!entry->compilation_result.computation.IsNull()) {
entry->compilation_status = compiler_.BuildExecutable(
entry->compilation_result, &entry->executable);
}
*executable = entry->executable.get();
}
Status status = entry->compilation_status;
return status;
}
} // namespace tensorflow

View File

@ -0,0 +1,112 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
// The XlaCompilationCache class caches the results of the XlaCompiler class,
// which converts a Tensorflow graph into a compiled XLA compilation.
//
// Since XLA computations must have static shapes, the cache generates a new
// XLA computation for each new set of input shapes.
//
// Currently no cache eviction policy is implemented and the cache grows without
// bound.
class XlaCompilationCache : public ResourceBase {
public:
explicit XlaCompilationCache(const XlaCompiler::Options& options);
~XlaCompilationCache() override;
// Compiles a function into a XlaCompiler::CompilationResult that can be used
// to execute an XLA Computation. `compilation_result` must be non-null.
// If `executable` is non-null, also builds an xla::LocalExecutable and sets
// `executable to point to it. The resulting executable pointer may be null if
// the computation has no non-constant outputs.
// Compilation results are cached.
Status Compile(const NameAttrList& function, int num_constant_args,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable);
xla::Client* client() const { return compiler_.client(); }
string DebugString() override;
private:
XlaCompiler compiler_;
std::unique_ptr<FunctionLibraryRuntime> function_library_runtime_;
// Describes the types, shapes and any compile-time constant arguments
// to a kernel.
struct Signature {
string name;
std::vector<std::pair<DataType, TensorShape>> arg_types;
// List of (argument #, value) pairs for arguments whose values are
// part of the JIT signature, and that are therefore constants in any given
// JIT compilation. Tensors must be in host memory.
std::vector<std::pair<int, Tensor>> arg_values;
bool operator==(const Signature& other) const;
struct Hash {
uint64 operator()(const Signature& signature) const;
};
};
static string SignatureDebugString(const Signature& sig);
// The value associated with a cache entry.
struct Entry {
mutex mu;
// Have we tried compiling this entry?
bool compiled = false;
// Did compilation succeed?
Status compilation_status GUARDED_BY(mu);
// Output of the XlaCompiler.
XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu);
// The XLA executable compiled from <computation>. May be null if no
// executable has been built.
std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
};
mutex mu_;
std::unordered_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_

View File

@ -0,0 +1,60 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Registers the XLA_CPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "Host" (CPU) backend.
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
const char* const DEVICE_XLA_CPU = "XLA_CPU";
class XlaCpuDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override;
};
Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices) {
static XlaDeviceOpRegistrations* registrations =
RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT);
(void)registrations;
std::unique_ptr<XlaDevice> device;
TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
DEVICE_CPU_XLA_JIT, options, name_prefix,
&device));
devices->push_back(device.release());
return Status::OK();
}
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
constexpr std::array<DataType, 5> kAllXlaCpuTypes = {
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaDeviceLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
} // namespace tensorflow

View File

@ -0,0 +1,219 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_device.h"
#include <stdlib.h>
#include <unordered_set>
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/stream_executor_util.h"
namespace tensorflow {
/* static */ Status XlaDevice::Create(
const string& platform_name, const string& device_name, int device_ordinal,
const string& jit_device_name, const SessionOptions& options,
const string& name_prefix, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
<< device_ordinal;
// These are no-ops if they have already been done previously for
// this device_name/jit_device_name pair.
XlaOpRegistry::RegisterJitKernels();
XlaOpRegistry::RegisterJitDevice(device_name, jit_device_name,
/*requires_jit=*/true);
auto platform = perftools::gputools::MultiPlatformManager::PlatformWithName(
platform_name);
if (!platform.ok()) {
return StreamExecutorUtil::ConvertStatus(platform.status());
}
const DeviceAttributes attrs = Device::BuildDeviceAttributes(
strings::StrCat(name_prefix, "/device:", device_name, ":",
device_ordinal),
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
strings::StrCat("device: ", device_name, " device"));
static Allocator* allocator = new XlaDeviceAllocator;
device->reset(new XlaDevice(options, attrs, device_ordinal,
DeviceType(jit_device_name),
platform.ValueOrDie(), allocator));
return Status::OK();
}
XlaDevice::Metadata::Metadata(int device_ordinal,
perftools::gputools::Platform* platform,
const DeviceType& device_type)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
perftools::gputools::Platform* XlaDevice::Metadata::platform() const {
return platform_;
}
XlaDevice::Metadata::~Metadata() {}
xla::LocalClient* XlaDevice::Metadata::client() const {
auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
return client.ValueOrDie();
}
const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return device_type_;
}
string XlaDevice::Metadata::DebugString() { return "XLA device metadata"; }
XlaDevice::XlaDevice(const SessionOptions& options,
const DeviceAttributes& attrs, int device_ordinal,
const DeviceType& jit_device_name,
perftools::gputools::Platform* platform,
Allocator* xla_allocator)
: LocalDevice(options, attrs, xla_allocator),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(xla_allocator),
platform_(platform) {
// Store the platform in the resource manager so Ops can retrieve it
// e.g., to lazily create a XlaCompilationCache object.
TF_CHECK_OK(resource_manager()->Create<Metadata>(
resource_manager()->default_container(), "xla_metadata",
new Metadata(device_ordinal_, platform_, jit_device_name_)));
}
XlaDevice::~XlaDevice() {}
xla::LocalClient* XlaDevice::client() const {
// We lazily create the client because the platform commits to the
// details of the host hardware when the client is created, so we
// don't want to do it until we get a chance to hook the platform up
// to a simulator.
// For now GetOrCreateLocalClient always returns success when passed
// a non-null platform. If that changes we may have to plumb in some
// way to pass Status back.
return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie();
}
Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
if (attr.on_host()) {
return cpu_allocator();
} else {
return xla_allocator_;
}
}
Status XlaDevice::FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) {
VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids());
XlaDeviceContext* ctx = new XlaDeviceContext(client());
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
(*device_context_map)[n->id()] = ctx;
}
return Status::OK();
}
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
op_kernel->Compute(context);
}
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) {
VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
<< op_kernel->type_string();
op_kernel->ComputeAsync(context, done);
}
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) {
VLOG(1) << "XlaDevice::MakeTensorFromProto";
Tensor parsed(tensor_proto.dtype());
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
tensor_proto.DebugString());
}
Status status;
if (alloc_attrs.on_host()) {
*tensor = parsed;
} else {
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n;
XlaTransferManager manager(client());
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
*tensor = copy;
}
VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
return status;
}
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) {
XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* {
return new XlaDeviceDummyOp(context);
};
for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(jit_device)) {
KernelDef* def = new KernelDef(*jit_def);
def->set_device_type(device);
registrations->op_kernel_registrars.emplace_back(
new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp",
dummy_factory));
}
return registrations;
}
} // namespace tensorflow

View File

@ -0,0 +1,120 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// The XlaDevice executes a TensorFlow graph using the XLA linear algebra
// runtime.
//
// Operators assigned to an XlaDevice are compiled into XLA computations.
// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state
// is managed by XLA.
//
// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
// under different names (e.g., XLA_CPU or XLA_GPU).
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace tensorflow {
class XlaDevice : public LocalDevice {
public:
// Wrapper class to store metadata about the XlaDevice in the
// resource manager, where it can be looked up e.g., when lazily
// creating the XlaCompilationCache device.
class Metadata : public ResourceBase {
public:
Metadata(int device_ordinal, perftools::gputools::Platform* platform,
const DeviceType& device_type);
~Metadata() override;
// The index of the device on this host.
int device_ordinal() const;
perftools::gputools::Platform* platform() const;
xla::LocalClient* client() const;
const DeviceType& jit_device_type() const;
string DebugString() override;
private:
const int device_ordinal_;
const DeviceType device_type_;
perftools::gputools::Platform* platform_; // Not owned.
};
// Factory function. 'platform_name' is the name of the XLA platform.
// 'device_name' is the name of the Tensorflow device to create.
// 'jit_device_name' is the name of the corresponding JIT device.
static Status Create(const string& platform_name, const string& device_name,
int device_ordinal, const string& jit_device_name,
const SessionOptions& options, const string& name_prefix,
std::unique_ptr<XlaDevice>* device);
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
::perftools::gputools::Platform* platform,
Allocator* xla_allocator);
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
Status Sync() override { return Status::OK(); }
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override;
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override;
xla::LocalClient* client() const;
private:
// Which hardware device in the client's platform this XlaDevice controls.
const int device_ordinal_;
// The name of the device that is used to compile Ops for this XlaDevice.
const DeviceType& jit_device_name_;
Allocator* xla_allocator_; // Not owned.
::perftools::gputools::Platform* platform_; // Not owned.
};
// Builds dummy OpKernel registrations on 'device' for the JIT operators
// registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations
// object that encapsulates the kernel registrations.
struct XlaDeviceOpRegistrations {
std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
op_kernel_registrars;
};
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_

View File

@ -0,0 +1,181 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
namespace tensorflow {
// The contents of tensors allocated by XlaDeviceAllocator.
struct XlaGlobalData {
mutable mutex mu;
// May be nullptr if there is no xla::GlobalData backing this Tensor.
std::shared_ptr<xla::GlobalData> data GUARDED_BY(mu);
};
// The allocator used for Tensors assigned to the XLA device. The allocator
// doesn't actually back Tensors with storage. Instead, each tensor contains
// a XlaGlobalData that wraps XLA-managed storage.
XlaDeviceAllocator::XlaDeviceAllocator() = default;
XlaDeviceAllocator::~XlaDeviceAllocator() = default;
string XlaDeviceAllocator::Name() { return "xla"; }
void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
// Regardless of the size requested, always allocate a XlaGlobalData. Respect
// the aligment request because there is alignment checking even for Tensors
// whose data is never accessed.
void* p = port::aligned_malloc(sizeof(XlaGlobalData), alignment);
VLOG(2) << "Allocated XLA device tensor " << p;
return new (p) XlaGlobalData();
}
void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
XlaGlobalData* global_data = reinterpret_cast<XlaGlobalData*>(ptr);
VLOG(2) << "Deallocated XLA device tensor " << ptr;
global_data->~XlaGlobalData();
port::aligned_free(ptr);
}
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
// Don't run any constructors or destructors for complex objects,
// since there is no backing store for the tensor to run them
// on. strings are the only complex objects currently stored in
// Tensors. If others are added, this set of overrides must be
// extended to include them.
void XlaDeviceAllocator::RunStringCtor(string* p, size_t n) {}
void XlaDeviceAllocator::RunStringDtor(string* p, size_t n) {}
void XlaDeviceAllocator::RunResourceCtor(ResourceHandle* p, size_t n) {}
void XlaDeviceAllocator::RunResourceDtor(ResourceHandle* p, size_t n) {}
static const XlaGlobalData* CastTensorToXlaGlobalData(const Tensor& tensor) {
const XlaGlobalData* expression =
reinterpret_cast<const XlaGlobalData*>(tensor.tensor_data().data());
return expression;
}
static XlaGlobalData* CastTensorToXlaGlobalData(Tensor* tensor) {
const XlaGlobalData* expression =
reinterpret_cast<const XlaGlobalData*>(tensor->tensor_data().data());
return const_cast<XlaGlobalData*>(expression);
}
XlaTransferManager::XlaTransferManager(xla::Client* client) : client_(client) {}
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor,
StatusCallback done) const {
if (cpu_tensor->NumElements() > 0) {
VLOG(2) << "CopyCPUTensorToDevice "
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
<< " " << reinterpret_cast<const void*>(
device_tensor->tensor_data().data())
<< cpu_tensor->NumElements();
xla::Literal literal;
Status status = HostTensorToLiteral(*cpu_tensor, &literal);
if (!status.ok()) {
done(status);
return;
}
auto gd = client_->TransferToServer(literal);
if (!gd.ok()) {
done(gd.status());
return;
}
SetTensorGlobalData(
std::shared_ptr<xla::GlobalData>(std::move(gd.ValueOrDie())),
device_tensor);
} else {
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
}
done(Status::OK());
}
void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name,
Device* device,
Tensor* cpu_tensor,
StatusCallback done) {
if (device_tensor->NumElements() > 0) {
VLOG(2) << "CopyDeviceTensorToCPU"
<< reinterpret_cast<const void*>(
device_tensor->tensor_data().data())
<< " "
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
<< device_tensor->NumElements();
std::shared_ptr<xla::GlobalData> global_data =
GetTensorGlobalData(*device_tensor);
xla::Shape shape;
Status status =
TensorShapeToXLAShape(cpu_tensor->dtype(), cpu_tensor->shape(), &shape);
if (!status.ok()) {
done(status);
return;
}
auto result = client_->Transfer(*global_data, &shape);
if (!result.ok()) {
done(result.status());
return;
}
const void* src_ptr = xla::LiteralUtil::InternalData(*result.ValueOrDie());
void* dst_ptr = DMAHelper::base(cpu_tensor);
size_t total_bytes = cpu_tensor->TotalBytes();
memcpy(dst_ptr, src_ptr, total_bytes);
} else {
VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
}
done(Status::OK());
}
std::shared_ptr<xla::GlobalData> XlaTransferManager::GetTensorGlobalData(
const Tensor& tensor) {
const XlaGlobalData* data = CastTensorToXlaGlobalData(tensor);
mutex_lock lock(data->mu);
CHECK(data->data);
return data->data;
}
void XlaTransferManager::SetTensorGlobalData(
std::shared_ptr<xla::GlobalData> global_data, Tensor* tensor) {
XlaGlobalData* data = CastTensorToXlaGlobalData(tensor);
mutex_lock lock(data->mu);
data->data = std::move(global_data);
}
XlaDeviceContext::XlaDeviceContext(xla::Client* client) : manager_(client) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor,
StatusCallback done) const {
manager_.CopyCPUTensorToDevice(cpu_tensor, device, device_tensor, done);
}
void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name,
Device* device, Tensor* cpu_tensor,
StatusCallback done) {
manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
done);
}
} // namespace tensorflow

View File

@ -0,0 +1,92 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
#include <memory>
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// The allocator used for Tensors assigned to the XLA device. The allocator
// doesn't actually back Tensors with storage. Instead, each tensor is a thin
// wrapper around XLA-managed storage.
class XlaDeviceAllocator : public Allocator {
public:
XlaDeviceAllocator();
~XlaDeviceAllocator() override;
string Name() override;
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
void GetStats(AllocatorStats* stats) override;
private:
void RunStringCtor(string* p, size_t n) override;
void RunStringDtor(string* p, size_t n) override;
void RunResourceCtor(ResourceHandle* p, size_t n) override;
void RunResourceDtor(ResourceHandle* p, size_t n) override;
};
// Helper class for managing data transfers between host and XLA devices.
class XlaTransferManager {
public:
explicit XlaTransferManager(xla::Client* client);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done);
// Helper methods to get/set the xla::GlobalData backing a Tensor on the
// XlaDevice.
static std::shared_ptr<xla::GlobalData> GetTensorGlobalData(
const Tensor& tensor);
static void SetTensorGlobalData(std::shared_ptr<xla::GlobalData> global_data,
Tensor* tensor);
private:
xla::Client* client_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
// implementation must inherit from DeviceContext but otherwise just
// wraps the methods in XlaTransferManager.
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(xla::Client* client);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
StatusCallback done) const override;
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
private:
XlaTransferManager manager_;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_

View File

@ -0,0 +1,171 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_device_launch_op.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
namespace {
Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** compiler) {
XlaDevice::Metadata* metadata;
Status s = rm->Lookup<XlaDevice::Metadata>(rm->default_container(),
"xla_metadata", &metadata);
if (!s.ok()) {
return s;
}
core::ScopedUnref metadata_ref(metadata);
XlaCompiler::Options options;
options.device_type = metadata->jit_device_type();
options.client = metadata->client();
options.allow_cpu_custom_calls = false;
options.local_executable_has_hybrid_result = false;
*compiler = new XlaCompilationCache(options);
return Status::OK();
}
} // namespace
XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
const NameAttrList* func;
OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
function_ = *func;
VLOG(1) << "XlaDeviceLaunch created function="
<< Canonicalize(function_.name(), function_.attr());
DataTypeVector constant_types;
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
num_constant_args_ = constant_types.size();
}
void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "XlaDeviceLaunch::Compute "
<< Canonicalize(function_.name(), function_.attr());
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
ResourceMgr* rm = ctx->resource_manager();
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
XlaCompilationCache* compiler;
OP_REQUIRES_OK(ctx,
rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_compiler", &compiler,
[rm](XlaCompilationCache** compiler) {
return BuildCompilationCache(rm, compiler);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
// this is more obviously correct.)
core::ScopedUnref compiler_ref(compiler);
const XlaCompiler::CompilationResult* kernel;
OP_REQUIRES_OK(
ctx,
compiler->Compile(function_, num_constant_args_, ctx, &kernel, nullptr));
VLOG(1) << "Executing XLA Computation...";
OP_REQUIRES(ctx, ctx->num_outputs() == kernel->outputs.size(),
errors::Internal("Unexpected number of outputs"));
// Run the computation, if any. There might not be a computation if all
// outputs were compile-time constants.
std::vector<std::unique_ptr<xla::GlobalData>> outputs;
if (!kernel->computation.IsNull()) {
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
// Convert argument tensors to xla::GlobalData pointers.
std::vector<std::shared_ptr<xla::GlobalData>> arg_handles(
kernel->xla_input_shapes.size());
std::vector<xla::GlobalData*> arg_ptrs(kernel->xla_input_shapes.size());
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
int input_num = kernel->xla_input_shapes[i].first;
arg_handles[i] =
XlaTransferManager::GetTensorGlobalData(ctx->input(input_num));
arg_ptrs[i] = arg_handles[i].get();
}
// Execute the computation.
xla::ExecutionProfile profile;
Env* env = Env::Default();
auto start_time = env->NowMicros();
auto result = compiler->client()->Execute(
kernel->computation, arg_ptrs, &kernel->xla_output_shape, &profile);
auto elapsed = env->NowMicros() - start_time;
OP_REQUIRES(ctx, result.ok(), result.status());
VLOG(1) << "Elapsed time: " << elapsed << "us";
VLOG(1) << "ExecutionProfile: " << profile.DebugString();
if (xla::ShapeUtil::IsTuple(kernel->xla_output_shape)) {
auto outputs_or_error =
compiler->client()->DeconstructTuple(*result.ValueOrDie());
OP_REQUIRES(ctx, outputs_or_error.ok(), outputs_or_error.status());
outputs = outputs_or_error.ConsumeValueOrDie();
} else {
outputs.push_back(result.ConsumeValueOrDie());
}
}
XlaDeviceContext* device_context = ctx->op_device_context<XlaDeviceContext>();
// Copy XLA outputs to the operator's outputs.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
Tensor* output;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(i, kernel->outputs[i].shape, &output));
if (kernel->outputs[i].is_constant) {
// TODO(phawkins): mark constant _XlaLaunch outputs as HostMemory and
// remove the copy from this code.
Status status;
device_context->CopyCPUTensorToDevice(
&kernel->outputs[i].constant_value, nullptr, output,
[&status](const Status& s) { status = s; });
if (!status.ok()) {
ctx->SetStatus(status);
return;
}
} else {
CHECK_LT(output_num, outputs.size());
XlaTransferManager::SetTensorGlobalData(
std::shared_ptr<xla::GlobalData>(std::move(outputs[output_num])),
output);
++output_num;
}
}
VLOG(1) << "Done";
}
XlaDeviceLaunchOp::~XlaDeviceLaunchOp() {
VLOG(1) << "XlaDeviceLaunch destroyed";
}
} // namespace tensorflow

View File

@ -0,0 +1,50 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// The XlaDeviceLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaDeviceLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
// Once all inputs are present, and their shapes are known, the op can
// use a 'TlaJit' to compile and execute code which is specific
// to the shapes of input Tensors.
class XlaDeviceLaunchOp : public OpKernel {
public:
explicit XlaDeviceLaunchOp(OpKernelConstruction* ctx);
~XlaDeviceLaunchOp() override;
void Compute(OpKernelContext* ctx) override;
private:
NameAttrList function_;
int num_constant_args_;
Tensor dummy_tensor_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceLaunchOp);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_

View File

@ -0,0 +1,36 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
namespace tensorflow {
void XlaDeviceAssignOp::Copy(OpKernelContext* context, Tensor* lhs,
const Tensor& rhs) {
std::shared_ptr<xla::GlobalData> gd =
XlaTransferManager::GetTensorGlobalData(rhs);
XlaTransferManager::SetTensorGlobalData(std::move(gd), lhs);
}
XlaDeviceDummyOp::XlaDeviceDummyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) {
LOG(FATAL) << "Attempted to execute Op " << name() << "type " << type_string()
<< " on an XLA device. This should never happen.";
}
} // namespace tensorflow

View File

@ -0,0 +1,118 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Common kernel registrations for XLA devices.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
#include "tensorflow/compiler/jit/xla_device_launch_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/assign_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
#include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow {
// Implementation of Assign for XLA devices.
class XlaDeviceAssignOp : public AssignOp {
public:
using AssignOp::AssignOp;
void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override;
};
// Dummy OpKernel, used for kernels assigned to an XLA device that should be
// compiled. Should never be called at runtime since such ops should be
// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an
// operator on an XLA device but the compiler did not compile it.
class XlaDeviceDummyOp : public OpKernel {
public:
explicit XlaDeviceDummyOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
};
#define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
REGISTER_KERNEL_BUILDER( \
Name("_XlaLaunch").Device(DEVICE).HostMemory("constants"), KERNEL);
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
REGISTER_KERNEL_BUILDER( \
Name("_HostSend").Device(DEVICE).HostMemory("tensor"), SendOp); \
REGISTER_KERNEL_BUILDER( \
Name("_HostRecv").Device(DEVICE).HostMemory("tensor"), RecvOp); \
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp); \
REGISTER_KERNEL_BUILDER( \
Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \
ConstantOp); \
REGISTER_KERNEL_BUILDER( \
Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), \
XlaDeviceDummyOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("Variable").Device(DEVICE).TypeConstraint("dtype", TYPES), \
VariableOp); \
REGISTER_KERNEL_BUILDER( \
Name("VariableV2").Device(DEVICE).TypeConstraint("dtype", TYPES), \
VariableOp); \
REGISTER_KERNEL_BUILDER( \
Name("TemporaryVariable").Device(DEVICE).TypeConstraint("dtype", TYPES), \
TemporaryVariableOp); \
REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
.Device(DEVICE) \
.TypeConstraint("T", TYPES), \
DestroyTemporaryVariableOp); \
REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
.Device(DEVICE) \
.TypeConstraint("dtype", TYPES) \
.HostMemory("is_initialized"), \
IsVariableInitializedOp); \
REGISTER_KERNEL_BUILDER( \
Name("Assign").Device(DEVICE).TypeConstraint("T", TYPES), \
XlaDeviceAssignOp); \
\
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \
ControlTriggerOp); \
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \
NextIterationOp); \
REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \
SwitchOp); \
REGISTER_KERNEL_BUILDER( \
Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \
REGISTER_KERNEL_BUILDER(Name("LoopCond") \
.Device(DEVICE) \
.HostMemory("input") \
.HostMemory("output"), \
IdentityOp);
// TODO(phawkins): do we really need Placeholder? Should it be a real
// implementation of Placeholder?
// TODO(b/32507444): the registrations for the control flow operators are
// temporary and exist primarily to work around a bug in the graph partitioning
// code.
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_

View File

@ -0,0 +1,65 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
const char* const DEVICE_XLA_GPU = "XLA_GPU";
class XlaGpuDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override;
};
Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices) {
static XlaDeviceOpRegistrations* registrations =
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
(void)registrations;
std::unique_ptr<XlaDevice> device;
Status status =
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, &device);
if (!status.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
LOG(WARNING) << "Failed to create XLA_GPU device: " << status;
return Status::OK();
}
devices->push_back(device.release());
return Status::OK();
}
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
constexpr std::array<DataType, 5> kAllXlaGpuTypes = {
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaDeviceLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
} // namespace tensorflow

View File

@ -0,0 +1,342 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_local_launch_op.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/util/stream_executor_util.h"
namespace gpu = perftools::gputools;
namespace tensorflow {
REGISTER_OP("_XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
.Input("args: Targs")
.Attr("Targs: list(type) >= 0")
.Output("results: Tresults")
.Attr("Tresults: list(type) >= 0")
.Attr("function: func")
.Doc("XLA Launch Op. For use by the XLA JIT only.");
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
class XlaAllocator : public xla::DeviceMemoryAllocator {
public:
XlaAllocator(const perftools::gputools::Platform* platform,
OpKernelContext* op_context);
~XlaAllocator() override;
xla::StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
int device_ordinal, uint64 size, bool retry_on_failure = true) override;
Status Deallocate(int device_ordinal,
perftools::gputools::DeviceMemoryBase* mem) override;
// Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is
// interpreted as having data type 'dtype' and shape 'shape'.
Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype,
const TensorShape& shape, Tensor* tensor) const;
private:
OpKernelContext* const op_context_;
// Map from pointer address to the owning Tensor; used by
// MakeTensorFromBuffer. Also used to automatically release Tensors when the
// allocator is freed.
std::unordered_map<void*, Tensor> tensors_;
};
XlaAllocator::XlaAllocator(const perftools::gputools::Platform* platform,
OpKernelContext* op_context)
: xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
XlaAllocator::~XlaAllocator() = default;
xla::StatusOr<perftools::gputools::DeviceMemoryBase> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
AllocatorAttributes allocator_attrs;
allocator_attrs.set_on_host(false);
AllocationAttributes allocation_attrs;
allocation_attrs.no_retry_on_failure = !retry_on_failure;
Tensor t;
Status status = op_context_->allocate_temp(
DT_UINT8, TensorShape({static_cast<int64>(size)}), &t, allocator_attrs,
allocation_attrs);
if (!status.ok()) {
VLOG(2) << "Allocation failed " << size;
return status;
}
void* data =
reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data()));
TF_RET_CHECK(data != nullptr);
tensors_[data] = t;
return perftools::gputools::DeviceMemoryBase(data, size);
}
Status XlaAllocator::Deallocate(int device_ordinal,
perftools::gputools::DeviceMemoryBase* mem) {
if (mem->opaque() != nullptr) {
if (tensors_.erase(mem->opaque()) == 0) {
return tensorflow::errors::InvalidArgument("Unknown tensor address");
}
}
return Status::OK();
}
Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer,
DataType dtype,
const TensorShape& shape,
Tensor* out_tensor) const {
void* ptr = const_cast<void*>(buffer.opaque());
auto it = tensors_.find(ptr);
if (it == tensors_.end()) {
return errors::InvalidArgument("Unknown tensor address");
}
const Tensor& tensor = it->second;
int64 output_size = DataTypeSize(dtype) * shape.num_elements();
if (tensor.TotalBytes() == output_size) {
out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape);
} else {
Tensor slice = tensor.Slice(0, output_size);
out_tensor->UnsafeCopyFromInternal(slice, dtype, shape);
}
return Status::OK();
}
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
: OpKernel(ctx), device_type_(ctx->device_type()) {
const NameAttrList* func;
OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
function_ = *func;
DataTypeVector constant_types;
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
num_constant_args_ = constant_types.size();
}
Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) {
gpu::Platform::Id platform_id;
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id = gpu::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
platform_id = gpu::cuda::kCudaPlatformId;
} else {
return errors::InvalidArgument("Unknown device type for local _XlaLaunch");
}
auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id);
if (!platform.ok()) {
return StreamExecutorUtil::ConvertStatus(platform.status());
}
auto client =
xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie());
if (!client.ok()) {
return client.status();
}
const string* compiler_device;
if (!XlaOpRegistry::GetJitDevice(device_type_.type(), &compiler_device,
/*requires_jit=*/nullptr)) {
return errors::InvalidArgument("No JIT device registered for ",
device_type_.type());
}
XlaCompiler::Options options;
options.device_type = DeviceType(*compiler_device);
options.client = client.ValueOrDie();
options.allow_cpu_custom_calls = (platform_id == gpu::host::kHostPlatformId);
options.local_executable_has_hybrid_result = true;
*compiler = new XlaCompilationCache(options);
return Status::OK();
}
void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "XlaLocalLaunchOp::Compute "
<< Canonicalize(function_.name(), function_.attr());
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
ResourceMgr* rm = ctx->resource_manager();
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
gpu::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
XlaCompilationCache* compiler;
OP_REQUIRES_OK(ctx,
rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_compiler", &compiler,
[this](XlaCompilationCache** compiler) {
return BuildCompilationCache(compiler);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
// this is more obviously correct.)
core::ScopedUnref compiler_ref(compiler);
xla::LocalClient* client = static_cast<xla::LocalClient*>(compiler->client());
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
OP_REQUIRES_OK(ctx,
compiler->Compile(function_, num_constant_args_, ctx, &kernel,
&executable));
VLOG(1) << "Executing XLA Computation...";
// Builds an XLA allocator for the device.
XlaAllocator xla_allocator(client->platform(), ctx);
XlaLocalRuntimeContext local_runtime_context;
std::unique_ptr<xla::ShapedBuffer> output;
bool output_is_tuple;
if (!kernel->computation.IsNull()) {
// Build xla::ShapedBuffers that point directly to the Tensor buffers.
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
arg_buffers.reserve(kernel->xla_input_shapes.size() + 1);
arg_buffers.resize(kernel->xla_input_shapes.size());
std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size());
// Pass remaining parameters.
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
int arg_num = kernel->xla_input_shapes[i].first;
const xla::Shape& shape = kernel->xla_input_shapes[i].second;
gpu::DeviceMemoryBase dmem(
const_cast<char*>(ctx->input(arg_num).tensor_data().data()),
ctx->input(arg_num).tensor_data().size());
arg_buffers[i] =
xla::ShapedBuffer::MakeArrayShapedBuffer(
shape, client->platform(), client->default_device_ordinal(), dmem)
.ConsumeValueOrDie();
arg_ptrs[i] = arg_buffers[i].get();
}
// Make the final parameter point at local_runtime_context.
if (kernel->requires_runtime_context) {
gpu::DeviceMemoryBase local_runtime_context_dmem(
&local_runtime_context, sizeof(local_runtime_context));
arg_buffers.push_back(
xla::ShapedBuffer::MakeArrayShapedBuffer(
xla::ShapeUtil::MakeOpaqueShape(), client->platform(),
client->default_device_ordinal(), local_runtime_context_dmem)
.ConsumeValueOrDie());
arg_ptrs.push_back(arg_buffers.back().get());
}
// Execute the computation.
VLOG(2) << "Executing computation.";
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(&xla_allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
Env* env = Env::Default();
auto start_time = env->NowMicros();
auto run_result = executable->Run(arg_ptrs, run_options);
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
if (local_runtime_context.error) {
ctx->CtxFailure(errors::InvalidArgument(
"Compiled kernel returned error: ", local_runtime_context.error_msg));
return;
}
output = std::move(run_result.ValueOrDie());
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
// Computation output should always be a tuple.
if (VLOG_IS_ON(2)) {
VLOG(2) << "Result tuple shape: " << output->shape().DebugString();
}
output_is_tuple = xla::ShapeUtil::IsTuple(output->shape());
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
if (kernel->outputs[i].is_constant) {
// Output is a constant
const Tensor& const_tensor = kernel->outputs[i].constant_value;
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
VLOG(1) << "Constant output tensor on device";
Tensor* output_tensor;
TF_CHECK_OK(
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
const void* src_ptr = DMAHelper::base(&const_tensor);
void* dst_ptr = DMAHelper::base(output_tensor);
gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
} else {
// No copy required.
ctx->set_output(i, const_tensor);
}
} else {
const TensorShape& shape = kernel->outputs[i].shape;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
gpu::DeviceMemoryBase buffer;
if (output_is_tuple) {
buffer = output->buffer({output_num});
} else {
CHECK_EQ(0, output_num);
buffer = output->buffer({});
}
Tensor output_tensor;
// Looks up the owning Tensor by buffer address.
OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer(
buffer, ctx->expected_output_dtype(i), shape,
&output_tensor));
ctx->set_output(i, output_tensor);
++output_num;
}
if (VLOG_IS_ON(3)) {
VLOG(3) << ctx->mutable_output(i)->DebugString();
}
}
VLOG(1) << "Done";
}
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
VLOG(1) << "XlaLocalLaunchOp destroyed";
}
REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU),
XlaLocalLaunchOp);
REGISTER_KERNEL_BUILDER(
Name("_XlaLaunch").Device(DEVICE_GPU).HostMemory("constants"),
XlaLocalLaunchOp);
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
#define TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
// Once all inputs are present, and their shapes are known, the op can
// use a 'XlaCompilationCache' to compile and execute code which is specific
// to the shapes of input Tensors.
// XlaLocalLaunchOp uses xla::LocalClient::ExecuteLocally and passes
// arguments into/out of XLA in device memory.
class XlaLocalLaunchOp : public OpKernel {
public:
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
~XlaLocalLaunchOp() override;
void Compute(OpKernelContext* ctx) override;
private:
// Builds a XlaCompilationCache class suitable for the current device.
Status BuildCompilationCache(XlaCompilationCache** compiler);
DeviceType device_type_;
NameAttrList function_;
int num_constant_args_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_

View File

@ -0,0 +1,352 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//tensorflow/compiler/tf2xla:internal",
],
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites")
generate_backend_suites()
py_library(
name = "xla_test",
testonly = 1,
srcs = ["xla_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/compiler:compiler_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:variables",
],
)
cc_library(
name = "depthwise_conv2d_test_kernel",
testonly = 1,
srcs = ["depthwise_conv2d_test_kernel.cc"],
deps = ["//tensorflow/core:framework_lite"],
)
tf_xla_py_test(
name = "binary_ops_test",
size = "small",
srcs = ["binary_ops_test.py"],
shard_count = 5,
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "clustering_test",
size = "small",
srcs = ["clustering_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "concat_ops_test",
size = "small",
srcs = ["concat_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:array_ops_gen",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradient_checker",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "conv2d_test",
size = "medium",
srcs = ["conv2d_test.py"],
shard_count = 10,
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "dynamic_stitch_test",
size = "small",
srcs = ["dynamic_stitch_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "function_test",
size = "small",
srcs = ["function_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "lrn_ops_test",
size = "medium",
srcs = ["lrn_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "nary_ops_test",
size = "small",
srcs = ["nary_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "nullary_ops_test",
size = "small",
srcs = ["nullary_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "pooling_ops_test",
size = "medium",
srcs = ["pooling_ops_test.py"],
shard_count = 10,
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "reduce_ops_test",
size = "medium",
srcs = ["reduce_ops_test.py"],
shard_count = 5,
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "ternary_ops_test",
size = "small",
srcs = ["ternary_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "unary_ops_test",
size = "small",
srcs = ["unary_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "xla_device_test",
size = "small",
srcs = ["xla_device_test.py"],
additional_deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
],
)
cuda_py_test(
name = "jit_test",
size = "medium",
srcs = ["jit_test.py"],
additional_deps = [
"//tensorflow/contrib/compiler:compiler_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
],
)
cc_library(
name = "randomized_tests_library",
testonly = 1,
srcs = ["randomized_tests.cc"],
deps = [
"//tensorflow/compiler/jit",
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_util",
],
)
tf_cuda_cc_test(
name = "randomized_tests",
# This test is randomized, so only run it if explicitly requested.
tags = [
"manual",
"noguitar",
"notap",
],
deps = [":randomized_tests_library"],
)
py_library(
name = "lstm",
testonly = 1,
srcs = ["lstm.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
],
)
cuda_py_test(
name = "lstm_test",
srcs = ["lstm_test.py"],
additional_deps = [
":lstm",
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:variables",
],
)
# An example of ahead-of-time compilation using tfcompile. The
# lstm_layer_inference.pbtxt file was generated by running lstm_test
# --dump_graph_dir, and the config file was written by hand.
#
# Run the following to build a minimal benchmark of the computation on Android:
# $ bazel build -c opt --config=android_arm \
# third_party/tensorflow/compiler/tests:lstm_layer_inference_benchmark
#
# Currently the resulting binary size is ~190KB
tf_library(
name = "lstm_layer_inference",
testonly = 1,
config = "lstm_layer_inference.config.pbtxt",
cpp_class = "LSTMLayerInference",
graph = "lstm_layer_inference.pbtxt",
tags = ["manual"],
tfcompile_flags = "--xla_cpu_multi_thread_eigen=false",
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,749 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for binary operators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
class BinaryOpsTest(XLATestCase):
"""Test cases for binary operators."""
def _testBinary(self, op, a, b, expected, equality_test=None):
with self.test_session() as session:
with self.test_scope():
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
output = op(pa, pb)
result = session.run(output, {pa: a, pb: b})
if equality_test is None:
equality_test = self.assertAllClose
equality_test(result, expected, rtol=1e-3)
def ListsAreClose(self, result, expected, rtol):
"""Tests closeness of two lists of floats."""
self.assertEqual(len(result), len(expected))
for i in range(len(result)):
self.assertAllClose(result[i], expected[i], rtol)
def testFloatOps(self):
for dtype in self.float_types:
self._testBinary(
gen_math_ops._real_div,
np.array([3, 3, -1.5, -8, 44], dtype=dtype),
np.array([2, -2, 7, -4, 0], dtype=dtype),
expected=np.array(
[1.5, -1.5, -0.2142857, 2, float("inf")], dtype=dtype))
self._testBinary(math_ops.pow, dtype(3), dtype(4), expected=dtype(81))
self._testBinary(
math_ops.pow,
np.array([1, 2], dtype=dtype),
np.zeros(shape=[0, 2], dtype=dtype),
expected=np.zeros(shape=[0, 2], dtype=dtype))
self._testBinary(
math_ops.pow,
np.array([10, 4], dtype=dtype),
np.array([2, 3], dtype=dtype),
expected=np.array([100, 64], dtype=dtype))
self._testBinary(
math_ops.pow,
dtype(2),
np.array([3, 4], dtype=dtype),
expected=np.array([8, 16], dtype=dtype))
self._testBinary(
math_ops.pow,
np.array([[2], [3]], dtype=dtype),
dtype(4),
expected=np.array([[16], [81]], dtype=dtype))
self._testBinary(
gen_math_ops._sigmoid_grad,
np.array([4, 3, 2, 1], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([-60, -36, -14, 0], dtype=dtype))
self._testBinary(
gen_math_ops._rsqrt_grad,
np.array([4, 3, 2, 1], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([-160, -81, -28, -4], dtype=dtype))
self._testBinary(
gen_nn_ops._softplus_grad,
np.array([4, 3, 2, 1], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array(
[3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype))
self._testBinary(
gen_math_ops._tanh_grad,
np.array([4, 3, 2, 1], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([-75, -48, -21, 0], dtype=dtype))
self._testBinary(
gen_nn_ops._relu_grad,
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),
np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype),
expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10], dtype=dtype))
self._testBinary(
gen_nn_ops._relu6_grad,
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype),
np.array(
[0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype),
expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype))
self._testBinary(
gen_nn_ops._softmax_cross_entropy_with_logits,
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype),
np.array([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], dtype=dtype),
expected=[
np.array([1.44019, 2.44019], dtype=dtype),
np.array([[-0.067941, -0.112856, -0.063117, 0.243914],
[-0.367941, -0.212856, 0.036883, 0.543914]],
dtype=dtype),
],
equality_test=self.ListsAreClose)
def testIntOps(self):
for dtype in self.int_types:
self._testBinary(
gen_math_ops._truncate_div,
np.array([3, 3, -1, -9, -8], dtype=dtype),
np.array([2, -2, 7, 2, -4], dtype=dtype),
expected=np.array([1, -1, 0, -4, 2], dtype=dtype))
def testNumericOps(self):
for dtype in self.numeric_types:
self._testBinary(
math_ops.add,
np.array([1, 2], dtype=dtype),
np.array([10, 20], dtype=dtype),
expected=np.array([11, 22], dtype=dtype))
self._testBinary(
math_ops.add,
dtype(5),
np.array([1, 2], dtype=dtype),
expected=np.array([6, 7], dtype=dtype))
self._testBinary(
math_ops.add,
np.array([[1], [2]], dtype=dtype),
dtype(7),
expected=np.array([[8], [9]], dtype=dtype))
self._testBinary(
math_ops.sub,
np.array([1, 2], dtype=dtype),
np.array([10, 20], dtype=dtype),
expected=np.array([-9, -18], dtype=dtype))
self._testBinary(
math_ops.sub,
dtype(5),
np.array([1, 2], dtype=dtype),
expected=np.array([4, 3], dtype=dtype))
self._testBinary(
math_ops.sub,
np.array([[1], [2]], dtype=dtype),
dtype(7),
expected=np.array([[-6], [-5]], dtype=dtype))
self._testBinary(
math_ops.maximum,
np.array([1, 2], dtype=dtype),
np.array([10, 20], dtype=dtype),
expected=np.array([10, 20], dtype=dtype))
self._testBinary(
math_ops.maximum,
dtype(5),
np.array([1, 20], dtype=dtype),
expected=np.array([5, 20], dtype=dtype))
self._testBinary(
math_ops.maximum,
np.array([[10], [2]], dtype=dtype),
dtype(7),
expected=np.array([[10], [7]], dtype=dtype))
self._testBinary(
math_ops.minimum,
np.array([1, 20], dtype=dtype),
np.array([10, 2], dtype=dtype),
expected=np.array([1, 2], dtype=dtype))
self._testBinary(
math_ops.minimum,
dtype(5),
np.array([1, 20], dtype=dtype),
expected=np.array([1, 5], dtype=dtype))
self._testBinary(
math_ops.minimum,
np.array([[10], [2]], dtype=dtype),
dtype(7),
expected=np.array([[7], [2]], dtype=dtype))
self._testBinary(
math_ops.mul,
np.array([1, 20], dtype=dtype),
np.array([10, 2], dtype=dtype),
expected=np.array([10, 40], dtype=dtype))
self._testBinary(
math_ops.mul,
dtype(5),
np.array([1, 20], dtype=dtype),
expected=np.array([5, 100], dtype=dtype))
self._testBinary(
math_ops.mul,
np.array([[10], [2]], dtype=dtype),
dtype(7),
expected=np.array([[70], [14]], dtype=dtype))
self._testBinary(
math_ops.squared_difference,
np.array([1, 2], dtype=dtype),
np.array([10, 20], dtype=dtype),
expected=np.array([81, 324], dtype=dtype))
self._testBinary(
math_ops.squared_difference,
dtype(5),
np.array([1, 2], dtype=dtype),
expected=np.array([16, 9], dtype=dtype))
self._testBinary(
math_ops.squared_difference,
np.array([[1], [2]], dtype=dtype),
dtype(7),
expected=np.array([[36], [25]], dtype=dtype))
self._testBinary(
nn_ops.bias_add,
np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([2, -1], dtype=dtype),
expected=np.array([[3, 1], [5, 3]], dtype=dtype))
self._testBinary(
nn_ops.bias_add,
np.array([[[[1, 2], [3, 4]]]], dtype=dtype),
np.array([2, -1], dtype=dtype),
expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype))
def _testDivision(self, dtype):
"""Test cases for division operators."""
self._testBinary(
math_ops.div,
np.array([10, 20], dtype=dtype),
np.array([10, 2], dtype=dtype),
expected=np.array([1, 10], dtype=dtype))
self._testBinary(
math_ops.div,
dtype(40),
np.array([2, 20], dtype=dtype),
expected=np.array([20, 2], dtype=dtype))
self._testBinary(
math_ops.div,
np.array([[10], [4]], dtype=dtype),
dtype(2),
expected=np.array([[5], [2]], dtype=dtype))
self._testBinary(
gen_math_ops._floor_div,
np.array([3, 3, -1, -9, -8], dtype=dtype),
np.array([2, -2, 7, 2, -4], dtype=dtype),
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
def testIntDivision(self):
for dtype in self.int_types:
self._testDivision(dtype)
def testFloatDivision(self):
for dtype in self.float_types:
self._testDivision(dtype)
def _testRemainder(self, dtype):
"""Test cases for remainder operators."""
self._testBinary(
gen_math_ops._floor_mod,
np.array([3, 3, -1, -8], dtype=dtype),
np.array([2, -2, 7, -4], dtype=dtype),
expected=np.array([1, -1, 6, 0], dtype=dtype))
self._testBinary(
gen_math_ops._truncate_mod,
np.array([3, 3, -1, -8], dtype=dtype),
np.array([2, -2, 7, -4], dtype=dtype),
expected=np.array([1, 1, -1, 0], dtype=dtype))
def testIntRemainder(self):
for dtype in self.int_types:
self._testRemainder(dtype)
def testFloatRemainder(self):
for dtype in self.float_types:
self._testRemainder(dtype)
def testLogicalOps(self):
self._testBinary(
math_ops.logical_and,
np.array([[True, False], [False, True]], dtype=np.bool),
np.array([[False, True], [False, True]], dtype=np.bool),
expected=np.array([[False, False], [False, True]], dtype=np.bool))
self._testBinary(
math_ops.logical_or,
np.array([[True, False], [False, True]], dtype=np.bool),
np.array([[False, True], [False, True]], dtype=np.bool),
expected=np.array([[True, True], [False, True]], dtype=np.bool))
def testComparisons(self):
self._testBinary(
math_ops.equal,
np.array([1, 5, 20], dtype=np.float32),
np.array([10, 5, 2], dtype=np.float32),
expected=np.array([False, True, False], dtype=np.bool))
self._testBinary(
math_ops.equal,
np.float32(5),
np.array([1, 5, 20], dtype=np.float32),
expected=np.array([False, True, False], dtype=np.bool))
self._testBinary(
math_ops.equal,
np.array([[10], [7], [2]], dtype=np.float32),
np.float32(7),
expected=np.array([[False], [True], [False]], dtype=np.bool))
self._testBinary(
math_ops.not_equal,
np.array([1, 5, 20], dtype=np.float32),
np.array([10, 5, 2], dtype=np.float32),
expected=np.array([True, False, True], dtype=np.bool))
self._testBinary(
math_ops.not_equal,
np.float32(5),
np.array([1, 5, 20], dtype=np.float32),
expected=np.array([True, False, True], dtype=np.bool))
self._testBinary(
math_ops.not_equal,
np.array([[10], [7], [2]], dtype=np.float32),
np.float32(7),
expected=np.array([[True], [False], [True]], dtype=np.bool))
for greater_op in [math_ops.greater, (lambda x, y: x > y)]:
self._testBinary(
greater_op,
np.array([1, 5, 20], dtype=np.float32),
np.array([10, 5, 2], dtype=np.float32),
expected=np.array([False, False, True], dtype=np.bool))
self._testBinary(
greater_op,
np.float32(5),
np.array([1, 5, 20], dtype=np.float32),
expected=np.array([True, False, False], dtype=np.bool))
self._testBinary(
greater_op,
np.array([[10], [7], [2]], dtype=np.float32),
np.float32(7),
expected=np.array([[True], [False], [False]], dtype=np.bool))
for greater_equal_op in [math_ops.greater_equal, (lambda x, y: x >= y)]:
self._testBinary(
greater_equal_op,
np.array([1, 5, 20], dtype=np.float32),
np.array([10, 5, 2], dtype=np.float32),
expected=np.array([False, True, True], dtype=np.bool))
self._testBinary(
greater_equal_op,
np.float32(5),
np.array([1, 5, 20], dtype=np.float32),
expected=np.array([True, True, False], dtype=np.bool))
self._testBinary(
greater_equal_op,
np.array([[10], [7], [2]], dtype=np.float32),
np.float32(7),
expected=np.array([[True], [True], [False]], dtype=np.bool))
for less_op in [math_ops.less, (lambda x, y: x < y)]:
self._testBinary(
less_op,
np.array([1, 5, 20], dtype=np.float32),
np.array([10, 5, 2], dtype=np.float32),
expected=np.array([True, False, False], dtype=np.bool))
self._testBinary(
less_op,
np.float32(5),
np.array([1, 5, 20], dtype=np.float32),
expected=np.array([False, False, True], dtype=np.bool))
self._testBinary(
less_op,
np.array([[10], [7], [2]], dtype=np.float32),
np.float32(7),
expected=np.array([[False], [False], [True]], dtype=np.bool))
for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]:
self._testBinary(
less_equal_op,
np.array([1, 5, 20], dtype=np.float32),
np.array([10, 5, 2], dtype=np.float32),
expected=np.array([True, True, False], dtype=np.bool))
self._testBinary(
less_equal_op,
np.float32(5),
np.array([1, 5, 20], dtype=np.float32),
expected=np.array([False, True, True], dtype=np.bool))
self._testBinary(
less_equal_op,
np.array([[10], [7], [2]], dtype=np.float32),
np.float32(7),
expected=np.array([[False], [True], [True]], dtype=np.bool))
def testBroadcasting(self):
"""Tests broadcasting behavior of an operator."""
for dtype in self.numeric_types:
self._testBinary(
math_ops.add,
np.array(3, dtype=dtype),
np.array([10, 20], dtype=dtype),
expected=np.array([13, 23], dtype=dtype))
self._testBinary(
math_ops.add,
np.array([10, 20], dtype=dtype),
np.array(4, dtype=dtype),
expected=np.array([14, 24], dtype=dtype))
# [1,3] x [4,1] => [4,3]
self._testBinary(
math_ops.add,
np.array([[10, 20, 30]], dtype=dtype),
np.array([[1], [2], [3], [4]], dtype=dtype),
expected=np.array(
[[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]],
dtype=dtype))
# [3] * [4,1] => [4,3]
self._testBinary(
math_ops.add,
np.array([10, 20, 30], dtype=dtype),
np.array([[1], [2], [3], [4]], dtype=dtype),
expected=np.array(
[[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]],
dtype=dtype))
def testFill(self):
for dtype in self.numeric_types:
self._testBinary(
array_ops.fill,
np.array([], dtype=np.int32),
dtype(-42),
expected=dtype(-42))
self._testBinary(
array_ops.fill,
np.array([1, 2], dtype=np.int32),
dtype(7),
expected=np.array([[7, 7]], dtype=dtype))
self._testBinary(
array_ops.fill,
np.array([3, 2], dtype=np.int32),
dtype(50),
expected=np.array([[50, 50], [50, 50], [50, 50]], dtype=dtype))
# Helper method used by testMatMul, testSparseMatMul, testBatchMatMul below.
def _testMatMul(self, op):
for dtype in self.float_types:
self._testBinary(
op,
np.array([[-0.25]], dtype=dtype),
np.array([[8]], dtype=dtype),
expected=np.array([[-2]], dtype=dtype))
self._testBinary(
op,
np.array([[100, 10, 0.5]], dtype=dtype),
np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype),
expected=np.array([[123, 354]], dtype=dtype))
self._testBinary(
op,
np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype),
np.array([[100], [10]], dtype=dtype),
expected=np.array([[130], [250], [680]], dtype=dtype))
self._testBinary(
op,
np.array([[1000, 100], [10, 1]], dtype=dtype),
np.array([[1, 2], [3, 4]], dtype=dtype),
expected=np.array([[1300, 2400], [13, 24]], dtype=dtype))
self._testBinary(
op,
np.array([], dtype=dtype).reshape((2, 0)),
np.array([], dtype=dtype).reshape((0, 3)),
expected=np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype))
def testMatMul(self):
self._testMatMul(math_ops.matmul)
# TODO(phawkins): failing on GPU, no registered kernel.
def DISABLED_testSparseMatMul(self):
# Binary wrappers for sparse_matmul with different hints
def SparseMatmulWrapperTF(a, b):
return tf.sparse_matmul(a, b, a_is_sparse=True)
def SparseMatmulWrapperFT(a, b):
return tf.sparse_matmul(a, b, b_is_sparse=True)
def SparseMatmulWrapperTT(a, b):
return tf.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True)
self._testMatMul(tf.sparse_matmul)
self._testMatMul(SparseMatmulWrapperTF)
self._testMatMul(SparseMatmulWrapperFT)
self._testMatMul(SparseMatmulWrapperTT)
def testBatchMatMul(self):
# Same tests as for tf.matmul above.
self._testMatMul(math_ops.matmul)
# Tests with batches of matrices.
self._testBinary(
math_ops.matmul,
np.array([[[-0.25]]], dtype=np.float32),
np.array([[[8]]], dtype=np.float32),
expected=np.array([[[-2]]], dtype=np.float32))
self._testBinary(
math_ops.matmul,
np.array([[[-0.25]], [[4]]], dtype=np.float32),
np.array([[[8]], [[2]]], dtype=np.float32),
expected=np.array([[[-2]], [[8]]], dtype=np.float32))
self._testBinary(
math_ops.matmul,
np.array(
[[[[1000, 100], [10, 1]], [[2000, 200], [20, 2]]],
[[[3000, 300], [30, 3]], [[4000, 400], [40, 4]]]],
dtype=np.float32),
np.array(
[[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[11, 22], [33, 44]],
[[55, 66], [77, 88]]]],
dtype=np.float32),
expected=np.array(
[[[[1300, 2400], [13, 24]], [[11400, 13600], [114, 136]]],
[[[42900, 79200], [429, 792]], [[250800, 299200], [2508, 2992]]]],
dtype=np.float32))
self._testBinary(
math_ops.matmul,
np.array([], dtype=np.float32).reshape((2, 2, 0)),
np.array([], dtype=np.float32).reshape((2, 0, 3)),
expected=np.array(
[[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]],
dtype=np.float32))
self._testBinary(
math_ops.matmul,
np.array([], dtype=np.float32).reshape((0, 2, 4)),
np.array([], dtype=np.float32).reshape((0, 4, 3)),
expected=np.array([], dtype=np.float32).reshape(0, 2, 3))
# Regression test for b/31472796.
if hasattr(np, "matmul"):
x = np.arange(0, 3 * 5 * 16 * 7, dtype=np.float32).reshape((3, 5, 16, 7))
self._testBinary(
lambda x, y: math_ops.matmul(x, y, adjoint_b=True),
x, x,
expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
def testExpandDims(self):
for dtype in self.numeric_types:
self._testBinary(
array_ops.expand_dims,
dtype(7),
np.int32(0),
expected=np.array([7], dtype=dtype))
self._testBinary(
array_ops.expand_dims,
np.array([42], dtype=dtype),
np.int32(0),
expected=np.array([[42]], dtype=dtype))
self._testBinary(
array_ops.expand_dims,
np.array([], dtype=dtype),
np.int32(0),
expected=np.array([[]], dtype=dtype))
self._testBinary(
array_ops.expand_dims,
np.array([[[1, 2], [3, 4]]], dtype=dtype),
np.int32(0),
expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype))
self._testBinary(
array_ops.expand_dims,
np.array([[[1, 2], [3, 4]]], dtype=dtype),
np.int32(1),
expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype))
self._testBinary(
array_ops.expand_dims,
np.array([[[1, 2], [3, 4]]], dtype=dtype),
np.int32(2),
expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
self._testBinary(
array_ops.expand_dims,
np.array([[[1, 2], [3, 4]]], dtype=dtype),
np.int32(3),
expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype))
def testPad(self):
for dtype in self.numeric_types:
self._testBinary(
array_ops.pad,
np.array(
[[1, 2, 3], [4, 5, 6]], dtype=dtype),
np.array(
[[1, 2], [2, 1]], dtype=np.int32),
expected=np.array(
[[0, 0, 0, 0, 0, 0],
[0, 0, 1, 2, 3, 0],
[0, 0, 4, 5, 6, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
dtype=dtype))
def testReshape(self):
for dtype in self.numeric_types:
self._testBinary(
array_ops.reshape,
np.array([], dtype=dtype),
np.array([0, 4], dtype=np.int32),
expected=np.zeros(shape=[0, 4], dtype=dtype))
self._testBinary(
array_ops.reshape,
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
np.array([2, 3], dtype=np.int32),
expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
self._testBinary(
array_ops.reshape,
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
np.array([3, 2], dtype=np.int32),
expected=np.array([[0, 1], [2, 3], [4, 5]], dtype=dtype))
self._testBinary(
array_ops.reshape,
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
np.array([-1, 6], dtype=np.int32),
expected=np.array([[0, 1, 2, 3, 4, 5]], dtype=dtype))
self._testBinary(
array_ops.reshape,
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
np.array([6, -1], dtype=np.int32),
expected=np.array([[0], [1], [2], [3], [4], [5]], dtype=dtype))
self._testBinary(
array_ops.reshape,
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
np.array([2, -1], dtype=np.int32),
expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
self._testBinary(
array_ops.reshape,
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
np.array([-1, 3], dtype=np.int32),
expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
def testSplit(self):
for dtype in self.numeric_types:
self._testBinary(
lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x),
np.int32(0),
np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
dtype=dtype),
expected=[
np.array([[[1], [2]]], dtype=dtype),
np.array([[[3], [4]]], dtype=dtype),
np.array([[[5], [6]]], dtype=dtype),
],
equality_test=self.ListsAreClose)
self._testBinary(
lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x),
np.int32(1),
np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
dtype=dtype),
expected=[
np.array([[[1]], [[3]], [[5]]], dtype=dtype),
np.array([[[2]], [[4]], [[6]]], dtype=dtype),
],
equality_test=self.ListsAreClose)
def testTile(self):
for dtype in self.numeric_types:
self._testBinary(
array_ops.tile,
np.array([[6]], dtype=dtype),
np.array([1, 2], dtype=np.int32),
expected=np.array([[6, 6]], dtype=dtype))
self._testBinary(
array_ops.tile,
np.array([[1], [2]], dtype=dtype),
np.array([1, 2], dtype=np.int32),
expected=np.array([[1, 1], [2, 2]], dtype=dtype))
self._testBinary(
array_ops.tile,
np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([3, 2], dtype=np.int32),
expected=np.array(
[[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4]],
dtype=dtype))
self._testBinary(
array_ops.tile,
np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([1, 1], dtype=np.int32),
expected=np.array(
[[1, 2],
[3, 4]],
dtype=dtype))
self._testBinary(
array_ops.tile,
np.array([[1, 2]], dtype=dtype),
np.array([3, 1], dtype=np.int32),
expected=np.array(
[[1, 2],
[1, 2],
[1, 2]],
dtype=dtype))
def testTranspose(self):
for dtype in self.numeric_types:
self._testBinary(
array_ops.transpose,
np.zeros(shape=[1, 0, 4], dtype=dtype),
np.array([1, 2, 0], dtype=np.int32),
expected=np.zeros(shape=[0, 4, 1], dtype=dtype))
self._testBinary(
array_ops.transpose,
np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([0, 1], dtype=np.int32),
expected=np.array([[1, 2], [3, 4]], dtype=dtype))
self._testBinary(
array_ops.transpose,
np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([1, 0], dtype=np.int32),
expected=np.array([[1, 3], [2, 4]], dtype=dtype))
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,78 @@
"""Build rules for Tensorflow/XLA testing."""
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
def all_backends():
if cuda_is_configured():
return ["cpu", "gpu"]
else:
return ["cpu"]
def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
backends=None, **kwargs):
"""Generates py_test targets, one per XLA backend.
This rule generates py_test() targets named name_backend, for each backend
in all_backends(). The rule also generates a test suite with named `name` that
tests all backends for the test.
For example, the following rule generates test cases foo_test_cpu,
foo_test_gpu, and a test suite name foo_test that tests both.
tf_xla_py_test(
name="foo_test",
srcs="foo_test.py",
deps=[...],
)
Args:
name: Name of the target.
srcs: Sources for the target.
deps: Dependencies of the target.
tags: Tags to apply to the generated targets.
data: Data dependencies of the target.
main: Same as py_test's main attribute.
backends: A list of backends to test. Supported values include "cpu" and
"gpu". If not specified, defaults to all backends.
**kwargs: keyword arguments passed onto the generated py_test() rules.
"""
if backends == None:
backends = all_backends()
test_names = []
for backend in backends:
test_name = "{}_{}".format(name, backend)
backend_tags = ["tf_xla_{}".format(backend)]
backend_args = []
backend_deps = []
backend_data = []
if backend == "cpu":
backend_args += ["--test_device=XLA_CPU",
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
elif backend == "gpu":
backend_args += ["--test_device=XLA_GPU",
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
backend_tags += ["requires-gpu-sm35"]
else:
fail("Unknown backend {}".format(backend))
native.py_test(
name=test_name,
srcs=srcs,
srcs_version="PY2AND3",
args=backend_args,
main="{}.py".format(name) if main == None else main,
data=data + backend_data,
deps=deps + backend_deps,
tags=tags + backend_tags,
**kwargs
)
test_names.append(test_name)
native.test_suite(name=name, tests=test_names)
def generate_backend_suites(backends=[]):
"""Generates per-backend test_suites that run all tests for a backend."""
if not backends:
backends = all_backends()
for backend in backends:
native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend])

View File

@ -0,0 +1,102 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the behavior of the auto-compilation pass."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
class ClusteringTest(XLATestCase):
def testAdd(self):
val1 = np.array([4, 3, 2, 1], dtype=np.float32)
val2 = np.array([5, 6, 7, 8], dtype=np.float32)
expected = val1 + val2
with self.test_session():
with self.test_scope():
input1 = constant_op.constant(val1, name="const1")
input2 = constant_op.constant(val2, name="const2")
output = math_ops.add(input1, input2)
result = output.eval()
self.assertAllClose(result, expected, rtol=1e-3)
def testAddFromCpuMultiple(self):
val1 = np.array([4, 3, 2, 1]).astype(np.float32)
val2 = np.array([5, 6, 7, 8]).astype(np.float32)
expected = val1 + val2
with self.test_session():
with ops.device(CPU_DEVICE):
input1 = constant_op.constant(val1, name="const1")
input2 = constant_op.constant(val2, name="const2")
with self.test_scope():
output = math_ops.add(input1, input2)
for _ in xrange(10):
result = output.eval()
self.assertAllClose(result, expected, rtol=1e-3)
def testDeadlock(self):
# Builds a graph of the form:
# x -> y
# | \
# z -> w
# where x and z are placed on the CPU and y and w are placed on the XLA
# device. If y and w are clustered for compilation, then the graph will
# deadlock since the clustered graph will contain a self-loop.
with self.test_session() as sess:
with ops.device(CPU_DEVICE):
x = array_ops.placeholder(dtypes.float32, [2])
with self.test_scope():
y = x * 2
with ops.device(CPU_DEVICE):
z = y * y
with self.test_scope():
w = y + z
result = sess.run(w, {x: [1.5, 0.5]})
self.assertAllClose(result, [12., 2.], rtol=1e-3)
def testHostMemory(self):
with self.test_session() as sess:
x = array_ops.placeholder(dtypes.int32)
with self.test_scope():
y = x + 1
with ops.device(CPU_DEVICE):
# Place a computation on the CPU, so y and w cannot be merged into the
# same JIT compilation.
z = y * 2
with self.test_scope():
# Argument 'y' is a non-constant output of a previous cluster. Make sure
# it is properly copied to host memory so it can be used as a
# compile-time constant input for this cluster.
w = array_ops.reshape(z, y)
result = sess.run(w, {x: [1, 0]})
expected = np.array([[4], [2]], dtype=np.int32)
self.assertAllClose(expected, result, rtol=1e-3)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,374 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for XLA Concat Op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
class ConcatTest(XLATestCase):
def testHStack(self):
with self.test_session():
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
with self.test_scope():
c = array_ops.concat_v2([p1, p2], 0)
params = {
p1: np.random.rand(4, 4).astype("f"),
p2: np.random.rand(4, 4).astype("f")
}
result = c.eval(feed_dict=params)
self.assertEqual(result.shape, c.get_shape())
self.assertAllEqual(result[:4, :], params[p1])
self.assertAllEqual(result[4:, :], params[p2])
def testVStack(self):
with self.test_session():
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
with self.test_scope():
c = array_ops.concat_v2([p1, p2], 1)
params = {
p1: np.random.rand(4, 4).astype("f"),
p2: np.random.rand(4, 4).astype("f")
}
result = c.eval(feed_dict=params)
self.assertEqual(result.shape, c.get_shape())
self.assertAllEqual(result[:, :4], params[p1])
self.assertAllEqual(result[:, 4:], params[p2])
def testInt32(self):
with self.test_session():
p1 = np.random.rand(2, 3).astype("i")
p2 = np.random.rand(2, 3).astype("i")
x1 = constant_op.constant(p1)
x2 = constant_op.constant(p2)
with self.test_scope():
c = array_ops.concat_v2([x1, x2], 0)
result = c.eval()
self.assertAllEqual(result[:2, :], p1)
self.assertAllEqual(result[2:, :], p2)
def _testRandom(self, dtype):
# Random dims of rank 5
shape = np.random.randint(1, 5, size=5)
# Random number of tensors, but always > 1.
num_tensors = np.random.randint(2, 10)
# Random dim to concat on
concat_dim = np.random.randint(5)
params = {}
if dtype == dtypes.bfloat16:
dtype_feed = dtypes.float32
else:
dtype_feed = dtype
with self.test_session():
p = []
for i in np.arange(num_tensors):
input_shape = shape
input_shape[concat_dim] = np.random.randint(1, 5)
placeholder = array_ops.placeholder(dtype_feed, shape=input_shape)
p.append(placeholder)
t = dtype_feed.as_numpy_dtype
params[placeholder] = np.random.rand(*input_shape).astype(t)
if dtype != dtype_feed:
concat_inputs = [math_ops.cast(p_i, dtype) for p_i in p]
else:
concat_inputs = p
with self.test_scope():
c = array_ops.concat_v2(concat_inputs, concat_dim)
if dtype != dtype_feed:
c = math_ops.cast(c, dtype_feed)
result = c.eval(feed_dict=params)
self.assertEqual(result.shape, c.get_shape())
cur_offset = 0
for i in np.arange(num_tensors):
# The index into the result is the ':' along all dimensions
# except the concat_dim. slice(0, size) is used for ':', and
# a list of slices is used to index into result.
ind = [slice(0, params[p[i]].shape[j]) for j in np.arange(5)]
ind[concat_dim] = slice(cur_offset,
cur_offset + params[p[i]].shape[concat_dim])
cur_offset += params[p[i]].shape[concat_dim]
if dtype == dtype_feed:
self.assertAllEqual(result[ind], params[p[i]])
else:
self.assertAllClose(result[ind], params[p[i]], 0.01)
def testRandom(self):
self._testRandom(dtypes.float32)
self._testRandom(dtypes.int32)
def _testGradientsSimple(self):
with self.test_session():
inp = []
inp_tensors = []
with self.test_scope():
for x in [1, 2, 6]:
shape = [10, x, 2]
t = np.random.rand(*shape).astype("f")
inp.append(t)
inp_tensors.append(
constant_op.constant(
[float(y) for y in t.flatten()],
shape=shape,
dtype=dtypes.float32))
c = array_ops.concat_v2(inp_tensors, 1)
output_shape = [10, 9, 2]
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat_v2(grad, 1)
result = concated_grad.eval()
self.assertAllEqual(result, grad_inp)
def testGradientsSimpleAll(self):
self._testGradientsSimple()
def _testGradientsFirstDim(self):
with self.test_session():
inp = []
inp_tensors = []
with self.test_scope():
for x in [1, 2, 6]:
shape = [x, 10, 2]
t = np.random.rand(*shape).astype("f")
inp.append(t)
inp_tensors.append(
constant_op.constant(
[float(y) for y in t.flatten()],
shape=shape,
dtype=dtypes.float32))
c = array_ops.concat_v2(inp_tensors, 0)
output_shape = [9, 10, 2]
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat_v2(grad, 0)
result = concated_grad.eval()
self.assertAllEqual(result, grad_inp)
def testGradientsFirstDimAll(self):
self._testGradientsFirstDim()
def _testGradientsLastDim(self):
with self.test_session():
inp = []
inp_tensors = []
with self.test_scope():
for x in [1, 2, 6]:
shape = [10, 2, x]
t = np.random.rand(*shape).astype("f")
inp.append(t)
inp_tensors.append(
constant_op.constant(
[float(y) for y in t.flatten()],
shape=shape,
dtype=dtypes.float32))
c = array_ops.concat_v2(inp_tensors, 2)
output_shape = [10, 2, 9]
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat_v2(grad, 2)
result = concated_grad.eval()
self.assertAllEqual(result, grad_inp)
def testGradientsLastDimAll(self):
self._testGradientsLastDim()
def _RunAndVerifyGradientsRandom(self):
# Random dims of rank 5
input_shape = np.random.randint(1, 5, size=5)
# Random number of tensors
num_tensors = np.random.randint(1, 10)
# Random dim to concat on
concat_dim = np.random.randint(5)
concat_dim_sizes = np.random.randint(1, 5, size=num_tensors)
with self.test_session():
inp = []
inp_tensors = []
with self.test_scope():
for x in concat_dim_sizes:
shape = input_shape
shape[concat_dim] = x
t = np.random.rand(*shape).astype("f")
inp.append(t)
inp_tensors.append(
constant_op.constant(
[float(y) for y in t.flatten()],
shape=shape,
dtype=dtypes.float32))
c = array_ops.concat_v2(inp_tensors, concat_dim)
output_shape = input_shape
output_shape[concat_dim] = concat_dim_sizes.sum()
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat_v2(grad, concat_dim)
result = concated_grad.eval()
self.assertAllEqual(result, grad_inp)
def testGradientsRandom(self):
for _ in range(5):
self._RunAndVerifyGradientsRandom()
# Re-enable once zero-element Retvals are handled correctly.
def DISABLED_testZeroSize(self):
# Verify that concat doesn't crash and burn for zero size inputs
np.random.seed(7)
with self.test_session() as sess:
with self.test_scope():
for shape0 in (), (2,):
axis = len(shape0)
for shape1 in (), (3,):
for n0 in 0, 1, 2:
for n1 in 0, 1, 2:
x0 = np.random.randn(*(shape0 + (n0,) + shape1))
x1 = np.random.randn(*(shape0 + (n1,) + shape1))
correct = np.concatenate([x0, x1], axis=axis)
# TODO(irving): Make tf.concat handle map, then drop list().
xs = list(map(constant_op.constant, [x0, x1]))
c = array_ops.concat_v2(xs, axis)
self.assertAllEqual(c.eval(), correct)
# Check gradients
dc = np.random.randn(*c.get_shape().as_list())
dxs = sess.run(gradients_impl.gradients(c, xs, dc))
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
def testTensorConcatDim0Grad(self):
x_shapes = [[20, 7, 3], [10, 7, 3], [14, 7, 3]]
output_shape = [44, 7, 3]
x_vals = [
np.random.random_sample(x_shape).astype(np.float32)
for x_shape in x_shapes
]
with self.test_session():
with self.test_scope():
xs = [constant_op.constant(x_val) for x_val in x_vals]
output = array_ops.concat_v2(xs, 0)
err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
output_shape)
self.assertLess(err, 1e-4)
def testTensorConcatDim1Grad(self):
x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]]
output_shape = [20, 11, 3]
x_vals = [
np.random.random_sample(x_shape).astype(np.float32)
for x_shape in x_shapes
]
with self.test_session():
with self.test_scope():
xs = [constant_op.constant(x_val) for x_val in x_vals]
output = array_ops.concat_v2(xs, 1)
err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
output_shape)
self.assertLess(err, 1e-4)
def testConcatTuple(self):
c1 = np.random.rand(4, 4).astype(np.float32)
c2 = np.random.rand(4, 4).astype(np.float32)
with self.test_session():
with self.test_scope():
concat_list_t = array_ops.concat_v2([c1, c2], 0)
concat_tuple_t = array_ops.concat_v2((c1, c2), 0)
self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
def testConcatNoScalars(self):
with self.test_session():
with self.test_scope():
scalar = constant_op.constant(7)
dim = array_ops.placeholder(dtypes.int32)
with self.assertRaisesRegexp(
ValueError, r"Can't concatenate scalars \(use tf\.pack instead\)"):
array_ops.concat_v2([scalar, scalar, scalar], dim)
class ConcatOffsetTest(XLATestCase):
def testBasic(self):
with self.test_session() as sess:
with self.test_scope():
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
off = gen_array_ops._concat_offset(cdim, [s0, s1, s2])
ans = sess.run(off)
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
class PackTest(XLATestCase):
def testBasic(self):
with self.test_session() as sess:
with self.test_scope():
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
ans = sess.run(packed)
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
def testScalars(self):
with self.test_session() as sess:
with self.test_scope():
s0 = constant_op.constant(2, dtypes.int32)
s1 = constant_op.constant(3, dtypes.int32)
s2 = constant_op.constant(5, dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
ans = sess.run(packed)
self.assertAllEqual(ans, [2, 3, 5])
def testEmpty(self):
with self.test_session() as sess:
with self.test_scope():
s0 = constant_op.constant([[]], dtypes.int32)
s1 = constant_op.constant([[]], dtypes.int32)
s2 = constant_op.constant([[]], dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
ans = sess.run(packed)
self.assertAllEqual(ans, [[[]], [[]], [[]]])
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,526 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Conv2D via the XLA JIT.
The canned results in these tests are created by running each test using the
Tensorflow CPU device and saving the output.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
class Conv2DTest(XLATestCase):
def _VerifyValues(self, input_sizes, filter_sizes, stride, padding, expected):
"""Tests that tf.nn.conv2d produces the expected value.
Args:
input_sizes: Input tensor dimensions in
[batch, input_rows, input_cols, input_depth].
filter_sizes: Filter tensor dimensions in
[kernel_rows, kernel_cols, input_depth, output_depth].
stride: Stride.
padding: Padding type.
expected: Expected output.
"""
total_size_1 = np.prod(input_sizes)
total_size_2 = np.prod(filter_sizes)
x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes)
x2 = np.arange(1, total_size_2 + 1, dtype=np.float32).reshape(filter_sizes)
strides = [1, stride, stride, 1]
with self.test_session() as sess:
with self.test_scope():
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
out = nn_ops.conv2d(
t1, t2, strides=strides, padding=padding, data_format="NHWC")
value = sess.run(out, {t1: x1, t2: x2})
self.assertArrayNear(expected, np.ravel(value), 1e-3)
def testConv2D1x1Filter(self):
expected_output = [
30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0,
204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0
]
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[1, 1, 3, 3],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D2x2Filter(self):
expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0]
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[2, 2, 3, 3],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D1x2Filter(self):
expected_output = [
231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0,
936.0, 1029.0
]
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[1, 2, 3, 3],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D2x2FilterStride2(self):
expected_output = [2271.0, 2367.0, 2463.0]
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[2, 2, 3, 3],
stride=2,
padding="VALID",
expected=expected_output)
def testConv2D2x2FilterStride2Same(self):
expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[2, 2, 3, 3],
stride=2,
padding="SAME",
expected=expected_output)
class Conv2DBackpropInputTest(XLATestCase):
def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride,
padding, expected):
"""Tests that gen_nn_ops.conv2d_backprop_input produces the expected output.
Args:
input_sizes: Input tensor dimensions in
[batch, input_rows, input_cols, input_depth].
filter_sizes: Filter tensor dimensions in
[kernel_rows, kernel_cols, input_depth, output_depth].
out_backprop_sizes: Output gradients tensor dimensions.
stride: Stride.
padding: Padding type.
expected: Expected output.
"""
total_size_1 = np.prod(filter_sizes)
total_size_2 = np.prod(out_backprop_sizes)
x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(filter_sizes)
x2 = np.arange(
1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes)
strides = [1, stride, stride, 1]
with self.test_session() as sess:
with self.test_scope():
t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
out = gen_nn_ops.conv2d_backprop_input(
input_sizes=input_sizes,
filter=t1,
out_backprop=t2,
strides=strides,
padding=padding,
data_format="NHWC")
value = sess.run(out, {t1: x1, t2: x2})
self.assertArrayNear(expected, np.ravel(value), 1e-3)
def testConv2D1x1Filter(self):
expected_output = [
5, 11, 17, 11, 25, 39, 17, 39, 61, 23, 53, 83, 29, 67, 105, 35, 81, 127,
41, 95, 149, 47, 109, 171, 53, 123, 193, 59, 137, 215, 65, 151, 237, 71,
165, 259, 77, 179, 281, 83, 193, 303, 89, 207, 325, 95, 221, 347.
]
self._VerifyValues(
input_sizes=[1, 4, 4, 3],
filter_sizes=[1, 1, 3, 2],
out_backprop_sizes=[1, 4, 4, 2],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D1x2FilterStride3Width5(self):
expected_output = [1, 2, 0, 2, 4]
self._VerifyValues(
input_sizes=[1, 1, 5, 1],
filter_sizes=[1, 2, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=3,
padding="VALID",
expected=expected_output)
def testConv2D1x2FilterStride3Width6(self):
expected_output = [1, 2, 0, 2, 4, 0]
self._VerifyValues(
input_sizes=[1, 1, 6, 1],
filter_sizes=[1, 2, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=3,
padding="VALID",
expected=expected_output)
def testConv2D1x2FilterStride3Width7(self):
expected_output = [1, 2, 0, 2, 4, 0, 0]
self._VerifyValues(
input_sizes=[1, 1, 7, 1],
filter_sizes=[1, 2, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=3,
padding="VALID",
expected=expected_output)
def testConv2D2x2FilterC1Same(self):
expected_output = [1, 4, 7, 7, 23, 33]
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
filter_sizes=[2, 2, 1, 1],
out_backprop_sizes=[1, 2, 3, 1],
stride=1,
padding="SAME",
expected=expected_output)
def testConv2D2x2Filter(self):
expected_output = [
14, 32, 50, 100, 163, 226, 167, 212, 257, 122, 140, 158, 478, 541, 604,
437, 482, 527
]
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[2, 2, 3, 3],
out_backprop_sizes=[1, 1, 2, 3],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D2x2FilterSame(self):
expected_output = [
14, 32, 50, 100, 163, 226, 217, 334, 451, 190, 307, 424, 929, 1217,
1505, 1487, 1883, 2279
]
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[2, 2, 3, 3],
out_backprop_sizes=[1, 2, 3, 3],
stride=1,
padding="SAME",
expected=expected_output)
def testConv2D1x2Filter(self):
expected_output = [1, 4, 4, 3, 10, 8, 5, 16, 12]
self._VerifyValues(
input_sizes=[1, 3, 3, 1],
filter_sizes=[1, 2, 1, 1],
out_backprop_sizes=[1, 3, 2, 1],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D1x2FilterSame(self):
expected_output = [1, 4, 7, 4, 13, 16, 7, 22, 25]
self._VerifyValues(
input_sizes=[1, 3, 3, 1],
filter_sizes=[1, 2, 1, 1],
out_backprop_sizes=[1, 3, 3, 1],
stride=1,
padding="SAME",
expected=expected_output)
def testConv2D2x2FilterStride2(self):
expected_output = [1, 2, 5, 4, 6, 0, 0, 0, 0, 0, 3, 6, 13, 8, 12]
self._VerifyValues(
input_sizes=[1, 3, 5, 1],
filter_sizes=[1, 3, 1, 1],
out_backprop_sizes=[1, 2, 2, 1],
stride=2,
padding="VALID",
expected=expected_output)
def testConv2D2x2FilterStride2Same(self):
expected_output = [1, 2, 2, 3, 4, 6]
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
filter_sizes=[2, 2, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=2,
padding="SAME",
expected=expected_output)
class Conv2DBackpropFilterTest(XLATestCase):
def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride,
padding, expected):
"""Tests that gen_nn_ops.conv2d_backprop_filter produces the right output.
Args:
input_sizes: Input tensor dimensions in
[batch, input_rows, input_cols, input_depth].
filter_sizes: Filter tensor dimensions in
[kernel_rows, kernel_cols, input_depth, output_depth].
out_backprop_sizes: Output gradients tensor dimensions.
stride: Stride.
padding: Padding type.
expected: Expected output.
"""
total_size_1 = np.prod(input_sizes)
total_size_2 = np.prod(out_backprop_sizes)
x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes)
x2 = np.arange(
1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes)
strides = [1, stride, stride, 1]
with self.test_session() as sess:
with self.test_scope():
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
tensor = gen_nn_ops.conv2d_backprop_filter(
input=t1,
filter_sizes=filter_sizes,
out_backprop=t2,
strides=strides,
padding=padding,
data_format="NHWC")
value = sess.run(tensor, {t1: x1, t2: x2})
self.assertArrayNear(expected, np.ravel(value), 1e-5)
def testConv2D1x1Filter(self):
expected_output = [8056, 8432, 8312, 8704, 8568, 8976]
self._VerifyValues(
input_sizes=[1, 4, 4, 3],
filter_sizes=[1, 1, 3, 2],
out_backprop_sizes=[1, 4, 4, 2],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D1x2Filter(self):
expected_output = [120, 141]
self._VerifyValues(
input_sizes=[1, 3, 3, 1],
filter_sizes=[1, 2, 1, 1],
out_backprop_sizes=[1, 3, 2, 1],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D2x2FilterDepth1(self):
expected_output = [5, 8, 14, 17]
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
filter_sizes=[2, 2, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D2x2Filter(self):
expected_output = [
17, 22, 27, 22, 29, 36, 27, 36, 45, 32, 43, 54, 37, 50, 63, 42, 57, 72,
62, 85, 108, 67, 92, 117, 72, 99, 126, 77, 106, 135, 82, 113, 144, 87,
120, 153
]
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[2, 2, 3, 3],
out_backprop_sizes=[1, 1, 2, 3],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D1x2FilterStride3Width5(self):
expected_output = [9, 12]
self._VerifyValues(
input_sizes=[1, 1, 5, 1],
filter_sizes=[1, 2, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=3,
padding="VALID",
expected=expected_output)
def testConv2D1x2FilterStride3Width6(self):
expected_output = [9, 12]
self._VerifyValues(
input_sizes=[1, 1, 6, 1],
filter_sizes=[1, 2, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=3,
padding="VALID",
expected=expected_output)
def testConv2D1x2FilterStride3Width7(self):
expected_output = [9, 12]
self._VerifyValues(
input_sizes=[1, 1, 7, 1],
filter_sizes=[1, 2, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=3,
padding="VALID",
expected=expected_output)
def testConv2D1x3Filter(self):
expected_output = [5, 8, 11]
self._VerifyValues(
input_sizes=[1, 1, 4, 1],
filter_sizes=[1, 3, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=1,
padding="VALID",
expected=expected_output)
def testConv2D1x3FilterSame(self):
expected_output = [20, 30, 20]
self._VerifyValues(
input_sizes=[1, 1, 4, 1],
filter_sizes=[1, 3, 1, 1],
out_backprop_sizes=[1, 1, 4, 1],
stride=1,
padding="SAME",
expected=expected_output)
def testConv2D1x3FilterSameOutbackprop2(self):
expected_output = [7, 10, 3]
self._VerifyValues(
input_sizes=[1, 1, 4, 1],
filter_sizes=[1, 3, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=2,
padding="SAME",
expected=expected_output)
def testConv2D2x2FilterC1Same(self):
expected_output = [91, 58, 32, 17]
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
filter_sizes=[2, 2, 1, 1],
out_backprop_sizes=[1, 2, 3, 1],
stride=1,
padding="SAME",
expected=expected_output)
def testConv2D2x2FilterStride2(self):
expected_output = [92, 102, 112]
self._VerifyValues(
input_sizes=[1, 3, 5, 1],
filter_sizes=[1, 3, 1, 1],
out_backprop_sizes=[1, 2, 2, 1],
stride=2,
padding="VALID",
expected=expected_output)
def testConv2D2x2FilterStride2Same(self):
expected_output = [7, 2, 16, 5]
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
filter_sizes=[2, 2, 1, 1],
out_backprop_sizes=[1, 1, 2, 1],
stride=2,
padding="SAME",
expected=expected_output)
class DepthwiseConv2DTest(XLATestCase):
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
def ConfigsToTest(self):
input_sizes = [[4, 35, 35, 2], [4, 147, 147, 2], [3, 299, 299, 3],
[5, 183, 183, 1]]
filter_sizes = [[5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]]
strides = [1, 3, 2, 2]
# pylint: disable=invalid-name
VALID = "VALID"
SAME = "SAME"
# pylint: enable=invalid-name
paddings = [SAME, VALID, SAME, SAME, SAME]
for i, f, s, p in zip(input_sizes, filter_sizes, strides, paddings):
yield i, f, s, p
def _VerifyValues(self, input_size, filter_size, stride, padding):
imag = np.random.rand(*input_size).astype(np.float32)
filt = np.random.rand(*filter_size).astype(np.float32)
strides = [1, stride, stride, 1]
with self.test_session():
with self.test_scope():
imag_ph = array_ops.placeholder(dtypes.float32, shape=input_size)
filt_ph = array_ops.placeholder(dtypes.float32, shape=filter_size)
feed_dict = {imag_ph: imag, filt_ph: filt}
xla_out = nn_impl.depthwise_conv2d(imag_ph, filt_ph, strides,
padding).eval(feed_dict=feed_dict)
with self.test_session():
with ops.device(self.CPU_DEVICE):
imag_ph = array_ops.placeholder(dtypes.float32, shape=input_size)
filt_ph = array_ops.placeholder(dtypes.float32, shape=filter_size)
feed_dict = {imag_ph: imag, filt_ph: filt}
cpu_out = nn_impl.depthwise_conv2d(imag_ph, filt_ph, strides,
padding).eval(feed_dict=feed_dict)
self.assertAllClose(xla_out, cpu_out)
# This is disabled because we need a mechanism to set command-line flags,
# i.e. an implementation of SetCommandLineOption() below.
#
# def _VerifyDummy(self, input_size, filter_size, stride, padding):
# imag = np.random.rand(*input_size).astype(np.float32)
# filt = np.random.rand(*filter_size).astype(np.float32)
# strides = [1, stride, stride, 1]
#
# with self.test_session():
# with self.test_scope():
# imag_ph = tf.placeholder(tf.float32, shape=input_size)
# filt_ph = tf.placeholder(tf.float32, shape=filter_size)
# feed_dict = {imag_ph: imag, filt_ph: filt}
# SetCommandLineOption(
# "tf_tla_depthwise_conv2d_custom_func",
# "DummyDepthwiseConv2dKernel")
# xla_out = tf.nn.depthwise_conv2d(
# imag_ph, filt_ph, strides, padding).eval(feed_dict=feed_dict)
# SetCommandLineOption(
# "tf_tla_depthwise_conv2d_custom_func", "")
#
# expected = np.array(range(np.ravel(xla_out).shape[0]), dtype=np.float32)
# self.assertAllClose(np.ravel(xla_out), expected)
def testBasic(self):
for i, f, s, p in self.ConfigsToTest():
self._VerifyValues(i, f, s, p)
# Test disabled until _VerifyDummy(), above can be implemented.
# def testCustomFunc(self):
# if self.has_custom_call:
# for i, f, s, p in self.ConfigsToTest():
# self._VerifyDummy(i, f, s, p)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,30 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/types.h"
using tensorflow::int64;
// A dummy implementation that fills the output with 0, 1, 2,...
// to test the custom call implementation of DepthwiseConv2dNative op.
// TODO(keveman): Test this after adding a real implementation for the kernel.
extern "C" void DummyDepthwiseConv2dKernel(float* output, void** inputs) {
const int64* output_size = reinterpret_cast<const int64*>(inputs[4]);
const int64 total_size =
output_size[0] * output_size[1] * output_size[2] * output_size[3];
for (int64 i = 0; i < total_size; ++i) {
*(output + i) = i;
}
}

View File

@ -0,0 +1,86 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.dynamic_stitch."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.platform import googletest
class DynamicStitchTest(XLATestCase):
def _AssertDynamicStitchResultIs(self, indices, data, expected):
with self.test_session() as session:
index_placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices
]
data_placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in data
]
with self.test_scope():
output = data_flow_ops.dynamic_stitch(index_placeholders,
data_placeholders)
feed_dict = {}
for placeholder, value in zip(index_placeholders, indices):
feed_dict[placeholder] = value
for placeholder, value in zip(data_placeholders, data):
feed_dict[placeholder] = value
result = session.run(output, feed_dict=feed_dict)
self.assertAllClose(expected, result, rtol=1e-3)
def testSimpleEmpty(self):
idx1 = np.array([0, 2], dtype=np.int32)
idx2 = np.array([[1], [3]], dtype=np.int32)
val1 = np.array([[], []], dtype=np.int32)
val2 = np.array([[[]], [[]]], dtype=np.int32)
self._AssertDynamicStitchResultIs(
[idx1, idx2], [val1, val2],
expected=np.array([[], [], [], []], np.int32))
def testSimple1D(self):
val1 = np.array([0, 4, 7], dtype=np.int32)
val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32)
val3 = np.array([0, 40, 70], dtype=np.float32)
val4 = np.array([10, 60, 20, 30, 50], dtype=np.float32)
expected = np.array([0, 10, 20, 30, 40, 50, 60, 70], dtype=np.float32)
self._AssertDynamicStitchResultIs(
[val1, val2], [val3, val4], expected=expected)
def testSimple2D(self):
val1 = np.array([0, 4, 7], dtype=np.int32)
val2 = np.array([1, 6], dtype=np.int32)
val3 = np.array([2, 3, 5], dtype=np.int32)
val4 = np.array([[0, 1], [40, 41], [70, 71]], dtype=np.float32)
val5 = np.array([[10, 11], [60, 61]], dtype=np.float32)
val6 = np.array([[20, 21], [30, 31], [50, 51]], dtype=np.float32)
expected = np.array(
[[0, 1], [10, 11], [20, 21], [30, 31], [40, 41], [50, 51], [60, 61],
[70, 71]],
dtype=np.float32)
self._AssertDynamicStitchResultIs(
[val1, val2, val3], [val4, val5, val6], expected=expected)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,130 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for Tensorflow functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
class FunctionTest(XLATestCase):
def testFunction(self):
"""Executes a simple TensorFlow function."""
def APlus2B(a, b):
return a + b * 2
aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
with self.test_session() as sess:
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
return APlus2B(a, b)
a = constant_op.constant(aval, name="a")
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_f = Foo(a, b)
result = sess.run(call_f)
self.assertAllClose(result, expected, rtol=1e-3)
def testNestedFunctions(self):
"""Executes two nested TensorFlow functions."""
def TimesTwo(x):
return x * 2
def APlus2B(a, b):
return a + TimesTwo(b)
aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
with self.test_session() as sess:
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
return APlus2B(a, b)
a = constant_op.constant(aval, name="a")
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_g = Foo(a, b)
result = sess.run(call_g)
self.assertAllClose(result, expected, rtol=1e-3)
def testFunctionMultipleRetvals(self):
"""Executes a function with multiple return values."""
# This function will run on the XLA device
def Func(a, b):
return a + b, a - b
aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = Func(aval, bval)
with self.test_session() as sess:
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
return Func(a, b)
a = constant_op.constant(aval, name="a")
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_f = Foo(a, b)
result = sess.run(call_f)
self.assertAllClose(result, expected, rtol=1e-3)
def testFunctionsNoInline(self):
@function.Defun(dtypes.float32, noinline=True)
def TimesTwo(x):
return x * 2
@function.Defun(dtypes.float32, dtypes.float32)
def APlus2B(a, b):
return a + TimesTwo(b)
aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = aval + bval * 2
with self.test_session() as sess:
with self.test_scope():
a = array_ops.placeholder(dtypes.float32, name="a")
b = array_ops.placeholder(dtypes.float32, name="b")
call = APlus2B(a, b)
result = sess.run(call, {a: aval, b: bval})
self.assertAllClose(result, expected, rtol=1e-3)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,459 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for JIT compilation on the CPU and GPU devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.compiler import jit
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
jit_scope = jit.experimental_jit_scope
def CompiledKernel(fn, *inputs, **kwargs):
"""Execute 'fn' as a compiled XLA kernel, with 'inputs'."""
name = kwargs.pop("name", None)
noinline = kwargs.pop("noinline", None)
@function.Defun(func_name=name, noinline=noinline, compiled=True)
def Compiled(*args):
return fn(*args)
return Compiled(*inputs)
def RunMetadataLabels(run_metadata):
"""Returns all labels in run_metadata."""
labels = []
for dev_stats in run_metadata.step_stats.dev_stats:
for node_stats in dev_stats.node_stats:
labels.append(node_stats.timeline_label)
return labels
def InLabels(labels, substr):
"""Returns true iff one of the labels contains substr."""
return any([substr in x for x in labels])
def MetadataHasXlaLaunch(run_metadata):
"""Returns true if there is a _XlaLaunch kernel in run_metadata's timeline."""
# TODO(phawkins): find a less hacky way to test whether a kernel ran.
return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch")
class JitLaunchTest(test.TestCase):
# Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel.
# Verifies that the outputs match and that XLA was invoked. 'fn' must take
# the same number of tensors as arguments that are in 'args', and must return
# a tuple of output tensors.
# If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node
# actually ran. However, it is sometimes possible for _XlaLaunch ops to be
# constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
with session_lib.Session() as sess:
placeholders = []
feeds = {}
for arg in args:
placeholder = array_ops.placeholder(
dtypes.as_dtype(arg.dtype), list(arg.shape))
placeholders.append(placeholder)
feeds[placeholder] = arg
compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline)
direct_op = fn(*placeholders)
run_metadata = config_pb2.RunMetadata()
compiled = sess.run(compiled_op,
feeds,
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
print("Compiled Result {}".format(compiled))
if require_kernel_launch:
self.assert_(MetadataHasXlaLaunch(run_metadata))
direct = sess.run(direct_op, feeds)
print("Direct Result {}".format(direct))
if (isinstance(compiled, (tuple, list)) and
(isinstance(direct, (tuple, list)))):
for (x, y) in zip(compiled, direct):
self.assertAllClose(x, y, rtol=1e-1)
else:
self.assertAllClose(compiled, direct)
def testNoOutputs(self):
with session_lib.Session() as sess:
# Build a function with a single Const node, whose output is ignored.
fdef = function_pb2.FunctionDef()
fdef.signature.name = "KernelWithNoOutputs"
node = node_def_pb2.NodeDef()
node.op = "Const"
node.name = "ignored"
node.attr["dtype"].type = dtypes.int32.as_datatype_enum
tensor = tensor_util.make_tensor_proto([0], dtype=dtypes.int32, shape=[])
node.attr["value"].tensor.CopyFrom(tensor)
fdef.node_def.extend([node])
# Check that calling the result as a compiled kernel doesn't crash.
@function.Defun(compiled=True)
def KernelWithNoOutputs():
return constant_op.constant(100)
# Hack to override the definition. By accessing .definition, we
# force the _DefinedFunction initialized internally. Then, we
# replace it's internal FunctionDef proto. We do this hack here
# because one typically can't construct KernelWithNoOutputs
# function via Defun decorator directly.
_ = KernelWithNoOutputs.definition
foo = KernelWithNoOutputs
foo._definition = fdef
call = KernelWithNoOutputs()
sess.run(call, {})
def testAliasing(self):
"""Regression test for compiled functions that return an aliased buffer.
XLA returns aliased buffers if outputs are identical. Tests that
we handle that case.
"""
def AddOnceReturnTwice(x):
y = math_ops.add(x, x)
return y, y
# Exercises compling a function (say, Foo) which calls another
# function (say, Bar) which is not inlined. When the compiler compiles
# Foo, it needs to symbolic execute Bar correctly regardless whether
# Bar is inlined or not.
#
# Tests compiled=True and noinline=True.
self._compare(
AddOnceReturnTwice, [np.array(
[[[0.5, -1.0]]], dtype=np.float32)],
noinline=True)
# Tests compiled=True and noinline=False.
self._compare(
AddOnceReturnTwice, [np.array(
[[[0.5, -1.0]]], dtype=np.float32)],
noinline=False)
def testOneConstOutput(self):
"""Test consisting of a single constant return value."""
def OneConstOutput():
return constant_op.constant([-3, 44, 99])
self._compare(OneConstOutput, [], require_kernel_launch=False)
def testConstZeroElementOutput(self):
"""Test consisting of a constant zero element return value."""
def ConstZeroElementOutput():
return array_ops.fill([7, 0], 3.0)
self._compare(ConstZeroElementOutput, [], require_kernel_launch=False)
def testSomeConstOutputs(self):
"""Test kernels that return a mixture of const and non-const outputs."""
def SomeConstOutputs(x):
return constant_op.constant(
[-2, 7]), array_ops.identity(x), constant_op.constant(3.5)
self._compare(
SomeConstOutputs, [np.array(
[[1, 2, 3], [4, 5, 6]], dtype=np.float32)])
def testInt32Input(self):
"""Test an int32-typed input.
On a GPU, int32 tensors will be placed in host memory.
"""
def AddToSelf(x):
return math_ops.add(x, x)
self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)])
def testMandatoryConstantInput(self):
"""Tests an operator that has a mandatory-constant shape input."""
def FillWithFloat(x):
return array_ops.fill(x, 9.5)
self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)])
def testMnistForwardFunc(self):
"""Compute inference function from MNIST beginners tutorial."""
batch_size = 16
image_size = 28 * 28
num_classes = 10
# Define a TensorFlow function to compute the forward pass.
def MnistForward(w, b, x):
return nn_ops.softmax(math_ops.matmul(x, w) + b)
w = np.random.random_sample((image_size, num_classes)).astype(np.float32)
b = np.random.random_sample((num_classes)).astype(np.float32)
x = np.random.random_sample((batch_size, image_size)).astype(np.float32)
self._compare(MnistForward, [w, b, x])
def testExplicitMarking(self):
"""Test explicit marking of operators to compile."""
batch_size = 16
image_size = 28 * 28
num_classes = 10
with ops.Graph().as_default():
x = array_ops.placeholder(dtypes.float32)
w = array_ops.placeholder(dtypes.float32)
b = array_ops.placeholder(dtypes.float32)
with jit_scope():
y1 = math_ops.matmul(x, w)
y2 = math_ops.add(y1, b)
with jit_scope():
y = math_ops.square(y2)
dw = np.random.random_sample((image_size, num_classes)).astype(np.float32)
db = np.random.random_sample((num_classes)).astype(np.float32)
dx = np.random.random_sample((batch_size, image_size)).astype(np.float32)
with session_lib.Session() as sess:
run_metadata = config_pb2.RunMetadata()
output = sess.run(y, {x: dx,
w: dw,
b: db},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
# TODO(phawkins): really we would like to test that there were exactly
# two kernel launches. However, we have no reliable way to determine
# that.
self.assert_(MetadataHasXlaLaunch(run_metadata))
expected = np.square(np.dot(dx, dw) + db)
self.assertAllClose(expected, output, rtol=1e-1)
class XlaCompilationTest(test.TestCase):
"""Tests for auto-compilation on CPU/GPU devices."""
def testReshape(self):
"""Tests an operator with compile-time constant and non-constant inputs."""
with self.test_session() as sess:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.int32)
with jit_scope():
# Reshape's first argument is non-constant in the JIT, but its second
# (shape) argument will be treated as a compile-time constant for
# each JIT compilation.
# We do not use a tf.const() argument since we want to ensure the
# shape is still a run-time argument to the JIT, and not
# statically known as part of the JIT compilation's input graph.
z = array_ops.reshape(x, y)
run_metadata = config_pb2.RunMetadata()
out = sess.run(z,
{x: np.array([1, 2, 3, 4, 5, 6], np.float32),
y: [-1, 3]},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assert_(MetadataHasXlaLaunch(run_metadata))
self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
def testIgnoredArguments(self):
"""Tests that JIT computations can ignore formal parameters."""
with self.test_session() as sess:
x = array_ops.placeholder(dtypes.int32)
y = array_ops.placeholder(dtypes.int32)
with jit_scope():
z = math_ops.add(x, x)
w = math_ops.add(y, y)
# Pulls 'w' into the same compilation via control dependencies.
with ops.control_dependencies([w]):
n = control_flow_ops.no_op()
with ops.control_dependencies([n]):
t = math_ops.add(z, z)
run_metadata = config_pb2.RunMetadata()
out = sess.run(t, {x: np.int32(7),
y: np.int32(404)},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assert_(MetadataHasXlaLaunch(run_metadata))
self.assertAllClose(28, out)
def testLoops(self):
"""Tests that compilation accepts computations containing loops."""
with self.test_session() as session:
x = array_ops.placeholder(dtypes.float32)
with jit_scope():
c = lambda i, _: math_ops.less(i, 5)
b = lambda i, x: (i + 1, x * 2.0 + 1.0)
_, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
run_metadata = config_pb2.RunMetadata()
result = session.run(y, {x: np.float32(2)},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assert_(MetadataHasXlaLaunch(run_metadata))
self.assertAllClose(result, np.float32(95), rtol=1e-1)
def testCond(self):
"""Tests that compilation handles switch operators."""
with self.test_session() as session:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
c = array_ops.placeholder(dtypes.bool)
with jit_scope():
z = x + 1.0
w = control_flow_ops.cond(c, lambda: z, lambda: y)
t = math_ops.add(z, w)
# If JIT compilation chooses to cluster z and t, then execution will
# deadlock.
run_metadata = config_pb2.RunMetadata()
result = session.run(t, {x: np.float32(2),
y: np.float32(4),
c: True},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assert_(MetadataHasXlaLaunch(run_metadata))
self.assertAllClose(result, np.float32(6), rtol=1e-1)
def testNestedFunction(self):
g = ops.Graph()
with g.as_default():
@function.Defun(compiled=True)
def Bar(x, y):
return x + 2 * y
@function.Defun(compiled=True)
def Foo(x):
return Bar(x * x, x * x * x)
@function.Defun()
def Entry(x):
return Foo(x)
inp = array_ops.placeholder(dtypes.float32)
out = Entry(inp)
with self.test_session(graph=g, use_gpu=True) as sess:
run_metadata = config_pb2.RunMetadata()
val = sess.run(out,
feed_dict={inp: [2., 10.]},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assertAllClose(val, [20., 2100.])
def testLoopDeadlock(self):
"""Regression test for bug that caused deadlocks in graphs with loops."""
with self.test_session() as session:
x = array_ops.placeholder(dtypes.float32)
with jit_scope():
y = x + 1.0
c = lambda i, _x, _y: math_ops.less(i, 5)
b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0)
_, _, w = control_flow_ops.while_loop(c, b,
(constant_op.constant(0), y, x))
u = w + y
result = session.run(u, {x: np.float32(2)})
self.assertAllClose(result, np.float32(63), rtol=1e-1)
def testGradient(self):
"""Tests that the backprop function is properly compiled."""
def _Run(compiled):
@function.Defun(compiled=compiled)
def Forward(x):
return math_ops.log(x)
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtypes.float32)
y = Forward(x)
dx, = gradients_impl.gradients(y, [x], 1.0)
cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
optimizer_options=config_pb2.OptimizerOptions(
opt_level=config_pb2.OptimizerOptions.L1,
do_function_inlining=True)))
with session_lib.Session(graph=g, config=cfg) as sess:
run_metadata = config_pb2.RunMetadata()
dx_val = sess.run(dx,
feed_dict={x: 100.},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assertAllClose(dx_val, 0.01)
return RunMetadataLabels(run_metadata)
# SymGrad[f=log(x)](x, dy) = 1/x * dy
#
# Note: we don't need to compute log(x) for dx due to graph pruning.
# Do not compile the backprop. We should see one Reciprocal and one Mul.
labels = _Run(compiled=False)
self.assertFalse(InLabels(labels, "Log"))
self.assertTrue(InLabels(labels, "Reciprocal"))
self.assertTrue(InLabels(labels, "Mul"))
self.assertFalse(InLabels(labels, "_XlaLaunch"))
# Compile the backprop. One _XlaLaunch.
labels = _Run(compiled=True)
self.assertFalse(InLabels(labels, "Log"))
self.assertFalse(InLabels(labels, "Reciprocal"))
self.assertFalse(InLabels(labels, "Mul"))
self.assertTrue(InLabels(labels, "_XlaLaunch"))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,129 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Local Response Normalization ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import googletest
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
# Local response normalization tests. The forward tests are copied from
# tensorflow/python/kernel_tests/lrn_op_test.py
class LRNTest(XLATestCase):
def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0,
beta=0.5):
"""Compute expected result."""
output = copy.deepcopy(input_image)
batch_size = input_image.shape[0]
rows = input_image.shape[1]
cols = input_image.shape[2]
depth = input_image.shape[3]
for b in range(batch_size):
for r in range(rows):
for c in range(cols):
for d in range(depth):
begin = max(0, d - lrn_depth_radius)
end = min(depth, d + lrn_depth_radius + 1)
patch = input_image[b, r, c, begin:end]
output[b, r, c, d] /= (
np.power(bias + alpha * np.sum(patch * patch), beta))
return output
def _RunAndVerify(self, dtype):
with self.test_session():
# random shape
shape = np.random.randint(1, 16, size=4)
# Make depth at least 2 to make it meaningful
shape[3] += 1
p = array_ops.placeholder(dtype, shape=shape)
# random depth_radius, bias, alpha, beta
lrn_depth_radius = np.random.randint(1, shape[3])
bias = 1.0 + np.random.rand()
alpha = 2.0 * np.random.rand()
beta = 2.0 * np.random.rand()
with self.test_scope():
lrn_t = nn.local_response_normalization(
p,
name="lrn",
depth_radius=lrn_depth_radius,
bias=bias,
alpha=alpha,
beta=beta)
params = {p: np.random.rand(*shape).astype("f")}
result = lrn_t.eval(feed_dict=params)
expected = self._LRN(
params[p],
lrn_depth_radius=lrn_depth_radius,
bias=bias,
alpha=alpha,
beta=beta)
err = np.amax(np.abs(result - expected))
print("LRN error for bias ", bias, "alpha ", alpha, " beta ", beta, " is ",
err)
if dtype == dtypes.float32:
self.assertTrue(err < 1e-4)
else:
self.assertTrue(err < 1e-2)
self.assertShapeEqual(expected, lrn_t)
def testCompute(self):
for _ in range(2):
self._RunAndVerify(dtypes.float32)
def testLrnGrad(self):
# Test for LRNGrad that compares against the CPU implementation.
shape = [1, 2, 3, 4]
total_size = np.prod(shape)
in_image_vals = np.arange(1, total_size + 1, dtype=np.float32)
out_image_vals = np.arange(1, total_size + 1, dtype=np.float32)
out_grads_vals = np.arange(1, total_size + 1, dtype=np.float32)
depth_radius = np.random.randint(1, shape[3])
bias = 1.0 + np.random.rand()
alpha = 1.0 * np.random.rand()
beta = 1.0 * np.random.rand()
with self.test_session():
in_image = constant_op.constant(in_image_vals, shape=shape)
out_image = constant_op.constant(out_image_vals, shape=shape)
out_grads = constant_op.constant(out_grads_vals, shape=shape)
with ops.device(CPU_DEVICE):
expected = gen_nn_ops._lrn_grad(out_grads, in_image, out_image,
depth_radius, bias, alpha, beta)
with self.test_scope():
actual = gen_nn_ops._lrn_grad(out_grads, in_image, out_image,
depth_radius, bias, alpha, beta)
expected_val = expected.eval()
actual_val = actual.eval()
self.assertAllClose(actual_val, expected_val, rtol=1e-3)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,158 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple LSTM layer with benchmarks.
This sets up a simple LSTM (Long Short Term Memory) layer, unrolled to a fixed
length sequence. The only deviation from standard LSTM cells is that
activations are clipped, inspired by the GNMT machine translation model.
The GNMT paper has more details: https://arxiv.org/abs/1609.08144
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
def Clip(x):
"""Clips x to the range [-1., 1.]."""
return math_ops.maximum(math_ops.minimum(x, 1.), -1.)
def LSTMCellWeightsShape(num_inputs, num_nodes):
"""Returns the shape of the weights for a single LSTM cell."""
# Dimension 0 accounts for combining x with the previous m state.
# Dimension 1 accounts for the in value and the (in, forget, out) gates.
return [num_inputs + num_nodes, 4 * num_nodes]
def LSTMCell(weights, m_prev, c_prev, x, pad):
"""Unrolls a single LSTM cell with clipped activations forward by one step.
Args:
weights: Weight matrix with shape LSTMCellWeightsShape.
m_prev: Previous m states with shape [batch_size, num_nodes].
c_prev: Previous c states with shape [batch_size, num_nodes].
x: Input with shape [batch_size, num_inputs].
pad: Padding with shape [batch_size, 1]. Each padding value is either
0 or 1, where 1 indicates padding; i.e. the input is shorter than the
sequence length, and the (m, c) states should simply be passed through
from the previous states.
Returns:
The next (m, c) states, each with shape [batch_size, num_nodes].
"""
# Apply weights to the input and previous hidden state.
# The matmul here is the "big" operation.
xm = array_ops.concat_v2([x, m_prev], 1)
xmw = math_ops.matmul(xm, weights)
# Element-wise ops for the standard LSTM cell, with clipped activations.
# XLA can fuse these operations into a single loop.
in_value, in_gate, forget_gate, out_gate = array_ops.split(
value=xmw, num_or_size_splits=4, axis=1)
in_value = math_ops.tanh(in_value)
in_gate = math_ops.sigmoid(in_gate)
forget_gate = math_ops.sigmoid(forget_gate)
out_gate = math_ops.sigmoid(out_gate)
c_next = Clip(Clip(forget_gate * c_prev) + Clip(in_gate * in_value))
m_next = Clip(out_gate * c_next)
# Account for padding.
c_next = c_prev * pad + c_next * (1.0 - pad)
m_next = m_prev * pad + m_next * (1.0 - pad)
return m_next, c_next
def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq):
"""Unrolls a layer of LSTM cells forward by the sequence length.
The sequence length is determined by the length of x_seq and pad_seq, which
must be the same.
Args:
cell_name: Base name of each cell.
weights: Weight matrix with shape LSTMCellWeightsShape.
m: Initial m states with shape [batch_size, num_nodes].
c: Initial c states with shape [batch_size, num_nodes].
x_seq: List of inputs, each with shape [batch_size, num_inputs].
The length of the list is the sequence length.
pad_seq: List of paddings, each with shape [batch_size, 1].
The length of the list is the sequence length.
Each padding value is either 0 or 1, where 1 indicates padding;
i.e. the input is shorter than the sequence length.
Returns:
List of per-sequence-step outputs, each with shape [batch_size, num_nodes].
Raises:
ValueError: If len(x_seq) != len(pad_seq).
"""
if len(x_seq) != len(pad_seq):
raise ValueError('length of x_seq(%d) != pad_seq(%d)' %
(len(x_seq), len(pad_seq)))
out_seq = []
for seq in range(len(x_seq)):
with ops.name_scope('%s_%d' % (cell_name, seq)):
m, c = LSTMCell(weights, m, c, x_seq[seq], pad_seq[seq])
out_seq.append(array_ops.identity(m, name='out'))
return out_seq
def RandomVar(shape, name=None):
"""Returns a variable of the given shape initialized to random values."""
return variables.Variable(
random_ops.random_uniform(shape), dtype=dtypes.float32, name=name)
def RandomInputs(batch_size, seq_length, num_inputs):
"""Returns randomly initialized (x_seq, pad_seq) sequences."""
x_seq = []
pad_seq = []
with ops.name_scope('inputs'):
for seq in range(seq_length):
x_seq.append(RandomVar([batch_size, num_inputs], name='x_seq_%d' % seq))
# Real padding values are always a sequence of 0 followed by a
# sequence of 1, but random values are fine for benchmarking.
pad_seq.append(RandomVar([batch_size, 1], name='pad_seq_%d' % seq))
return x_seq, pad_seq
def BuildLSTMLayer(batch_size, seq_length, num_inputs, num_nodes):
"""Builds a single LSTM layer with random weights and inputs.
Args:
batch_size: Inputs are fed in batches of this size.
seq_length: The sequence length to unroll the LSTM layer.
num_inputs: Dimension of inputs that are fed into each LSTM cell.
num_nodes: The number of nodes in each LSTM cell.
Returns:
(out_seq, weights) pair. The out_seq is a list of per-sequence-step
outputs, each with shape [batch_size, num_nodes]. The weights are a list of
weight variables that may be trained.
"""
weights = RandomVar(
LSTMCellWeightsShape(num_inputs, num_nodes), name='weights')
m = array_ops.zeros([batch_size, num_nodes], name='init_m')
c = array_ops.zeros([batch_size, num_nodes], name='init_c')
x_seq, pad_seq = RandomInputs(batch_size, seq_length, num_inputs)
out_seq = LSTMLayer('lstm', weights, m, c, x_seq, pad_seq)
return out_seq, [weights]

View File

@ -0,0 +1,20 @@
# Text form of tensorflow.tfcompile.Config proto.
feed{ id{node_name:"inputs/x_seq_0/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/x_seq_1/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/x_seq_2/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/x_seq_3/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/x_seq_4/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/pad_seq_0/read"} shape{dim{size:128}dim{size:1}} }
feed{ id{node_name:"inputs/pad_seq_1/read"} shape{dim{size:128}dim{size:1}} }
feed{ id{node_name:"inputs/pad_seq_2/read"} shape{dim{size:128}dim{size:1}} }
feed{ id{node_name:"inputs/pad_seq_3/read"} shape{dim{size:128}dim{size:1}} }
feed{ id{node_name:"inputs/pad_seq_4/read"} shape{dim{size:128}dim{size:1}} }
feed{ id{node_name:"weights/read"} shape{dim{size:2048}dim{size:4096}} }
feed{ id{node_name:"init_c"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"init_m"} shape{dim{size:128}dim{size:1024}} }
fetch{ id{node_name:"lstm_0/out"} }
fetch{ id{node_name:"lstm_1/out"} }
fetch{ id{node_name:"lstm_2/out"} }
fetch{ id{node_name:"lstm_3/out"} }
fetch{ id{node_name:"lstm_4/out"} }

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,293 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the LSTM cell and layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.compiler.tests import lstm
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.platform import test
flags = flags_lib
FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', 128,
'Inputs are fed in batches of this size, for both '
'inference and training. Larger values cause the matmul '
'in each LSTM cell to have higher dimensionality.')
flags.DEFINE_integer('seq_length', 60,
'Length of the unrolled sequence of LSTM cells in a layer.'
'Larger values cause more LSTM matmuls to be run.')
flags.DEFINE_integer('num_inputs', 1024,
'Dimension of inputs that are fed into each LSTM cell.')
flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.')
flags.DEFINE_string('device', 'gpu',
'TensorFlow device to assign ops to, e.g. "gpu", "cpu". '
'For details see documentation for tf.Graph.device.')
flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in '
'*.pbtxt format to this directory.')
def _DumpGraph(graph, basename):
if FLAGS.dump_graph_dir:
name = os.path.join(FLAGS.dump_graph_dir, basename + '.pbtxt')
with open(name, 'w') as f:
f.write(str(graph.as_graph_def()))
def _Sigmoid(x):
return 1. / (1. + np.exp(-x))
def _Clip(x):
return np.maximum(np.minimum(x, 1.), -1.)
class LSTMTest(test.TestCase):
def setUp(self):
# The tests for a single LSTM cell and LSTM layer use these values as
# inputs. We always set the dimensionality of num_inputs=1; thus batch_size
# actually represents the different input cases.
self._inputs = np.array([[-1.], [-.5], [0.], [.5], [1.]], np.float32)
self._batch_size = len(self._inputs)
def _NextC(self, inputs, weight, m_prev, c_prev):
"""Returns the next c states of an LSTM cell."""
x = (inputs + m_prev) * weight
return _Clip(_Clip(_Sigmoid(x) * c_prev) + _Clip(_Sigmoid(x) * np.tanh(x)))
def _NextM(self, inputs, weight, m_prev, c_prev):
"""Returns the next m states of an LSTM cell."""
x = (inputs + m_prev) * weight
return _Clip(_Sigmoid(x) * self._NextC(inputs, weight, m_prev, c_prev))
def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar,
pad_scalar):
with self.test_session() as sess:
num_inputs = 1
num_nodes = 1
weights = init_weights(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
m_prev = constant_op.constant([[m_prev_scalar]] * self._batch_size)
c_prev = constant_op.constant([[c_prev_scalar]] * self._batch_size)
x = constant_op.constant(self._inputs)
pad = constant_op.constant([[pad_scalar]] * self._batch_size)
m, c = lstm.LSTMCell(weights, m_prev, c_prev, x, pad)
_DumpGraph(sess.graph, 'lstm_cell_%s_%d_%d_%d' %
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
# Initialize variables and run the unrolled LSTM step.
sess.run(variables.global_variables_initializer())
return sess.run([m, c])
def testLSTMCell(self):
# Run with all-0 weights, no padding.
m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 0., 0., 0.)
self.assertAllClose(m, [[0.]] * self._batch_size)
self.assertAllClose(c, [[0.]] * self._batch_size)
m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 0., 1., 0.)
self.assertAllClose(m, [[.25]] * self._batch_size)
self.assertAllClose(c, [[.5]] * self._batch_size)
m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 1., 0., 0.)
self.assertAllClose(m, [[.0]] * self._batch_size)
self.assertAllClose(c, [[.0]] * self._batch_size)
m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 1., 1., 0.)
self.assertAllClose(m, [[.25]] * self._batch_size)
self.assertAllClose(c, [[.5]] * self._batch_size)
# Run with all-1 weights, no padding.
for m_prev in [0., 1.]:
for c_prev in [0., 1.]:
m, c = self._RunLSTMCell('ones',
init_ops.ones_initializer(), m_prev, c_prev,
0.)
self.assertAllClose(m, self._NextM(self._inputs, 1., m_prev, c_prev))
self.assertAllClose(c, self._NextC(self._inputs, 1., m_prev, c_prev))
# Run with random weights.
for weight in np.random.rand(3):
weight_tf = constant_op.constant(weight, dtypes.float32)
random_weight = lambda shape, w=weight_tf: array_ops.fill(shape, w)
# No padding.
for m_prev in [0., 1.]:
for c_prev in [0., 1.]:
m, c = self._RunLSTMCell('random', random_weight, m_prev, c_prev, 0.)
self.assertAllClose(m,
self._NextM(self._inputs, weight, m_prev, c_prev))
self.assertAllClose(c,
self._NextC(self._inputs, weight, m_prev, c_prev))
# Set padding.
for m_prev in [0., 1.]:
for c_prev in [0., 1.]:
m, c = self._RunLSTMCell('random', random_weight, m_prev, c_prev, 1.)
self.assertAllClose(m, [[m_prev]] * self._batch_size)
self.assertAllClose(c, [[c_prev]] * self._batch_size)
def testLSTMLayerErrors(self):
num_inputs = 1
num_nodes = 1
seq_length = 3
weights = array_ops.zeros(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
m = constant_op.constant([[0.]] * self._batch_size)
c = constant_op.constant([[0.]] * self._batch_size)
x_seq = [constant_op.constant(self._inputs)] * seq_length
pad = constant_op.constant([[0.]] * self._batch_size)
with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad])
with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad] * 2)
with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad] * 4)
def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar,
pad_scalar):
with self.test_session() as sess:
num_inputs = 1
num_nodes = 1
seq_length = 3
weights = init_weights(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
m_init = constant_op.constant([[m_init_scalar]] * self._batch_size)
c_init = constant_op.constant([[c_init_scalar]] * self._batch_size)
x_seq = [constant_op.constant(self._inputs)] * seq_length
pad_seq = [constant_op.constant([[pad_scalar]] * self._batch_size)
] * seq_length
out_seq = lstm.LSTMLayer('lstm', weights, m_init, c_init, x_seq, pad_seq)
_DumpGraph(sess.graph, 'lstm_layer_%s_%d_%d_%d' %
(basename, m_init_scalar, c_init_scalar, pad_scalar))
# Initialize variables and run the unrolled LSTM layer.
sess.run(variables.global_variables_initializer())
return sess.run(out_seq)
def testLSTMLayer(self):
# Run with all-0 weights, no padding.
o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 0., 0., 0.)
self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 0., 1., 0.)
self.assertAllClose(o, [[[.25]] * self._batch_size,
[[.125]] * self._batch_size,
[[.0625]] * self._batch_size])
o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 1., 0., 0.)
self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 1., 1., 0.)
self.assertAllClose(o, [[[.25]] * self._batch_size,
[[.125]] * self._batch_size,
[[.0625]] * self._batch_size])
# Run with all-1 weights, no padding.
weight1 = 1.
for m_init in [0., 1.]:
for c_init in [0., 1.]:
o = self._RunLSTMLayer('ones',
init_ops.ones_initializer(), m_init, c_init, 0.)
m0 = self._NextM(self._inputs, weight1, m_init, c_init)
c0 = self._NextC(self._inputs, weight1, m_init, c_init)
self.assertAllClose(o[0], m0)
m1 = self._NextM(self._inputs, weight1, m0, c0)
c1 = self._NextC(self._inputs, weight1, m0, c0)
self.assertAllClose(o[1], m1)
m2 = self._NextM(self._inputs, weight1, m1, c1)
self.assertAllClose(o[2], m2)
# Run with random weights.
for weight in np.random.rand(3):
weight_tf = constant_op.constant(weight, dtypes.float32)
random_weight = lambda shape, w=weight_tf: array_ops.fill(shape, w)
# No padding.
for m_init in [0., 1.]:
for c_init in [0., 1.]:
o = self._RunLSTMLayer('random', random_weight, m_init, c_init, 0.)
m0 = self._NextM(self._inputs, weight, m_init, c_init)
c0 = self._NextC(self._inputs, weight, m_init, c_init)
self.assertAllClose(o[0], m0)
m1 = self._NextM(self._inputs, weight, m0, c0)
c1 = self._NextC(self._inputs, weight, m0, c0)
self.assertAllClose(o[1], m1)
m2 = self._NextM(self._inputs, weight, m1, c1)
self.assertAllClose(o[2], m2)
# Set padding.
o = self._RunLSTMLayer('random', random_weight, 0., 0., 1.)
self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
o = self._RunLSTMLayer('random', random_weight, 0., 1., 1.)
self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
o = self._RunLSTMLayer('random', random_weight, 1., 0., 1.)
self.assertAllClose(o, [[[1.]] * self._batch_size] * 3)
o = self._RunLSTMLayer('random', random_weight, 1., 1., 1.)
self.assertAllClose(o, [[[1.]] * self._batch_size] * 3)
class LSTMBenchmark(test.Benchmark):
"""Mcro-benchmarks for a single layer of LSTM cells."""
def _LayerBuilder(self, do_training):
out_seq, weights = lstm.BuildLSTMLayer(FLAGS.batch_size, FLAGS.seq_length,
FLAGS.num_inputs, FLAGS.num_nodes)
name, fetches = ('lstm_layer_inference', out_seq)
if do_training:
# Not a real loss function, but good enough for benchmarking backprop.
loss = math_ops.reduce_sum(math_ops.add_n(out_seq))
dw = gradients_impl.gradients(loss, weights)
name, fetches = ('lstm_layer_training', dw)
_DumpGraph(ops.get_default_graph(),
'%s_%d_%d_%d_%d' % (name, FLAGS.batch_size, FLAGS.seq_length,
FLAGS.num_inputs, FLAGS.num_nodes))
return name, fetches
def benchmarkLayerInference(self):
xla_test.Benchmark(self, lambda: self._LayerBuilder(False), False,
FLAGS.device)
def benchmarkLayerInferenceXLA(self):
xla_test.Benchmark(self, lambda: self._LayerBuilder(False), True,
FLAGS.device)
def benchmarkLayerTraining(self):
xla_test.Benchmark(self, lambda: self._LayerBuilder(True), False,
FLAGS.device)
def benchmarkLayerTrainingXLA(self):
xla_test.Benchmark(self, lambda: self._LayerBuilder(True), True,
FLAGS.device)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,209 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for operators with > 3 or arbitrary numbers of arguments."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
class NAryOpsTest(XLATestCase):
def _testNAry(self, op, args, expected):
with self.test_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
for arg in args
]
feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
output = op(placeholders)
result = session.run(output, feeds)
self.assertAllClose(result, expected, rtol=1e-3)
def testFloat(self):
self._testNAry(math_ops.add_n,
[np.array([[1, 2, 3]], dtype=np.float32)],
expected=np.array([[1, 2, 3]], dtype=np.float32))
self._testNAry(math_ops.add_n,
[np.array([1, 2], dtype=np.float32),
np.array([10, 20], dtype=np.float32)],
expected=np.array([11, 22], dtype=np.float32))
self._testNAry(math_ops.add_n,
[np.array([-4], dtype=np.float32),
np.array([10], dtype=np.float32),
np.array([42], dtype=np.float32)],
expected=np.array([48], dtype=np.float32))
def testConcat(self):
self._testNAry(
lambda x: array_ops.concat_v2(x, 0), [
np.array(
[[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
[[7, 8, 9], [10, 11, 12]], dtype=np.float32)
],
expected=np.array(
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.float32))
self._testNAry(
lambda x: array_ops.concat_v2(x, 1), [
np.array(
[[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
[[7, 8, 9], [10, 11, 12]], dtype=np.float32)
],
expected=np.array(
[[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32))
def testSplitV(self):
with self.test_session() as session:
with self.test_scope():
output = session.run(
array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]],
dtype=np.float32),
[2, 2], 1))
expected = [np.array([[1, 2], [5, 6], [9, 0]], dtype=np.float32),
np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)]
self.assertAllEqual(output, expected)
def testStridedSlice(self):
self._testNAry(lambda x: array_ops.strided_slice(*x),
[np.array([[], [], []], dtype=np.float32),
np.array([1, 0], dtype=np.int32),
np.array([3, 0], dtype=np.int32),
np.array([1, 1], dtype=np.int32)],
expected=np.array([[], []], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice(*x),
[np.array([[], [], []], dtype=np.float32),
np.array([1, 0], dtype=np.int64),
np.array([3, 0], dtype=np.int64),
np.array([1, 1], dtype=np.int64)],
expected=np.array([[], []], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice(*x),
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=np.float32),
np.array([1, 1], dtype=np.int32),
np.array([3, 3], dtype=np.int32),
np.array([1, 1], dtype=np.int32)],
expected=np.array([[5, 6], [8, 9]], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice(*x),
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=np.float32),
np.array([0, 2], dtype=np.int32),
np.array([2, 0], dtype=np.int32),
np.array([1, -1], dtype=np.int32)],
expected=np.array([[3, 2], [6, 5]], dtype=np.float32))
self._testNAry(lambda x: x[0][0:2, array_ops.newaxis, ::-1],
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=np.float32)],
expected=np.array([[[3, 2, 1]], [[6, 5, 4]]],
dtype=np.float32))
self._testNAry(lambda x: x[0][1, :, array_ops.newaxis],
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=np.float32)],
expected=np.array([[4], [5], [6]], dtype=np.float32))
def testStridedSliceGrad(self):
# Tests cases where input shape is empty.
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([], dtype=np.int32),
np.array([], dtype=np.int32),
np.array([], dtype=np.int32),
np.array([], dtype=np.int32),
np.float32(0.5)],
expected=np.array(np.float32(0.5), dtype=np.float32))
# Tests case where input shape is non-empty, but gradients are empty.
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([3], dtype=np.int32),
np.array([0], dtype=np.int32),
np.array([0], dtype=np.int32),
np.array([1], dtype=np.int32),
np.array([], dtype=np.float32)],
expected=np.array([0, 0, 0], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([3, 0], dtype=np.int32),
np.array([1, 0], dtype=np.int32),
np.array([3, 0], dtype=np.int32),
np.array([1, 1], dtype=np.int32),
np.array([[], []], dtype=np.float32)],
expected=np.array([[], [], []], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([3, 3], dtype=np.int32),
np.array([1, 1], dtype=np.int32),
np.array([3, 3], dtype=np.int32),
np.array([1, 1], dtype=np.int32),
np.array([[5, 6], [8, 9]], dtype=np.float32)],
expected=np.array([[0, 0, 0], [0, 5, 6], [0, 8, 9]],
dtype=np.float32))
def ssg_test(x):
return array_ops.strided_slice_grad(*x, shrink_axis_mask=0x4,
new_axis_mask=0x1)
self._testNAry(ssg_test,
[np.array([3, 1, 3], dtype=np.int32),
np.array([0, 0, 0, 2], dtype=np.int32),
np.array([0, 3, 1, -4], dtype=np.int32),
np.array([1, 2, 1, -3], dtype=np.int32),
np.array([[[1], [2]]], dtype=np.float32)],
expected=np.array([[[0, 0, 1]], [[0, 0, 0]], [[0, 0, 2]]],
dtype=np.float32))
ssg_test2 = lambda x: array_ops.strided_slice_grad(*x, new_axis_mask=0x15)
self._testNAry(ssg_test2,
[np.array([4, 4], dtype=np.int32),
np.array([0, 0, 0, 1, 0], dtype=np.int32),
np.array([0, 3, 0, 4, 0], dtype=np.int32),
np.array([1, 2, 1, 2, 1], dtype=np.int32),
np.array([[[[[1], [2]]], [[[3], [4]]]]], dtype=np.float32)],
expected=np.array([[0, 1, 0, 2], [0, 0, 0, 0], [0, 3, 0, 4],
[0, 0, 0, 0]], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([3, 3], dtype=np.int32),
np.array([0, 2], dtype=np.int32),
np.array([2, 0], dtype=np.int32),
np.array([1, -1], dtype=np.int32),
np.array([[1, 2], [3, 4]], dtype=np.float32)],
expected=np.array([[0, 2, 1], [0, 4, 3], [0, 0, 0]],
dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([3, 3], dtype=np.int32),
np.array([2, 2], dtype=np.int32),
np.array([0, 1], dtype=np.int32),
np.array([-1, -2], dtype=np.int32),
np.array([[1], [2]], dtype=np.float32)],
expected=np.array([[0, 0, 0], [0, 0, 2], [0, 0, 1]],
dtype=np.float32))
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,61 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for operators with no arguments."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import googletest
class NullaryOpsTest(XLATestCase):
def _testNullary(self, op, expected):
with self.test_session() as session:
with self.test_scope():
output = op()
result = session.run(output)
self.assertAllClose(result, expected, rtol=1e-3)
def testNoOp(self):
with self.test_session():
with self.test_scope():
output = control_flow_ops.no_op()
# This should not crash.
output.run()
def testConstants(self):
constants = [
np.float32(42),
np.array([], dtype=np.float32),
np.array([1, 2], dtype=np.float32),
np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
dtype=np.float32),
np.array([[[]], [[]]], dtype=np.float32),
np.array([[[[1]]]], dtype=np.float32),
]
for c in constants:
self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,511 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for pooling operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
def NHWCToNCHW(input_tensor):
"""Convert the input from NHWC format to NCHW.
Args:
input_tensor: a 4-D tensor, or a 4-element array representing the same.
Returns:
the converted tensor or a shape array
"""
if isinstance(input_tensor, ops.Tensor):
return array_ops.transpose(input_tensor, [0, 3, 1, 2])
else:
return [input_tensor[0], input_tensor[3], input_tensor[1], input_tensor[2]]
def NCHWToNHWC(input_tensor):
"""Convert the input from NCHW format to NHWC.
Args:
input_tensor: a 4-D tensor, or a 4-element array representing the same.
Returns:
the converted tensor or a shape array
"""
if isinstance(input_tensor, ops.Tensor):
return array_ops.transpose(input_tensor, [0, 2, 3, 1])
else:
return [input_tensor[0], input_tensor[2], input_tensor[3], input_tensor[1]]
def GetTestConfigs():
"""Get all the valid tests configs to run.
Returns:
all the valid test configs
"""
test_configs = ["NHWC", "NCHW"]
return test_configs
class PoolingTest(XLATestCase):
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
data_format, expected):
"""Verifies the output values of the pooling function.
Args:
pool_func: Function to be called, currently only co.MaxPool.
input_sizes: Input tensor dimensions.
ksize: The kernel size dimensions
strides: The stride dimensions
padding: Padding type.
data_format: The data format we use to run the pooling operation.
expected: An array containing the expected operation outputs.
"""
total_size = np.prod(input_sizes)
# Initializes the input tensor with array containing incrementing
# numbers from 1.
x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32)
x = x.reshape(input_sizes)
with self.test_session() as sess:
with self.test_scope():
inputs = array_ops.placeholder(dtypes.float32)
t = inputs
if data_format == "NCHW":
t = NHWCToNCHW(t)
ksize = NHWCToNCHW(ksize)
strides = NHWCToNCHW(strides)
t = pool_func(t,
ksize=ksize,
strides=strides,
padding=padding,
data_format=data_format)
if data_format == "NCHW":
t = NCHWToNHWC(t)
actual = sess.run(t, {inputs: x})
self.assertAllClose(expected, actual.flatten(), rtol=1e-5, atol=1e-6)
def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
expected):
"""Verifies the output values of the pooling function.
Args:
pool_func: Function to be called, co.MaxPool, co.AvgPool,
or the Lua version.
input_sizes: Input tensor dimensions.
ksize: The kernel size dimensions
strides: The stride dimensions
padding: Padding type.
expected: An array containing the expected operation outputs.
"""
for data_format in GetTestConfigs():
self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding,
data_format, expected)
def testMaxPoolValidPadding(self):
expected_output = [13.0, 14.0, 15.0]
self._VerifyValues(nn_ops.max_pool,
input_sizes=[1, 3, 3, 3],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="VALID",
expected=expected_output)
def testMaxPoolSamePadding(self):
expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0]
self._VerifyValues(nn_ops.max_pool,
input_sizes=[1, 2, 3, 3],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=expected_output)
def testMaxPoolSamePaddingNonSquareWindow(self):
# input is:
# [1.0, 2.0
# 3.0 4.0]
#
# Window of [x, x] should do:
#
# [max(1.0, 2.0), max(2.0, padded0),
# max(3.0, 4.0), max(4.0, padded0)]
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 2, 2, 1],
ksize=[1, 1, 2, 1],
strides=[1, 1, 1, 1],
padding="SAME",
expected=[2.0, 2.0, 4.0, 4.0])
def testMaxPoolValidPaddingUnevenStride(self):
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 4, 4, 1],
ksize=[1, 2, 2, 1],
strides=[1, 1, 2, 1],
padding="VALID",
expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0])
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 4, 4, 1],
ksize=[1, 2, 2, 1],
strides=[1, 2, 1, 1],
padding="VALID",
expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0])
def testMaxPoolSamePaddingFilter4(self):
expected_output = [
21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0, 54.0, 55.0, 56.0,
61.0, 62.0, 63.0, 64.0
]
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 4, 4, 4],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=expected_output)
def testMaxPoolSamePaddingFilter8(self):
expected_output = [
145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0,
163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 177.0, 178.0, 179.0, 180.0,
181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0,
191.0, 192.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0,
289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 305.0, 306.0,
307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0,
317.0, 318.0, 319.0, 320.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0,
407.0, 408.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0,
433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0,
443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 465.0, 466.0, 467.0, 468.0,
469.0, 470.0, 471.0, 472.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0,
487.0, 488.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0,
505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0
]
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 8, 8, 8],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=expected_output)
# Tests for DepthwiseMaxPooling on CPU only.
def testDepthwiseMaxPool1x1DepthWindow1(self):
# input is:
# [1.0, ..., 10.0] along depth,
#
# We maxpool by depth in patches of 2.
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 1, 1, 10],
ksize=[1, 1, 1, 2],
strides=[1, 1, 1, 2],
padding="SAME",
expected=[2.0, 4.0, 6.0, 8.0, 10.0])
def testDepthwiseMaxPool2x2DepthWindow3(self):
# input is:
#
# a 2x2x6 cube, and we depthwise max across 3 to produce a 2x2x2
# output. Each node has contiguous values, so the depthwise max
# should be multiples of 3.0.
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 2, 2, 6],
ksize=[1, 1, 1, 3],
strides=[1, 1, 1, 3],
padding="SAME",
expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0])
def testKernelSmallerThanStrideValid(self):
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 7, 7, 1],
ksize=[1, 2, 2, 1],
strides=[1, 3, 3, 1],
padding="VALID",
expected=[9, 12, 30, 33])
def testKernelSmallerThanStrideSame(self):
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 3, 3, 1],
ksize=[1, 1, 1, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=[1, 3, 7, 9])
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 4, 4, 1],
ksize=[1, 1, 1, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=[1, 3, 9, 11])
# Average pooling
def testAvgPoolValidPadding(self):
expected_output = [7, 8, 9]
self._VerifyValues(
nn_ops.avg_pool,
input_sizes=[1, 3, 3, 3],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="VALID",
expected=expected_output)
def testAvgPoolSamePadding(self):
expected_output = [7., 8., 9., 11.5, 12.5, 13.5]
self._VerifyValues(
nn_ops.avg_pool,
input_sizes=[1, 2, 3, 3],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=expected_output)
class PoolGradTest(XLATestCase):
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
def _VerifyOneTest(self, pool_func, pool_grad_func, input_sizes, ksize,
strides, padding, data_format):
"""Verifies the output values of the pooling gradient function.
Args:
pool_func: Forward pooling function
pool_grad_func: Pooling gradient function for pool_grad_func
input_sizes: Input tensor dimensions.
ksize: The kernel size dimensions
strides: The stride dimensions
padding: Padding type.
data_format: The data format we use to run the pooling operation.
"""
total_size = np.prod(input_sizes)
x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes)
with self.test_session() as sess:
# Use the forward pool function to compute some corresponding outputs
# (needed for the CPU device, and we need the shape in both cases).
with ops.device(self.CPU_DEVICE):
inputs = array_ops.placeholder(dtypes.float32, shape=input_sizes)
outputs = pool_func(
inputs,
ksize=ksize,
strides=strides,
padding=padding,
data_format="NHWC")
output_vals = np.array(sess.run(outputs, {inputs: x}))
output_gradient_vals = np.arange(
1, output_vals.size + 1, dtype=np.float32)
output_gradient_vals = output_gradient_vals.reshape(output_vals.shape)
# Use the Tensorflow CPU pooling gradient to compute the expected input
# gradients.
with ops.device(self.CPU_DEVICE):
output_gradients = array_ops.placeholder(
dtypes.float32, shape=output_vals.shape)
expected_input_gradients = pool_grad_func(
inputs,
outputs,
output_gradients,
ksize=ksize,
strides=strides,
padding=padding,
data_format="NHWC")
expected_input_gradient_vals = sess.run(
expected_input_gradients,
{inputs: x,
output_gradients: output_gradient_vals})
# Run the gradient op on the XLA device
with self.test_scope():
outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape)
xla_inputs = inputs
xla_outputs = outputs
xla_output_gradients = output_gradients
xla_ksize = ksize
xla_strides = strides
if data_format == "NCHW":
xla_inputs = NHWCToNCHW(inputs)
xla_outputs = NHWCToNCHW(outputs)
xla_output_gradients = NHWCToNCHW(output_gradients)
xla_ksize = NHWCToNCHW(ksize)
xla_strides = NHWCToNCHW(strides)
actual_input_gradients = pool_grad_func(
xla_inputs,
xla_outputs,
xla_output_gradients,
ksize=xla_ksize,
strides=xla_strides,
padding=padding,
data_format=data_format)
if data_format == "NCHW":
actual_input_gradients = NCHWToNHWC(actual_input_gradients)
actual = sess.run(actual_input_gradients, {
inputs: x,
outputs: output_vals,
output_gradients: output_gradient_vals
})
# Compare the Tensorflow and XLA results.
self.assertAllClose(
expected_input_gradient_vals.flatten(),
actual.flatten(),
rtol=1e-5,
atol=1e-6)
self.assertShapeEqual(actual, inputs)
def _VerifyValues(self, pool_func, pool_grad_func, input_sizes, ksize,
strides, padding):
"""Verifies the output values of the pooling function.
Args:
pool_func: Pooling function to be called, e.g., tf.nn.max_pool
pool_grad_func: Corresponding pooling gradient function.
input_sizes: Input tensor dimensions.
ksize: The kernel size dimensions
strides: The stride dimensions
padding: Padding type.
"""
for data_format in GetTestConfigs():
self._VerifyOneTest(pool_func, pool_grad_func, input_sizes, ksize,
strides, padding, data_format)
def _TestPooling(self, forward_op, backward_op):
# VALID padding
self._VerifyValues(
forward_op,
backward_op,
input_sizes=[1, 3, 3, 3],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="VALID")
# SAME padding
self._VerifyValues(
forward_op,
backward_op,
input_sizes=[1, 2, 3, 3],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="SAME")
# SAME padding, non square window
self._VerifyValues(
forward_op,
backward_op,
input_sizes=[1, 2, 2, 1],
ksize=[1, 1, 2, 1],
strides=[1, 1, 1, 1],
padding="SAME")
# VALID padding, uneven stride
self._VerifyValues(
forward_op,
backward_op,
input_sizes=[1, 4, 4, 1],
ksize=[1, 2, 2, 1],
strides=[1, 1, 2, 1],
padding="VALID")
self._VerifyValues(
forward_op,
backward_op,
input_sizes=[1, 4, 4, 1],
ksize=[1, 2, 2, 1],
strides=[1, 2, 1, 1],
padding="VALID")
# SAME padding, size 4 input
self._VerifyValues(
forward_op,
backward_op,
input_sizes=[1, 4, 4, 4],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="SAME")
# SAME padding, size 8 input
self._VerifyValues(
forward_op,
backward_op,
input_sizes=[1, 8, 8, 8],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding="SAME")
def testMaxPool(self):
self._TestPooling(nn_ops.max_pool, gen_nn_ops._max_pool_grad)
def testAvgPool(self):
# Wrapper around AvgPoolGrad that ignores extra arguments needed by
# MaxPoolGrad.
def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding,
data_format):
del outputs # Unused by average-pooling gradients.
return gen_nn_ops._avg_pool_grad(
inputs.get_shape().as_list(),
output_gradients,
ksize=ksize,
strides=strides,
padding=padding,
data_format=data_format)
self._TestPooling(nn_ops.avg_pool, AvgPoolGrad)
# The CPU implementation of AvgPoolGrad doesn't accept kernels smaller than
# the stride size, so we only run the following tests on MaxPoolGrad.
def testMaxPoolKernelSmallerThanStrideValid(self):
self._VerifyValues(
nn_ops.max_pool,
gen_nn_ops._max_pool_grad,
input_sizes=[1, 7, 7, 1],
ksize=[1, 2, 2, 1],
strides=[1, 3, 3, 1],
padding="VALID")
def testMaxPoolKernelSmallerThanStrideSame(self):
self._VerifyValues(
nn_ops.max_pool,
gen_nn_ops._max_pool_grad,
input_sizes=[1, 3, 3, 1],
ksize=[1, 1, 1, 1],
strides=[1, 2, 2, 1],
padding="SAME")
self._VerifyValues(
nn_ops.max_pool,
gen_nn_ops._max_pool_grad,
input_sizes=[1, 4, 4, 1],
ksize=[1, 1, 1, 1],
strides=[1, 2, 2, 1],
padding="SAME")
if __name__ == "__main__":
googletest.main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,125 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for reduction operators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
class ReduceOpsTest(XLATestCase):
def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs,
rtol=1e-4, atol=1e-4):
"""Tests that the output of 'tf_reduce_fn' matches numpy's output."""
for test_input in test_inputs:
with self.test_session() as sess:
with self.test_scope():
a = array_ops.placeholder(dtype)
index = array_ops.placeholder(dtypes.int32)
out = tf_reduce_fn(a, index)
result = sess.run(out, {a: test_input, index: [0]})
self.assertAllClose(result, np_reduce_fn(test_input, axis=0),
rtol=rtol, atol=atol)
result = sess.run(out, {a: test_input, index: [1]})
self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
rtol=rtol, atol=atol)
result = sess.run(out, {a: test_input, index: [-1]})
self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
rtol=rtol, atol=atol)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
sess.run(out, {a: test_input, index: [-33]})
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
sess.run(out, {a: test_input, index: [2]})
FLOAT_DATA = [
np.zeros(shape=(2, 0)),
np.zeros(shape=(0, 30)),
np.arange(1, 7).reshape(2, 3),
np.arange(-10, -4).reshape(2, 3),
np.arange(-4, 2).reshape(2, 3),
]
NONEMPTY_FLOAT_DATA = [
np.arange(1, 7).reshape(2, 3),
np.arange(-10, -4).reshape(2, 3),
np.arange(-4, 2).reshape(2, 3),
]
BOOL_DATA = [
np.array([], dtype=np.bool).reshape(2, 0),
np.array([], dtype=np.bool).reshape(0, 3),
np.array([[False, True, False], [True, True, False]]),
]
def testReduceSum(self):
self._testReduction(math_ops.reduce_sum, np.sum, np.float32,
self.FLOAT_DATA)
def testReduceProd(self):
self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
self.FLOAT_DATA)
def testReduceMin(self):
def reference_min(inp, axis):
"""Wrapper around np.amin that returns +infinity for an empty input."""
if inp.shape[axis] == 0:
return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf'))
return np.amin(inp, axis)
self._testReduction(math_ops.reduce_min, reference_min, np.float32,
self.FLOAT_DATA)
def testReduceMax(self):
def reference_max(inp, axis):
"""Wrapper around np.amax that returns -infinity for an empty input."""
if inp.shape[axis] == 0:
return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf'))
return np.amax(inp, axis)
self._testReduction(math_ops.reduce_max, reference_max, np.float32,
self.FLOAT_DATA)
def testReduceMean(self):
# TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
# reducing across zero inputs.
self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
self.NONEMPTY_FLOAT_DATA)
def testReduceAll(self):
self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA)
def testReduceAny(self):
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
if __name__ == '__main__':
googletest.main()

View File

@ -0,0 +1,110 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for ternary operators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
class TernaryOpsTest(XLATestCase):
def _testTernary(self, op, a, b, c, expected):
with self.test_session() as session:
with self.test_scope():
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c")
output = op(pa, pb, pc)
result = session.run(output, {pa: a, pb: b, pc: c})
self.assertAllClose(result, expected, rtol=1e-3)
def testLinspace(self):
self._testTernary(
math_ops.linspace,
np.float32(1),
np.float32(2),
np.int32(1),
expected=np.array([1], dtype=np.float32))
self._testTernary(
math_ops.linspace,
np.float32(1),
np.float32(4),
np.int32(3),
expected=np.array([1, 2.5, 4], dtype=np.float32))
def testRange(self):
self._testTernary(
math_ops.range,
np.int32(1),
np.int32(2),
np.int32(1),
expected=np.array([1], dtype=np.int32))
self._testTernary(
math_ops.range,
np.int32(1),
np.int32(7),
np.int32(2),
expected=np.array([1, 3, 5], dtype=np.int32))
def testSelect(self):
self._testTernary(
array_ops.where,
np.array(0, dtype=np.bool),
np.array(2, dtype=np.float32),
np.array(7, dtype=np.float32),
expected=np.array(7, dtype=np.float32))
self._testTernary(
array_ops.where,
np.array([0, 1, 1, 0], dtype=np.bool),
np.array([1, 2, 3, 4], dtype=np.float32),
np.array([5, 6, 7, 8], dtype=np.float32),
expected=np.array([5, 2, 3, 8], dtype=np.float32))
self._testTernary(
array_ops.where,
np.array([0, 1, 0], dtype=np.bool),
np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32))
def testSlice(self):
for dtype in self.numeric_types:
self._testTernary(
array_ops.slice,
np.array([[], [], []], dtype=dtype),
np.array([1, 0], dtype=np.int32),
np.array([2, 0], dtype=np.int32),
expected=np.array([[], []], dtype=dtype))
self._testTernary(
array_ops.slice,
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype),
np.array([0, 1], dtype=np.int32),
np.array([2, 1], dtype=np.int32),
expected=np.array([[2], [5]], dtype=dtype))
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,346 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for XLA JIT compiler."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
class UnaryOpsTest(XLATestCase):
"""Test cases for unary operators."""
def _testUnary(self, op, inp, expected, equality_test=None):
with self.test_session() as session:
with self.test_scope():
pinp = array_ops.placeholder(
dtypes.as_dtype(inp.dtype), inp.shape, name="a")
output = op(pinp)
result = session.run(output, {pinp: inp})
if equality_test is None:
equality_test = self.assertAllClose
equality_test(result, expected, rtol=1e-3)
def ListsAreClose(self, result, expected, rtol):
"""Tests closeness of two lists of floats."""
self.assertEqual(len(result), len(expected))
for i in range(len(result)):
self.assertAllClose(result[i], expected[i], rtol)
def testAllTypeOps(self):
for dtype in self.numeric_types:
self._testUnary(
array_ops.diag,
np.array([1, 2, 3, 4], dtype=dtype),
np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
dtype=dtype))
self._testUnary(
array_ops.diag_part,
np.arange(36).reshape([2, 3, 2, 3]).astype(dtype),
np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype))
self._testUnary(
array_ops.identity,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[-1, 1]], dtype=dtype))
self._testUnary(
array_ops.matrix_diag,
np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype))
self._testUnary(
array_ops.matrix_diag_part,
np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype),
np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype))
self._testUnary(
array_ops.prevent_gradient,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[-1, 1]], dtype=dtype))
self._testUnary(
array_ops.squeeze,
np.array([[[[[]]]]], dtype=dtype),
expected=np.array([], dtype=dtype))
self._testUnary(
array_ops.squeeze,
np.array([[[1], [2]]], dtype=dtype),
expected=np.array([1, 2], dtype=dtype))
self._testUnary(
array_ops.squeeze,
np.array([[[1]], [[2]]], dtype=dtype),
expected=np.array([1, 2], dtype=dtype))
self._testUnary(
array_ops.squeeze,
np.array([[[1, 2], [3, 4]]], dtype=dtype),
expected=np.array([[1, 2], [3, 4]], dtype=dtype))
self._testUnary(
array_ops.stop_gradient,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[-1, 1]], dtype=dtype))
def testFloatOps(self):
for dtype in self.float_types:
self._testUnary(
math_ops.ceil,
np.array([[-1.7, 1.2]], dtype=dtype),
expected=np.array([[-1, 2]], dtype=dtype))
self._testUnary(
math_ops.exp,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[0.36787945, 2.7182817]], dtype=dtype))
self._testUnary(
math_ops.floor,
np.array([[-1.7, 1.2]], dtype=dtype),
expected=np.array([[-2, 1]], dtype=dtype))
# Tests for tf.nn ops.
self._testUnary(
nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0))
# TODO(b/31644876): enable this test case when fixed.
# self._testUnary(tf.nn.l2_loss, dtype(4), dtype(10))
self._testUnary(
nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10))
self._testUnary(
math_ops.reciprocal,
np.array([[1, 2]], dtype=dtype),
expected=np.array([[1, 0.5]], dtype=dtype))
self._testUnary(
math_ops.log,
np.array([[1, 2]], dtype=dtype),
expected=np.array([[0, 0.69314718]], dtype=dtype))
self._testUnary(
math_ops.rsqrt,
np.array([[4, 16]], dtype=dtype),
expected=np.array([[0.5, 0.25]], dtype=dtype))
self._testUnary(
math_ops.sigmoid,
np.array(
[[1, 1, 1, 1],
[1, 2, 3, 4]],
dtype=dtype),
expected=np.array(
[[0.7310586, 0.7310586, 0.7310586, 0.7310586],
[0.7310586, 0.880797, 0.95257413, 0.98201376]],
dtype=dtype))
self._testUnary(
math_ops.sqrt,
np.array([[4, 9]], dtype=dtype),
expected=np.array([[2, 3]], dtype=dtype))
self._testUnary(
math_ops.tanh,
np.array(
[[1, 1, 1, 1],
[1, 2, 3, 4]],
dtype=dtype),
expected=np.array(
[[0.76159418, 0.76159418, 0.76159418, 0.76159418],
[0.76159418, 0.96402758, 0.99505478, 0.99932933]],
dtype=dtype))
self._testUnary(
nn_ops.log_softmax,
np.array(
[[1, 1, 1, 1],
[1, 2, 3, 4]],
dtype=dtype),
expected=np.array(
[[-1.3862944, -1.3862944, -1.3862944, -1.3862944],
[-3.4401896, -2.4401896, -1.4401897, -0.44018969]],
dtype=dtype))
self._testUnary(
nn_ops.relu,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[0, 1]], dtype=dtype))
self._testUnary(
nn_ops.relu6,
np.array([[-0.05, 6.05, 5]], dtype=dtype),
expected=np.array([[0, 6, 5]], dtype=dtype))
self._testUnary(
nn_ops.softmax,
np.array(
[[1, 1, 1, 1],
[1, 2, 3, 4]],
dtype=dtype),
expected=np.array(
[[0.25, 0.25, 0.25, 0.25],
[0.032058604, 0.087144323, 0.23688284, 0.64391428]],
dtype=dtype))
self._testUnary(
nn_ops.softplus,
np.array([[-2, 0, 8]], dtype=dtype),
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
def testNumericOps(self):
for dtype in self.numeric_types:
self._testUnary(
math_ops.abs,
np.array([[2, -1]], dtype=dtype),
expected=np.array([[2, 1]], dtype=dtype))
self._testUnary(
math_ops.neg,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[1, -1]], dtype=dtype))
self._testUnary(
math_ops.square,
np.array([[-2, 3]], dtype=dtype),
expected=np.array([[4, 9]], dtype=dtype))
self._testUnary(
array_ops.zeros_like,
np.array([[4, 3], [2, 1]], dtype=dtype),
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
def testLogicalOps(self):
self._testUnary(
math_ops.logical_not,
np.array([[True, False], [False, True]], dtype=np.bool),
expected=np.array([[False, True], [True, False]], dtype=np.bool))
def testBiasAddGrad(self):
self._testUnary(
gen_nn_ops.bias_add_grad,
np.array([[1., 2.], [3., 4.]], dtype=np.float32),
expected=np.array([4., 6.], dtype=np.float32))
self._testUnary(lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"),
np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]],
dtype=np.float32),
expected=np.array([10., 26.], dtype=np.float32))
def testCast(self):
shapes = [[], [4], [2, 3], [2, 0, 4]]
types = [dtypes.bool, dtypes.int32, dtypes.float32]
for shape in shapes:
for src_type in types:
for dst_type in types:
src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype)
src = src.reshape(shape)
dst = src.astype(dst_type.as_numpy_dtype)
self._testUnary(
lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
src,
expected=dst)
def testInvertPermutation(self):
self._testUnary(
array_ops.invert_permutation,
np.array([1, 2, 0], np.int32),
expected=np.array([2, 0, 1], dtype=np.int32))
def testRank(self):
rank_op = lambda x: array_ops.rank_internal(x, optimize=False)
for dtype in self.numeric_types:
self._testUnary(rank_op, dtype(7), expected=np.int32(0))
self._testUnary(
rank_op, np.array(
[[], []], dtype=dtype), expected=np.int32(2))
self._testUnary(
rank_op, np.array(
[-1, 1], dtype=dtype), expected=np.int32(1))
self._testUnary(
rank_op, np.array(
[[-1, 1]], dtype=dtype), expected=np.int32(2))
self._testUnary(
rank_op,
np.array([[-1], [1], [4]], dtype=dtype),
expected=np.int32(2))
def testShape(self):
shape_op = lambda x: array_ops.shape_internal(x, optimize=False)
for dtype in self.numeric_types:
self._testUnary(shape_op, dtype(7), expected=np.array([], dtype=np.int32))
self._testUnary(
shape_op,
np.array([[], []], dtype=dtype),
expected=np.array([2, 0], dtype=np.int32))
self._testUnary(
shape_op,
np.array([-1, 1], dtype=dtype),
expected=np.array([2], dtype=np.int32))
self._testUnary(
shape_op,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([1, 2], dtype=np.int32))
self._testUnary(
shape_op,
np.array([[-1], [1], [4]], dtype=dtype),
expected=np.array([3, 1], dtype=np.int32))
def testSize(self):
size_op = lambda x: array_ops.size_internal(x, optimize=False)
for dtype in self.numeric_types:
self._testUnary(size_op, dtype(7), expected=np.int32(1))
self._testUnary(
size_op, np.array([[], []], dtype=dtype), expected=np.int32(0))
self._testUnary(
size_op, np.array([-1, 1], dtype=dtype), expected=np.int32(2))
self._testUnary(
size_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2))
self._testUnary(
size_op,
np.array([[-1], [1], [4]], dtype=dtype),
expected=np.int32(3))
def testUnpack(self):
self._testUnary(
array_ops.unpack,
np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
expected=[
np.array([1., 2.], dtype=np.float32),
np.array([3., 4.], dtype=np.float32),
np.array([5., 6.], dtype=np.float32),
],
equality_test=self.ListsAreClose)
self._testUnary(lambda x: array_ops.unstack(x, axis=1),
np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
expected=[
np.array([1., 3., 5.], dtype=np.float32),
np.array([2., 4., 6.], dtype=np.float32),
],
equality_test=self.ListsAreClose)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,81 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for XLA devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class XlaDeviceTest(test.TestCase):
def testCopies(self):
"""Tests that copies between GPU and XLA devices work."""
if not test.is_gpu_available():
return
with session_lib.Session() as sess:
x = array_ops.placeholder(dtypes.float32, [2])
with ops.device("GPU"):
y = x * 2
with ops.device("device:XLA_CPU:0"):
z = y * y
with ops.device("GPU"):
w = y + z
result = sess.run(w, {x: [1.5, 0.5]})
self.assertAllClose(result, [12., 2.], rtol=1e-3)
def testLoops(self):
"""Tests that loops work on XLA devices."""
with session_lib.Session() as session:
x = array_ops.placeholder(dtypes.float32)
with ops.device("device:XLA_CPU:0"):
c = lambda i, _: math_ops.less(i, 5)
b = lambda i, x: (i + 1, x * 2.0 + 1.0)
_, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
result = session.run(y, {x: np.float32(2)})
self.assertAllClose(result, np.float32(95), rtol=1e-3)
def testCond(self):
"""Tests that tf.cond works on XLA devices."""
with session_lib.Session() as session:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
c = array_ops.placeholder(dtypes.bool)
with ops.device("device:XLA_CPU:0"):
z = x + 1.0
w = control_flow_ops.cond(c, lambda: z, lambda: y)
t = math_ops.add(z, w)
result = session.run(t, {x: np.float32(2), y: np.float32(4), c: True})
self.assertAllClose(result, np.float32(6), rtol=1e-3)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,148 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Definition of XLA test case."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import re
from tensorflow.contrib.compiler import jit
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
FLAGS = flags.FLAGS
flags.DEFINE_string('test_device', None,
'Tensorflow device on which to place operators under test')
flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.')
flags.DEFINE_string('disabled_manifest', None,
'Path to a file with a list of tests that should not run.')
class XLATestCase(test.TestCase):
"""XLA test cases are parameterized test cases."""
def __init__(self, method_name='runTest'):
super(XLATestCase, self).__init__(method_name)
self.device = FLAGS.test_device
self.has_custom_call = (self.device == 'XLA_CPU')
self.all_tf_types = [
dtypes.DType(types_pb2.DataType.Value(name))
for name in FLAGS.types.split(',')
]
self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
self.int_types = [
dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_integer
]
self.float_types = [
dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_floating
]
self.numeric_types = self.int_types + self.float_types
# Parse the manifest file, if any, into a regex identifying tests to
# disable
self.disabled_regex = None
if FLAGS.disabled_manifest is not None:
comments_re = re.compile('#.*$')
manifest_file = open(FLAGS.disabled_manifest, 'r')
lines = manifest_file.read().splitlines()
lines = [comments_re.sub('', l).strip() for l in lines]
self.disabled_regex = re.compile('|'.join(lines))
manifest_file.close()
def setUp(self):
name = '{}.{}'.format(type(self).__name__, self._testMethodName)
if self.disabled_regex is not None and self.disabled_regex.match(name):
logging.info('Disabled test case: %s', name)
self.skipTest('{} is disabled by manifest.'.format(name))
return
logging.info('Start test case: %s', name)
def tearDown(self):
logging.info('End test case: %s', self._testMethodName)
@contextlib.contextmanager
def test_session(self):
"""Custom implementation of test_session() for XLA tests.
We override the standard Tensorflow test_session() since it is too
specific to CPU and GPU tests. In particular, we want to disable soft
placement and explicitly assign ops to devices under test.
Yields:
A session to use when running a test case.
"""
graph = ops.Graph()
with session.Session(graph=graph) as sess, graph.as_default():
yield sess
@contextlib.contextmanager
def test_scope(self):
"""Test scope that runs tests on a Tensorflow/XLA device.
Uses a compilation_scope() to mark operators to compile.
Yields:
A scope to apply to the operators under test.
"""
with ops.device('device:{}:0'.format(self.device)):
yield
def Benchmark(tf_bench, builder_fn, use_xla_jit, device):
"""Build a graph and run benchmarks against it, with or without XLA.
Args:
tf_bench: An instance of tf.test.Benchmark, used to run the benchmark.
builder_fn: A function that builds a graph when invoked, and returns
(name, fetches), where name is the name of the test, and fetches
is a list of tensors to fetch as output.
use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF.
device: The tensorflow device to run on, e.g. "cpu", "gpu".
"""
with ops.Graph().as_default():
name = None
targets = []
with ops.device(device):
fetches = []
jit_scope = jit.experimental_jit_scope
with jit_scope(compile_ops=use_xla_jit):
name, fetches = builder_fn()
# We only want to benchmark the operations themselves, and not the data
# transfer of the result(s). Non-compiled identity ops ensure XLA
# doesn't know we're dropping the results, otherwise it might compile
# away the entire computation.
for fetch in fetches:
targets.append(array_ops.identity(fetch).op)
config = config_pb2.ConfigProto(allow_soft_placement=True)
with session.Session(config=config) as sess:
sess.run(variables.global_variables_initializer())
xla = 'xla_' if use_xla_jit else ''
tf_bench.run_op_benchmark(
sess, targets, name='%s_%s%s' % (name, xla, device))

Some files were not shown because too many files have changed in this diff Show More