Initial open-source release of XLA: Accelerated Linear Algebra.
XLA is a compiler-based linear algebra execution engine that targets CPUs, GPUs and custom accelerators. XLA is still experimental; we are releasing it early to get the community involved. Change: 143990941
This commit is contained in:
parent
7ad7e4dfae
commit
1e67c90e2c
20
configure
vendored
20
configure
vendored
@ -112,6 +112,26 @@ else
|
||||
sed -i -e "s/WITH_HDFS_SUPPORT = True/WITH_HDFS_SUPPORT = False/" tensorflow/core/platform/default/build_config.bzl
|
||||
fi
|
||||
|
||||
## Enable XLA.
|
||||
while [ "$TF_ENABLE_XLA" == "" ]; do
|
||||
read -p "Do you wish to build TensorFlow with the XLA just-in-time compiler (experimental)? [y/N] " INPUT
|
||||
case $INPUT in
|
||||
[Yy]* ) echo "XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=1;;
|
||||
[Nn]* ) echo "No XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;;
|
||||
"" ) echo "No XLA support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;;
|
||||
* ) echo "Invalid selection: " $INPUT;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ "$TF_ENABLE_XLA" == "1" ]; then
|
||||
# Update Bazel build configuration.
|
||||
perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl
|
||||
else
|
||||
# Update Bazel build configuration.
|
||||
perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl
|
||||
fi
|
||||
|
||||
|
||||
# Invoke python_config and set up symlinks to python includes
|
||||
./util/python/python_config.sh --setup "$PYTHON_BIN_PATH"
|
||||
|
||||
|
@ -95,6 +95,26 @@ filegroup(
|
||||
"//tensorflow/c:all_files",
|
||||
"//tensorflow/cc:all_files",
|
||||
"//tensorflow/cc/saved_model:all_files",
|
||||
"//tensorflow/compiler/aot:all_files",
|
||||
"//tensorflow/compiler/aot/tests:all_files",
|
||||
"//tensorflow/compiler/jit:all_files",
|
||||
"//tensorflow/compiler/jit/graphcycles:all_files",
|
||||
"//tensorflow/compiler/jit/legacy_flags:all_files",
|
||||
"//tensorflow/compiler/tests:all_files",
|
||||
"//tensorflow/compiler/tf2xla:all_files",
|
||||
"//tensorflow/compiler/tf2xla/kernels:all_files",
|
||||
"//tensorflow/compiler/xla:all_files",
|
||||
"//tensorflow/compiler/xla/client:all_files",
|
||||
"//tensorflow/compiler/xla/client/lib:all_files",
|
||||
"//tensorflow/compiler/xla/legacy_flags:all_files",
|
||||
"//tensorflow/compiler/xla/port:all_files",
|
||||
"//tensorflow/compiler/xla/service:all_files",
|
||||
"//tensorflow/compiler/xla/service/cpu:all_files",
|
||||
"//tensorflow/compiler/xla/service/gpu:all_files",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:all_files",
|
||||
"//tensorflow/compiler/xla/tests:all_files",
|
||||
"//tensorflow/compiler/xla/tools:all_files",
|
||||
"//tensorflow/contrib:all_files",
|
||||
"//tensorflow/contrib/android:all_files",
|
||||
"//tensorflow/contrib/bayesflow:all_files",
|
||||
|
218
tensorflow/compiler/aot/BUILD
Normal file
218
tensorflow/compiler/aot/BUILD
Normal file
@ -0,0 +1,218 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
|
||||
# Optional runtime utilities for use by code generated by tfcompile.
|
||||
cc_library(
|
||||
name = "runtime",
|
||||
srcs = ["runtime.cc"],
|
||||
hdrs = ["runtime.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "runtime_test",
|
||||
srcs = ["runtime_test.cc"],
|
||||
deps = [
|
||||
":runtime",
|
||||
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# Don't depend on this directly; this is only used for the benchmark test
|
||||
# generated by tf_library.
|
||||
cc_library(
|
||||
name = "tf_library_test_main",
|
||||
testonly = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//tensorflow/core:test_main"],
|
||||
)
|
||||
|
||||
xla_proto_library(
|
||||
name = "tfcompile_proto",
|
||||
srcs = ["tfcompile.proto"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfcompile_lib",
|
||||
srcs = [
|
||||
"codegen.cc",
|
||||
"compile.cc",
|
||||
"flags.cc",
|
||||
"tfcompile_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"codegen.h",
|
||||
"compile.h",
|
||||
"flags.h",
|
||||
"tfcompile_util.h",
|
||||
],
|
||||
deps = [
|
||||
":runtime", # needed by codegen to print aligned_buffer_bytes
|
||||
":tfcompile_proto",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "codegen_test",
|
||||
srcs = ["codegen_test.cc"],
|
||||
data = ["codegen_test_h.golden"],
|
||||
deps = [
|
||||
":tfcompile_lib",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tfcompile_util_test",
|
||||
srcs = ["tfcompile_util_test.cc"],
|
||||
deps = [
|
||||
":tfcompile_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "tfcompile",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":tfcompile_main"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfcompile_main",
|
||||
srcs = ["tfcompile_main.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tfcompile_lib",
|
||||
":tfcompile_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
# NOTE: Most end-to-end tests are in the "tests" subdirectory, to ensure that
|
||||
# tfcompile.bzl correctly handles usage from outside of the package that it is
|
||||
# defined in.
|
||||
|
||||
# A simple test of tf_library from a text protobuf, mostly to enable the
|
||||
# benchmark_test.
|
||||
tf_library(
|
||||
name = "test_graph_tfadd",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfadd.config.pbtxt",
|
||||
cpp_class = "AddComp",
|
||||
graph = "test_graph_tfadd.pbtxt",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
# Utility library for benchmark binaries, used by the *_benchmark rules that are
|
||||
# added by the tfcompile bazel macro.
|
||||
cc_library(
|
||||
name = "benchmark",
|
||||
srcs = ["benchmark.cc"],
|
||||
hdrs = ["benchmark.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
# The purpose of the benchmark library is to support building an aot
|
||||
# binary with minimal dependencies, to demonstrate small binary sizes.
|
||||
#
|
||||
# KEEP THE DEPENDENCIES MINIMAL.
|
||||
"//tensorflow/core:framework_lite",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_extra_android",
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "benchmark_test",
|
||||
srcs = ["benchmark_test.cc"],
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":benchmark",
|
||||
":test_graph_tfadd",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
test_suite(
|
||||
name = "all_tests",
|
||||
tags = ["manual"],
|
||||
tests = [
|
||||
":benchmark_test",
|
||||
":test_graph_tfadd_test",
|
||||
"//tensorflow/compiler/aot/tests:all_tests",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"benchmark_main.template", # used by tf_library(...,gen_benchmark=True)
|
||||
"test.cc", # used by tf_library(...,gen_test=True)
|
||||
])
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
138
tensorflow/compiler/aot/benchmark.cc
Normal file
138
tensorflow/compiler/aot/benchmark.cc
Normal file
@ -0,0 +1,138 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// The purpose of the benchmark library is to support building an aot binary
|
||||
// with minimal dependencies, to demonstrate small binary sizes.
|
||||
//
|
||||
// KEEP THE DEPENDENCIES MINIMAL.
|
||||
|
||||
#include "tensorflow/compiler/aot/benchmark.h"
|
||||
|
||||
#include <sys/time.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace benchmark {
|
||||
|
||||
// Returns current wall time in micros.
|
||||
//
|
||||
// TODO(b/33546473): Refactor tensorflow::Env::NowMicros() so that we can re-use
|
||||
// the implementation without pulling in all of the Env dependencies.
|
||||
static double NowMicros() {
|
||||
struct timeval tv;
|
||||
gettimeofday(&tv, NULL);
|
||||
return static_cast<uint64>(tv.tv_sec) * 1000000 + tv.tv_usec;
|
||||
}
|
||||
|
||||
void DumpStatsToStdout(const Stats& stats) {
|
||||
// Compute stats.
|
||||
std::vector<int64> sorted_us(stats.per_iter_us);
|
||||
std::sort(sorted_us.begin(), sorted_us.end());
|
||||
const size_t count_us = sorted_us.size();
|
||||
double sum_us = 0;
|
||||
size_t count_us_trimmed = 0;
|
||||
double sum_us_trimmed = 0;
|
||||
size_t count_us_best = 0;
|
||||
double sum_us_best = 0;
|
||||
static constexpr float trim_ratio = 0.25;
|
||||
static constexpr float best_ratio = 0.1;
|
||||
const size_t count_trimmed = count_us * trim_ratio;
|
||||
const size_t count_best = count_us * best_ratio;
|
||||
for (size_t i = 0; i < sorted_us.size(); ++i) {
|
||||
const int64 us = sorted_us[i];
|
||||
sum_us += us;
|
||||
if (i >= count_trimmed && i < count_us - count_trimmed) {
|
||||
sum_us_trimmed += us;
|
||||
++count_us_trimmed;
|
||||
}
|
||||
if (i < count_best) {
|
||||
sum_us_best += us;
|
||||
++count_us_best;
|
||||
}
|
||||
}
|
||||
// Prepare nicely-formatted data.
|
||||
const int kBufSize = 1000;
|
||||
char buf[kBufSize];
|
||||
snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
|
||||
const string label_trimmed(buf);
|
||||
snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
|
||||
const string label_best(buf);
|
||||
std::vector<std::pair<string, double>> groups = {
|
||||
{"Best:", sorted_us.front()},
|
||||
{"Worst:", sorted_us.back()},
|
||||
{"Median:", sorted_us[count_us / 2]},
|
||||
{"Mean:", sum_us / count_us},
|
||||
{label_trimmed, sum_us_trimmed / count_us_trimmed},
|
||||
{label_best, sum_us_best / count_us_best},
|
||||
};
|
||||
int max_label_size = 0;
|
||||
double max_us = 0;
|
||||
for (const auto& g : groups) {
|
||||
if (g.first.size() > max_label_size) {
|
||||
max_label_size = g.first.size();
|
||||
}
|
||||
if (g.second > max_us) {
|
||||
max_us = g.second;
|
||||
}
|
||||
}
|
||||
int max_digits = 1;
|
||||
while (max_us >= 10.0) {
|
||||
max_us /= 10.0;
|
||||
++max_digits;
|
||||
}
|
||||
// Dump stats out.
|
||||
printf("Benchmark ran %zu iterations over %lld us\n", count_us,
|
||||
stats.total_us);
|
||||
for (const auto& g : groups) {
|
||||
printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
|
||||
g.second);
|
||||
}
|
||||
}
|
||||
|
||||
void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) {
|
||||
// If neither max_seconds or max_iters is set, stop at kDefaultMicros.
|
||||
const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
|
||||
? Options::kDefaultMicros
|
||||
: options.max_micros;
|
||||
printf("Running benchmark for %lld us\n", max_us);
|
||||
const int64 start_us = NowMicros();
|
||||
int64 iters = 0;
|
||||
while (true) {
|
||||
const int64 iter_start_us = NowMicros();
|
||||
fn();
|
||||
const int64 end_us = NowMicros();
|
||||
// Collect stats and decide whether to stop.
|
||||
stats->per_iter_us.push_back(end_us - iter_start_us);
|
||||
const int64 total_us = end_us - start_us;
|
||||
++iters;
|
||||
if ((max_us > 0 && total_us >= max_us) ||
|
||||
(options.max_iters > 0 && iters >= options.max_iters)) {
|
||||
stats->total_us = total_us;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace benchmark
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
70
tensorflow/compiler/aot/benchmark.h
Normal file
70
tensorflow/compiler/aot/benchmark.h
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Contains benchmark functions used with the code-generated benchmarks that can
|
||||
// be used to test a model on android. See also code generation rules in
|
||||
// tfcompile.bzl.
|
||||
//
|
||||
// This is separate from the built-in micro-benchmarks, because we want to:
|
||||
// 1. show a binary with minimal dependencies, to show a close-to-lower-bound
|
||||
// binary size.
|
||||
// 2. compile on Android.
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_BENCHMARK_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_BENCHMARK_H_
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace benchmark {
|
||||
|
||||
// Options specifies options for benchmarks of functions generated by tfcompile.
|
||||
struct Options {
|
||||
// kDefaultMicros specifies the default time to run the benchmark, and is used
|
||||
// if neither max_iters nor max_micros is set.
|
||||
static const int64 kDefaultMicros = 3000000;
|
||||
|
||||
int64 max_iters = 0; // Maximum iterations to run, ignored if <= 0.
|
||||
int64 max_micros = 0; // Maximum microseconds to run, ignored if <= 0.
|
||||
};
|
||||
|
||||
// Stats holds statistics collected during benchmarking.
|
||||
struct Stats {
|
||||
std::vector<int64> per_iter_us; // Per-iteration deltas in us.
|
||||
int64 total_us; // Total time in us.
|
||||
|
||||
Stats() : total_us(0) { per_iter_us.reserve(5000); }
|
||||
};
|
||||
|
||||
// DumpStatsToStdout printfs to stdout stats in a multi-line human-friendly
|
||||
// form.
|
||||
void DumpStatsToStdout(const Stats& stats);
|
||||
|
||||
// BenchmarkFn is the signature of the function generated by tfcompile.
|
||||
typedef std::function<void()> BenchmarkFn;
|
||||
|
||||
// Benchmark runs a benchmark of the function `fn`, collecting stats in `stats`.
|
||||
// Use `options` to configure benchmarking options.
|
||||
void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats);
|
||||
|
||||
} // namespace benchmark
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_BENCHMARK_H_
|
51
tensorflow/compiler/aot/benchmark_main.template
Normal file
51
tensorflow/compiler/aot/benchmark_main.template
Normal file
@ -0,0 +1,51 @@
|
||||
// Generated by the tf_library build rule. DO NOT EDIT!
|
||||
//
|
||||
// This file contains the main function and logic for benchmarking code
|
||||
// generated by tfcompile. All tokens of the form `{{TFCOMPILE_*}}` must be
|
||||
// rewritten to real values before this file can be compiled.
|
||||
//
|
||||
// TFCOMPILE_HEADER : Path to the header file generated by tfcompile.
|
||||
// TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile.
|
||||
//
|
||||
// The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and
|
||||
// generates a cc_binary rule for you.
|
||||
|
||||
// These macros must be defined before eigen files are included.
|
||||
#define EIGEN_USE_THREADS
|
||||
#define EIGEN_USE_CUSTOM_THREAD_POOL
|
||||
|
||||
// clang-format off
|
||||
#include "{{TFCOMPILE_HEADER}}" // NOLINT(whitespace/braces)
|
||||
// clang-format on
|
||||
|
||||
#include "tensorflow/compiler/aot/benchmark.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
// Macros that expand to tokens based on the entry point name.
|
||||
// clang-format off
|
||||
#define CPP_CLASS {{TFCOMPILE_CPP_CLASS}} // NOLINT(whitespace/braces)
|
||||
// clang-format on
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
int Main(int argc, char** argv) {
|
||||
Eigen::ThreadPool pool(1 /* num_threads */);
|
||||
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
|
||||
|
||||
CPP_CLASS computation;
|
||||
computation.set_thread_pool(&device);
|
||||
|
||||
benchmark::Options options;
|
||||
benchmark::Stats stats;
|
||||
benchmark::Benchmark(options, [&] { computation.Run(); }, &stats);
|
||||
benchmark::DumpStatsToStdout(stats);
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
return tensorflow::tfcompile::Main(argc, argv);
|
||||
}
|
46
tensorflow/compiler/aot/benchmark_test.cc
Normal file
46
tensorflow/compiler/aot/benchmark_test.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/benchmark.h"
|
||||
|
||||
#include "tensorflow/compiler/aot/test_graph_tfadd.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace benchmark {
|
||||
namespace {
|
||||
|
||||
// There isn't much we can verify in a stable fashion, so we just run the
|
||||
// benchmark with max_iters, and ensure we end up with that many iter stats.
|
||||
TEST(Benchmark, Benchmark) {
|
||||
AddComp add;
|
||||
|
||||
Options options;
|
||||
options.max_iters = 1;
|
||||
Stats stats1;
|
||||
Benchmark(options, [&] { add.Run(); }, &stats1);
|
||||
EXPECT_EQ(stats1.per_iter_us.size(), 1);
|
||||
|
||||
options.max_iters = 5;
|
||||
Stats stats5;
|
||||
Benchmark(options, [&] { add.Run(); }, &stats5);
|
||||
EXPECT_EQ(stats5.per_iter_us.size(), 5);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace benchmark
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
579
tensorflow/compiler/aot/codegen.cc
Normal file
579
tensorflow/compiler/aot/codegen.cc
Normal file
@ -0,0 +1,579 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/aot/runtime.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/str_util.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
namespace {
|
||||
|
||||
// Convert an XLA type into a C++ type.
|
||||
Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
|
||||
switch (type) {
|
||||
case xla::PRED:
|
||||
*str = "bool";
|
||||
break;
|
||||
case xla::S8:
|
||||
*str = "tensorflow::int8";
|
||||
break;
|
||||
case xla::S16:
|
||||
*str = "tensorflow::int16";
|
||||
break;
|
||||
case xla::S32:
|
||||
*str = "tensorflow::int32";
|
||||
break;
|
||||
case xla::S64:
|
||||
*str = "tensorflow::int64";
|
||||
break;
|
||||
case xla::U8:
|
||||
*str = "tensorflow::uint8";
|
||||
break;
|
||||
case xla::U16:
|
||||
*str = "tensorflow::uint16";
|
||||
break;
|
||||
case xla::U32:
|
||||
*str = "tensorflow::uint32";
|
||||
break;
|
||||
case xla::U64:
|
||||
*str = "tensorflow::uint64";
|
||||
break;
|
||||
case xla::F32:
|
||||
*str = "float";
|
||||
break;
|
||||
case xla::F64:
|
||||
*str = "double";
|
||||
break;
|
||||
default:
|
||||
return errors::Unimplemented("XLA type ", xla::PrimitiveType_Name(type),
|
||||
" has no equivalent in C++");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1
|
||||
// values. There are `n` entries in `sizes`.
|
||||
size_t total_buffer_bytes(const intptr_t* sizes, size_t n) {
|
||||
size_t total = 0;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if (sizes[i] != -1) {
|
||||
total += sizes[i];
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
// Fills in arg_sizes with the byte size of each positional arg.
|
||||
Status ComputeArgSizes(const CompileResult& compile_result,
|
||||
std::vector<int64>* arg_sizes) {
|
||||
const xla::ProgramShape& ps = compile_result.program_shape;
|
||||
for (int i = 0; i < ps.parameters_size(); ++i) {
|
||||
if (i == ps.parameters_size() - 1 && compile_result.has_context_arg) {
|
||||
// If the compiled function needs a XlaLocalRuntimeContext* arg, it's
|
||||
// always last, and must be represented as an opaque type.
|
||||
const xla::PrimitiveType type = ps.parameters(i).element_type();
|
||||
if (type != xla::OPAQUE) {
|
||||
return errors::InvalidArgument(
|
||||
"expected final context arg to be opaque, but got type: ",
|
||||
xla::PrimitiveType_Name(type), ", from program shape: ",
|
||||
xla::ShapeUtil::HumanString(ps));
|
||||
}
|
||||
arg_sizes->push_back(-1);
|
||||
} else {
|
||||
arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf(
|
||||
ps.parameters(i), compile_result.pointer_size));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add (from,to) rewrite pairs based on the given shape. These rewrite pairs
|
||||
// are used to generate methods for args and results.
|
||||
Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
std::vector<std::pair<string, string>>* rewrites) {
|
||||
string type;
|
||||
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
|
||||
std::vector<string> dim_vars;
|
||||
string dim_sizes, indices;
|
||||
if (xla::ShapeUtil::Rank(shape) == 0 ||
|
||||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
|
||||
dim_sizes = "[1]";
|
||||
indices = "[0]";
|
||||
} else {
|
||||
for (int dim = 0; dim < shape.dimensions_size(); ++dim) {
|
||||
dim_vars.push_back(strings::StrCat("size_t dim", dim));
|
||||
dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]");
|
||||
indices += strings::StrCat("[dim", dim, "]");
|
||||
}
|
||||
}
|
||||
rewrites->push_back({"{{I}}", strings::StrCat(i)});
|
||||
rewrites->push_back({"{{TYPE}}", type});
|
||||
rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")});
|
||||
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
|
||||
rewrites->push_back({"{{INDICES}}", indices});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns code rewritten by replacing all rewrite pairs, with an extra rewrite
|
||||
// for the name. Note that the rewriting strategy is roughly O(N*M), where N is
|
||||
// the size of the code and M is the number of rewrites. It's fine for now
|
||||
// since N and M are pretty small.
|
||||
//
|
||||
// TODO(toddw): If this becomes a problem, we should be able to change the
|
||||
// algorithm to O(N) by using a state machine, e.g. regexps or a real
|
||||
// text-templating mechanism.
|
||||
string RewriteWithName(const string& name, string code,
|
||||
const std::vector<std::pair<string, string>>& rewrites) {
|
||||
str_util::ReplaceAllPairs(&code, rewrites);
|
||||
str_util::ReplaceAll(&code, "{{NAME}}", name);
|
||||
return code;
|
||||
}
|
||||
|
||||
// Generate methods for args (inputs).
|
||||
Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
|
||||
const CompileResult& compile_result, string* methods) {
|
||||
*methods += R"(
|
||||
void** args() { return args_; }
|
||||
const void *const *args() const { return args_; }
|
||||
)";
|
||||
size_t num_args = ps.parameters_size();
|
||||
if (compile_result.has_context_arg) {
|
||||
// If the compiled function needs a XlaLocalRuntimeContext* arg, it's
|
||||
// always last, and is set in the class constructor.
|
||||
num_args--;
|
||||
}
|
||||
if (config.feed_size() != num_args) {
|
||||
return errors::InvalidArgument("mismatch between feed_size(",
|
||||
config.feed_size(), ") and num_args(",
|
||||
num_args, ")");
|
||||
}
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
std::vector<std::pair<string, string>> rewrites;
|
||||
TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
|
||||
const string code = R"(
|
||||
void set_arg{{NAME}}_data(void* data) {
|
||||
args_[{{I}}] = data;
|
||||
}
|
||||
{{TYPE}}* arg{{NAME}}_data() {
|
||||
return static_cast<{{TYPE}}*>(args_[{{I}}]);
|
||||
}
|
||||
{{TYPE}}& arg{{NAME}}({{DIM_VARS}}) {
|
||||
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
|
||||
args_[{{I}}])){{INDICES}};
|
||||
}
|
||||
const {{TYPE}}* arg{{NAME}}_data() const {
|
||||
return static_cast<const {{TYPE}}*>(args_[{{I}}]);
|
||||
}
|
||||
const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const {
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
args_[{{I}}])){{INDICES}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName(strings::StrCat(i), code, rewrites);
|
||||
if (!config.feed(i).name().empty()) {
|
||||
*methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Generate methods for results (outputs).
|
||||
Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
|
||||
string* methods) {
|
||||
if (ps.result().element_type() != xla::TUPLE) {
|
||||
// Non-tuple (i.e. single-result) case.
|
||||
if (config.fetch_size() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"non-tuple result implies 1 fetch, but got ", config.fetch_size(),
|
||||
" fetches");
|
||||
}
|
||||
*methods += R"(
|
||||
void** results() { return temps_ + kResultIndex; }
|
||||
const void *const *results() const { return temps_ + kResultIndex; }
|
||||
)";
|
||||
std::vector<std::pair<string, string>> rewrites;
|
||||
TF_RETURN_IF_ERROR(AddRewritesForShape(0, ps.result(), &rewrites));
|
||||
const string code = R"(
|
||||
{{TYPE}}* result{{NAME}}_data() {
|
||||
return static_cast<{{TYPE}}*>(temps_[kResultIndex]);
|
||||
}
|
||||
{{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
|
||||
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
|
||||
temps_[kResultIndex])){{INDICES}};
|
||||
}
|
||||
const {{TYPE}}* result{{NAME}}_data() const {
|
||||
return static_cast<const {{TYPE}}*>(temps_[kResultIndex]);
|
||||
}
|
||||
const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
temps_[kResultIndex])){{INDICES}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName("0", code, rewrites);
|
||||
if (!config.fetch(0).name().empty()) {
|
||||
*methods += RewriteWithName("_" + config.fetch(0).name(), code, rewrites);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Tuple (i.e. multi-result) case.
|
||||
if (config.fetch_size() != ps.result().tuple_shapes_size()) {
|
||||
return errors::InvalidArgument("mismatch between fetch_size(",
|
||||
config.feed_size(), ") and tuple_size(",
|
||||
ps.result().tuple_shapes_size(), ")");
|
||||
}
|
||||
*methods += R"(
|
||||
void** results() {
|
||||
return static_cast<void**>(temps_[kResultIndex]);
|
||||
}
|
||||
const void *const *results() const {
|
||||
return static_cast<const void *const *>(temps_[kResultIndex]);
|
||||
}
|
||||
)";
|
||||
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
|
||||
std::vector<std::pair<string, string>> rewrites;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
|
||||
string code = R"(
|
||||
{{TYPE}}* result{{NAME}}_data() {
|
||||
return static_cast<{{TYPE}}*>(
|
||||
static_cast<void**>(temps_[kResultIndex])[{{I}}]);
|
||||
}
|
||||
{{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
|
||||
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
|
||||
static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}};
|
||||
}
|
||||
const {{TYPE}}* result{{NAME}}_data() const {
|
||||
return static_cast<{{TYPE}}*>(
|
||||
static_cast<void**>(temps_[kResultIndex])[{{I}}]);
|
||||
}
|
||||
const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName(strings::StrCat(i), code, rewrites);
|
||||
if (!config.fetch(i).name().empty()) {
|
||||
*methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status GenerateHeader(const HeaderOpts& opts, const Config& config,
|
||||
const CompileResult& compile_result, string* header) {
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
const int64 result_index = compile_result.aot->result_buffer_index();
|
||||
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
|
||||
if (result_index < 0 || result_index > temp_sizes.size()) {
|
||||
return errors::InvalidArgument("result index: ", result_index,
|
||||
" is outside the range of temp sizes: [0,",
|
||||
temp_sizes.size(), ")");
|
||||
}
|
||||
|
||||
// Compute sizes and generate methods.
|
||||
std::vector<int64> arg_sizes;
|
||||
TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes));
|
||||
const xla::ProgramShape& ps = compile_result.program_shape;
|
||||
string methods_arg, methods_result;
|
||||
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
|
||||
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
|
||||
const std::vector<intptr_t> iarg(arg_sizes.begin(), arg_sizes.end());
|
||||
const std::vector<intptr_t> itemp(temp_sizes.begin(), temp_sizes.end());
|
||||
const size_t arg_bytes_aligned =
|
||||
runtime::aligned_buffer_bytes(iarg.data(), iarg.size());
|
||||
const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size());
|
||||
const size_t temp_bytes_aligned =
|
||||
runtime::aligned_buffer_bytes(itemp.data(), itemp.size());
|
||||
const size_t temp_bytes_total =
|
||||
total_buffer_bytes(itemp.data(), itemp.size());
|
||||
|
||||
// Create rewrite strings for the optional context arg.
|
||||
string context_include;
|
||||
string context_set_arg, context_set_thread_pool, context_member_var;
|
||||
string run_result = "true";
|
||||
string error_msg = "tensorflow::string()";
|
||||
if (compile_result.has_context_arg) {
|
||||
// NOTE: Extra spaces and newlines are used to ensure nice formatting.
|
||||
context_include =
|
||||
"#include "
|
||||
"\"tensorflow/compiler/tf2xla/"
|
||||
"xla_local_runtime_context.h\"\n";
|
||||
context_set_arg = " args_[kNumArgs-1] = &context_;\n";
|
||||
context_set_thread_pool = " context_.thread_pool = pool;\n";
|
||||
context_member_var = " tensorflow::XlaLocalRuntimeContext context_;\n";
|
||||
run_result = "!context_.error";
|
||||
error_msg = "context_.error_msg";
|
||||
}
|
||||
|
||||
// Create rewrite strings for namespace start and end.
|
||||
string ns_start;
|
||||
for (const string& n : opts.namespaces) {
|
||||
ns_start += strings::StrCat("namespace ", n, " {\n");
|
||||
}
|
||||
ns_start += "\n";
|
||||
string ns_end("\n");
|
||||
for (int i = opts.namespaces.size() - 1; i >= 0; --i) {
|
||||
const string& n = opts.namespaces[i];
|
||||
ns_end += strings::StrCat("} // end namespace ", n, "\n");
|
||||
}
|
||||
|
||||
// Use a poor-man's text templating mechanism; first populate the full header
|
||||
// with placeholder tokens, and then rewrite the tokens with real values.
|
||||
*header =
|
||||
R"(// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT!
|
||||
//
|
||||
// This header was generated via ahead-of-time compilation of a TensorFlow
|
||||
// graph. An object file corresponding to this header was also generated.
|
||||
// This header gives access to the functionality in that object file.
|
||||
//
|
||||
// clang-format off
|
||||
|
||||
#ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
|
||||
#define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
|
||||
|
||||
{{CONTEXT_INCLUDE}}
|
||||
#include "tensorflow/compiler/aot/runtime.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace Eigen { class ThreadPoolDevice; }
|
||||
|
||||
// (Implementation detail) Entry point to the function in the object file.
|
||||
extern "C" void {{ENTRY}}(
|
||||
void* result, xla::ExecutableRunOptions* run_options,
|
||||
void** args, void** temps);
|
||||
|
||||
{{NS_START}}
|
||||
// {{CLASS}} represents a computation previously specified in a
|
||||
// TensorFlow graph, now compiled into executable code. Usage example:
|
||||
//
|
||||
// {{CLASS}} computation;
|
||||
// // ...set args using computation.argN methods
|
||||
// CHECK(computation.Run());
|
||||
// // ...inspect results using computation.resultN methods
|
||||
//
|
||||
// The Run method invokes the actual computation, with inputs read from arg
|
||||
// buffers, and outputs written to result buffers. Each Run call may also use
|
||||
// a set of temporary buffers for the computation.
|
||||
//
|
||||
// By default each instance of this class manages its own arg, result and temp
|
||||
// buffers. The AllocMode constructor parameter may be used to modify the
|
||||
// buffer allocation strategy.
|
||||
//
|
||||
// Under the default allocation strategy, this class is thread-compatible:
|
||||
// o Calls to non-const methods require exclusive access to the object.
|
||||
// o Concurrent calls to const methods are OK, if those calls are made while
|
||||
// it is guaranteed that no thread may call a non-const method.
|
||||
//
|
||||
// The logical function signature is:
|
||||
// {{PROGRAM_SHAPE}}
|
||||
//
|
||||
// Memory stats:
|
||||
// arg bytes total: {{ARG_BYTES_TOTAL}}
|
||||
// arg bytes aligned: {{ARG_BYTES_ALIGNED}}
|
||||
// temp bytes total: {{TEMP_BYTES_TOTAL}}
|
||||
// temp bytes aligned: {{TEMP_BYTES_ALIGNED}}
|
||||
class {{CLASS}} {
|
||||
public:
|
||||
// Number of input arguments for the compiled computation.
|
||||
static constexpr size_t kNumArgs = {{ARG_NUM}};
|
||||
|
||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||
static const intptr_t* ArgSizes() {
|
||||
static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}};
|
||||
return kArgSizes;
|
||||
}
|
||||
|
||||
// AllocMode controls the buffer allocation mode.
|
||||
enum class AllocMode {
|
||||
// Allocate all buffers - args, results and temps.
|
||||
ARGS_RESULTS_AND_TEMPS,
|
||||
|
||||
// Only allocate result and temp buffers.
|
||||
// Use set_argN_data to set argument buffers before Run is called.
|
||||
RESULTS_AND_TEMPS_ONLY,
|
||||
};
|
||||
|
||||
{{CLASS}}(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) {
|
||||
if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
|
||||
alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
|
||||
ArgSizes(), kNumArgs, args_, false /* annotate_initialized */);
|
||||
}
|
||||
{{CONTEXT_SET_ARG}}
|
||||
alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
|
||||
TempSizes(), kNumTemps, temps_, true /* annotate_initialized */);
|
||||
}
|
||||
|
||||
~{{CLASS}}() {
|
||||
tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
|
||||
tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
|
||||
}
|
||||
|
||||
// Sets the thread pool to use during the Run call.
|
||||
{{CLASS}}& set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
|
||||
run_options_.set_intra_op_thread_pool(pool);
|
||||
{{CONTEXT_SET_THREAD_POOL}}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Runs the computation, with inputs read from arg buffers, and outputs
|
||||
// written to result buffers. Returns true on success and false on failure.
|
||||
bool Run() {
|
||||
{{ENTRY}}(temps_[kResultIndex], &run_options_, args_, temps_);
|
||||
return {{RUN_RESULT}};
|
||||
}
|
||||
|
||||
// Returns the error message from the previous failed Run call.
|
||||
tensorflow::string error_msg() const { return {{ERROR_MSG}}; }
|
||||
|
||||
// Arg methods for managing input buffers. Buffers are in row-major order.
|
||||
// There is a set of methods for each positional argument, with the following
|
||||
// general form:
|
||||
//
|
||||
// void set_argN_data(void* data)
|
||||
// Sets the buffer of type T for positional argument N. May be called in
|
||||
// any AllocMode. Must be called before Run to have an affect. Must be
|
||||
// called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument,
|
||||
// to set the argument buffers.
|
||||
//
|
||||
// T* argN_data()
|
||||
// Returns the buffer of type T for positional argument N.
|
||||
//
|
||||
// T& argN(...dim indices...)
|
||||
// Returns a reference to the value of type T for positional argument N,
|
||||
// with dim indices specifying which value. No bounds checking is performed
|
||||
// on dim indices.
|
||||
//
|
||||
// void** args()
|
||||
// Returns an array of argument buffers, where args()[N] is the buffer for
|
||||
// positional argument N.
|
||||
{{METHODS_ARG}}
|
||||
|
||||
// Result methods for managing output buffers. Buffers are in row-major order.
|
||||
// Must only be called after a successful Run call. There is a set of methods
|
||||
// for each positional result, with the following general form:
|
||||
//
|
||||
// T* resultN_data()
|
||||
// Returns the buffer of type T for positional result N.
|
||||
//
|
||||
// T& resultN(...dim indices...)
|
||||
// Returns a reference to the value of type T for positional result N,
|
||||
// with dim indices specifying which value. No bounds checking is performed
|
||||
// on dim indices.
|
||||
//
|
||||
// void** results()
|
||||
// Returns an array of result buffers, where results()[N] is the buffer for
|
||||
// positional result N.
|
||||
//
|
||||
// Unlike the arg methods, there is no set_resultN_data method. The result
|
||||
// buffers are managed internally, and may change after each call to Run.
|
||||
{{METHODS_RESULT}}
|
||||
|
||||
private:
|
||||
// Number of result and temporary buffers for the compiled computation.
|
||||
static constexpr size_t kNumTemps = {{TEMP_NUM}};
|
||||
// The 0-based index of the result in the temporary buffers.
|
||||
static constexpr size_t kResultIndex = {{RESULT_INDEX}};
|
||||
|
||||
// Byte size of each result / temporary buffer. There are kNumTemps entries.
|
||||
static const intptr_t* TempSizes() {
|
||||
static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}};
|
||||
return kTempSizes;
|
||||
}
|
||||
|
||||
void* args_[kNumArgs];
|
||||
void* temps_[kNumTemps];
|
||||
void* alloc_args_ = nullptr;
|
||||
void* alloc_temps_ = nullptr;
|
||||
xla::ExecutableRunOptions run_options_;
|
||||
{{CONTEXT_MEMBER_VAR}}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN({{CLASS}});
|
||||
};
|
||||
{{NS_END}}
|
||||
|
||||
#endif // TFCOMPILE_GENERATED_{{ENTRY}}_H_
|
||||
|
||||
// clang-format on
|
||||
)";
|
||||
// The replacement strategy is naive, but good enough for our purposes.
|
||||
const std::vector<std::pair<string, string>> rewrites = {
|
||||
{"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
|
||||
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
|
||||
{"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
|
||||
{"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
|
||||
{"{{CLASS}}", opts.class_name},
|
||||
{"{{CONTEXT_INCLUDE}}\n", context_include},
|
||||
{"{{CONTEXT_MEMBER_VAR}}\n", context_member_var},
|
||||
{"{{CONTEXT_SET_ARG}}\n", context_set_arg},
|
||||
{"{{CONTEXT_SET_THREAD_POOL}}\n", context_set_thread_pool},
|
||||
{"{{ENTRY}}", compile_result.entry_point},
|
||||
{"{{ERROR_MSG}}", error_msg},
|
||||
{"{{METHODS_ARG}}\n", methods_arg},
|
||||
{"{{METHODS_RESULT}}\n", methods_result},
|
||||
{"{{NS_END}}\n", ns_end},
|
||||
{"{{NS_START}}\n", ns_start},
|
||||
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
|
||||
{"{{RESULT_INDEX}}", strings::StrCat(result_index)},
|
||||
{"{{RUN_RESULT}}", run_result},
|
||||
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
|
||||
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
|
||||
{"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
|
||||
{"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")},
|
||||
};
|
||||
str_util::ReplaceAllPairs(header, rewrites);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ParseCppClass(const string& cpp_class, string* class_name,
|
||||
std::vector<string>* namespaces) {
|
||||
class_name->clear();
|
||||
namespaces->clear();
|
||||
size_t begin = 0;
|
||||
size_t end = 0;
|
||||
while ((end = cpp_class.find("::", begin)) != string::npos) {
|
||||
const string ns = cpp_class.substr(begin, end - begin);
|
||||
TF_RETURN_IF_ERROR(ValidateCppIdent(
|
||||
ns, "in namespace component of cpp_class: " + cpp_class));
|
||||
namespaces->push_back(ns);
|
||||
begin = end + 2; // +2 to skip the two colons
|
||||
}
|
||||
const string name = cpp_class.substr(begin);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateCppIdent(name, "in class name of cpp_class: " + cpp_class));
|
||||
*class_name = name;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
53
tensorflow/compiler/aot/codegen.h
Normal file
53
tensorflow/compiler/aot/codegen.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_CODEGEN_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_CODEGEN_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// HeaderOpts specifies options for header-file generation.
|
||||
struct HeaderOpts {
|
||||
// The name of the generated C++ class, wrapping the generated function.
|
||||
string class_name;
|
||||
|
||||
// Namespaces specifies a list of C++ namespaces to add to the generated
|
||||
// header. If empty, all symbols will be in the global namespace.
|
||||
std::vector<string> namespaces;
|
||||
};
|
||||
|
||||
// GenerateHeader uses the meta-information from compile_result to generate a
|
||||
// C++ header giving access to the function in the generated object file. The
|
||||
// header includes API usage documentation.
|
||||
Status GenerateHeader(const HeaderOpts& opts, const Config& config,
|
||||
const CompileResult& compile_result, string* header);
|
||||
|
||||
// ParseCppClass parses `cpp_class` into its `class_name` and `namespaces`
|
||||
// components. The syntax is [[<optional_namespace>::],...]<class_name>. This
|
||||
// mirrors the C++ syntax for referring to a class, where multiple namespaces
|
||||
// may precede the class name, separated by double-colons.
|
||||
Status ParseCppClass(const string& cpp_class, string* class_name,
|
||||
std::vector<string>* namespaces);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_CODEGEN_H_
|
137
tensorflow/compiler/aot/codegen_test.cc
Normal file
137
tensorflow/compiler/aot/codegen_test.cc
Normal file
@ -0,0 +1,137 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace {
|
||||
|
||||
class ParseCppClassTest : public ::testing::Test {
|
||||
protected:
|
||||
void ExpectOK(const string& cpp_class, const string& want_class_name,
|
||||
const std::vector<string>& want_namespaces) {
|
||||
string class_name;
|
||||
std::vector<string> namespaces;
|
||||
TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces));
|
||||
EXPECT_EQ(class_name, want_class_name);
|
||||
EXPECT_EQ(namespaces, want_namespaces);
|
||||
}
|
||||
|
||||
void ExpectFail(const string& cpp_class) {
|
||||
string class_name;
|
||||
std::vector<string> namespaces;
|
||||
EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), Status::OK());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ParseCppClassTest, ParseOK) {
|
||||
ExpectOK("MyClass", "MyClass", {});
|
||||
ExpectOK("_MyClass", "_MyClass", {});
|
||||
ExpectOK("a::MyClass", "MyClass", {"a"});
|
||||
ExpectOK("a::foo::MyClass", "MyClass", {"a", "foo"});
|
||||
ExpectOK("a::foo::b::MyClass", "MyClass", {"a", "foo", "b"});
|
||||
ExpectOK("a::foo::b::bar::MyClass", "MyClass", {"a", "foo", "b", "bar"});
|
||||
ExpectOK("foo::MyClass", "MyClass", {"foo"});
|
||||
ExpectOK("_foo::MyClass", "MyClass", {"_foo"});
|
||||
ExpectOK("_foo::_MyClass", "_MyClass", {"_foo"});
|
||||
// Make sure we didn't skip a valid letter or digit
|
||||
string ident;
|
||||
for (char c = 'a'; c <= 'z'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
for (char c = 'A'; c <= 'Z'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
for (char c = '0'; c <= '9'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
ident += "_";
|
||||
ExpectOK(ident, ident, {});
|
||||
ExpectOK(ident + "::" + ident, ident, {ident});
|
||||
ExpectOK(ident + "::" + ident + "::" + ident, ident, {ident, ident});
|
||||
}
|
||||
|
||||
TEST_F(ParseCppClassTest, ParseFail) {
|
||||
ExpectFail("");
|
||||
ExpectFail("::");
|
||||
ExpectFail("::MyClass"); // valid C++, but disallowed for simpler code.
|
||||
ExpectFail("0");
|
||||
ExpectFail("a.b");
|
||||
ExpectFail("a:b");
|
||||
ExpectFail("good::.bad");
|
||||
ExpectFail("good:::bad");
|
||||
ExpectFail("good:: bad");
|
||||
ExpectFail("good::0bad");
|
||||
}
|
||||
|
||||
TEST(GenerateHeader, Golden) {
|
||||
HeaderOpts opts;
|
||||
opts.class_name = "MyClass";
|
||||
opts.namespaces = {"foo", "bar"};
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("feed0");
|
||||
feed->set_name("myfeed");
|
||||
feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("feed1");
|
||||
Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("fetch0");
|
||||
fetch->set_name("myfetch");
|
||||
CompileResult compile_result;
|
||||
compile_result.aot.reset(
|
||||
new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5));
|
||||
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
|
||||
{
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
||||
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
||||
xla::ShapeUtil::MakeOpaqueShape(),
|
||||
},
|
||||
xla::ShapeUtil::MakeShape(xla::U32, {5, 6}));
|
||||
compile_result.has_context_arg = true;
|
||||
compile_result.entry_point = "entry_point";
|
||||
compile_result.pointer_size = 8;
|
||||
string header;
|
||||
TF_EXPECT_OK(GenerateHeader(opts, config, compile_result, &header));
|
||||
|
||||
// Compare against the golden file.
|
||||
const string golden_name = io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||
"compiler/aot/codegen_test_h.golden");
|
||||
// To update the golden file, flip update_golden to true and run the
|
||||
// following:
|
||||
// bazel test --test_strategy=local \
|
||||
// third_party/tensorflow/compiler/aot:codegen_test
|
||||
const bool update_golden = false;
|
||||
if (update_golden) {
|
||||
TF_EXPECT_OK(WriteStringToFile(Env::Default(), golden_name, header));
|
||||
}
|
||||
string golden_data;
|
||||
TF_EXPECT_OK(ReadFileToString(Env::Default(), golden_name, &golden_data));
|
||||
EXPECT_EQ(header, golden_data);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
268
tensorflow/compiler/aot/codegen_test_h.golden
Normal file
268
tensorflow/compiler/aot/codegen_test_h.golden
Normal file
@ -0,0 +1,268 @@
|
||||
// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT!
|
||||
//
|
||||
// This header was generated via ahead-of-time compilation of a TensorFlow
|
||||
// graph. An object file corresponding to this header was also generated.
|
||||
// This header gives access to the functionality in that object file.
|
||||
//
|
||||
// clang-format off
|
||||
|
||||
#ifndef TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
|
||||
#define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
|
||||
#include "tensorflow/compiler/aot/runtime.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace Eigen { class ThreadPoolDevice; }
|
||||
|
||||
// (Implementation detail) Entry point to the function in the object file.
|
||||
extern "C" void entry_point(
|
||||
void* result, xla::ExecutableRunOptions* run_options,
|
||||
void** args, void** temps);
|
||||
|
||||
namespace foo {
|
||||
namespace bar {
|
||||
|
||||
// MyClass represents a computation previously specified in a
|
||||
// TensorFlow graph, now compiled into executable code. Usage example:
|
||||
//
|
||||
// MyClass computation;
|
||||
// // ...set args using computation.argN methods
|
||||
// CHECK(computation.Run());
|
||||
// // ...inspect results using computation.resultN methods
|
||||
//
|
||||
// The Run method invokes the actual computation, with inputs read from arg
|
||||
// buffers, and outputs written to result buffers. Each Run call may also use
|
||||
// a set of temporary buffers for the computation.
|
||||
//
|
||||
// By default each instance of this class manages its own arg, result and temp
|
||||
// buffers. The AllocMode constructor parameter may be used to modify the
|
||||
// buffer allocation strategy.
|
||||
//
|
||||
// Under the default allocation strategy, this class is thread-compatible:
|
||||
// o Calls to non-const methods require exclusive access to the object.
|
||||
// o Concurrent calls to const methods are OK, if those calls are made while
|
||||
// it is guaranteed that no thread may call a non-const method.
|
||||
//
|
||||
// The logical function signature is:
|
||||
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> u32[5,6]
|
||||
//
|
||||
// Memory stats:
|
||||
// arg bytes total: 104
|
||||
// arg bytes aligned: 128
|
||||
// temp bytes total: 126
|
||||
// temp bytes aligned: 224
|
||||
class MyClass {
|
||||
public:
|
||||
// Number of input arguments for the compiled computation.
|
||||
static constexpr size_t kNumArgs = 3;
|
||||
|
||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||
static const intptr_t* ArgSizes() {
|
||||
static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96, -1};
|
||||
return kArgSizes;
|
||||
}
|
||||
|
||||
// AllocMode controls the buffer allocation mode.
|
||||
enum class AllocMode {
|
||||
// Allocate all buffers - args, results and temps.
|
||||
ARGS_RESULTS_AND_TEMPS,
|
||||
|
||||
// Only allocate result and temp buffers.
|
||||
// Use set_argN_data to set argument buffers before Run is called.
|
||||
RESULTS_AND_TEMPS_ONLY,
|
||||
};
|
||||
|
||||
MyClass(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) {
|
||||
if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
|
||||
alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
|
||||
ArgSizes(), kNumArgs, args_, false /* annotate_initialized */);
|
||||
}
|
||||
args_[kNumArgs-1] = &context_;
|
||||
alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
|
||||
TempSizes(), kNumTemps, temps_, true /* annotate_initialized */);
|
||||
}
|
||||
|
||||
~MyClass() {
|
||||
tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
|
||||
tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
|
||||
}
|
||||
|
||||
// Sets the thread pool to use during the Run call.
|
||||
MyClass& set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
|
||||
run_options_.set_intra_op_thread_pool(pool);
|
||||
context_.thread_pool = pool;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Runs the computation, with inputs read from arg buffers, and outputs
|
||||
// written to result buffers. Returns true on success and false on failure.
|
||||
bool Run() {
|
||||
entry_point(temps_[kResultIndex], &run_options_, args_, temps_);
|
||||
return !context_.error;
|
||||
}
|
||||
|
||||
// Returns the error message from the previous failed Run call.
|
||||
tensorflow::string error_msg() const { return context_.error_msg; }
|
||||
|
||||
// Arg methods for managing input buffers. Buffers are in row-major order.
|
||||
// There is a set of methods for each positional argument, with the following
|
||||
// general form:
|
||||
//
|
||||
// void set_argN_data(void* data)
|
||||
// Sets the buffer of type T for positional argument N. May be called in
|
||||
// any AllocMode. Must be called before Run to have an affect. Must be
|
||||
// called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument,
|
||||
// to set the argument buffers.
|
||||
//
|
||||
// T* argN_data()
|
||||
// Returns the buffer of type T for positional argument N.
|
||||
//
|
||||
// T& argN(...dim indices...)
|
||||
// Returns a reference to the value of type T for positional argument N,
|
||||
// with dim indices specifying which value. No bounds checking is performed
|
||||
// on dim indices.
|
||||
//
|
||||
// void** args()
|
||||
// Returns an array of argument buffers, where args()[N] is the buffer for
|
||||
// positional argument N.
|
||||
|
||||
void** args() { return args_; }
|
||||
const void *const *args() const { return args_; }
|
||||
|
||||
void set_arg0_data(void* data) {
|
||||
args_[0] = data;
|
||||
}
|
||||
float* arg0_data() {
|
||||
return static_cast<float*>(args_[0]);
|
||||
}
|
||||
float& arg0(size_t dim0, size_t dim1) {
|
||||
return (*static_cast<float(*)[1][2]>(
|
||||
args_[0]))[dim0][dim1];
|
||||
}
|
||||
const float* arg0_data() const {
|
||||
return static_cast<const float*>(args_[0]);
|
||||
}
|
||||
const float& arg0(size_t dim0, size_t dim1) const {
|
||||
return (*static_cast<const float(*)[1][2]>(
|
||||
args_[0]))[dim0][dim1];
|
||||
}
|
||||
|
||||
void set_arg_myfeed_data(void* data) {
|
||||
args_[0] = data;
|
||||
}
|
||||
float* arg_myfeed_data() {
|
||||
return static_cast<float*>(args_[0]);
|
||||
}
|
||||
float& arg_myfeed(size_t dim0, size_t dim1) {
|
||||
return (*static_cast<float(*)[1][2]>(
|
||||
args_[0]))[dim0][dim1];
|
||||
}
|
||||
const float* arg_myfeed_data() const {
|
||||
return static_cast<const float*>(args_[0]);
|
||||
}
|
||||
const float& arg_myfeed(size_t dim0, size_t dim1) const {
|
||||
return (*static_cast<const float(*)[1][2]>(
|
||||
args_[0]))[dim0][dim1];
|
||||
}
|
||||
|
||||
void set_arg1_data(void* data) {
|
||||
args_[1] = data;
|
||||
}
|
||||
tensorflow::int64* arg1_data() {
|
||||
return static_cast<tensorflow::int64*>(args_[1]);
|
||||
}
|
||||
tensorflow::int64& arg1(size_t dim0, size_t dim1) {
|
||||
return (*static_cast<tensorflow::int64(*)[3][4]>(
|
||||
args_[1]))[dim0][dim1];
|
||||
}
|
||||
const tensorflow::int64* arg1_data() const {
|
||||
return static_cast<const tensorflow::int64*>(args_[1]);
|
||||
}
|
||||
const tensorflow::int64& arg1(size_t dim0, size_t dim1) const {
|
||||
return (*static_cast<const tensorflow::int64(*)[3][4]>(
|
||||
args_[1]))[dim0][dim1];
|
||||
}
|
||||
|
||||
// Result methods for managing output buffers. Buffers are in row-major order.
|
||||
// Must only be called after a successful Run call. There is a set of methods
|
||||
// for each positional result, with the following general form:
|
||||
//
|
||||
// T* resultN_data()
|
||||
// Returns the buffer of type T for positional result N.
|
||||
//
|
||||
// T& resultN(...dim indices...)
|
||||
// Returns a reference to the value of type T for positional result N,
|
||||
// with dim indices specifying which value. No bounds checking is performed
|
||||
// on dim indices.
|
||||
//
|
||||
// void** results()
|
||||
// Returns an array of result buffers, where results()[N] is the buffer for
|
||||
// positional result N.
|
||||
//
|
||||
// Unlike the arg methods, there is no set_resultN_data method. The result
|
||||
// buffers are managed internally, and may change after each call to Run.
|
||||
|
||||
void** results() { return temps_ + kResultIndex; }
|
||||
const void *const *results() const { return temps_ + kResultIndex; }
|
||||
|
||||
tensorflow::uint32* result0_data() {
|
||||
return static_cast<tensorflow::uint32*>(temps_[kResultIndex]);
|
||||
}
|
||||
tensorflow::uint32& result0(size_t dim0, size_t dim1) {
|
||||
return (*static_cast<tensorflow::uint32(*)[5][6]>(
|
||||
temps_[kResultIndex]))[dim0][dim1];
|
||||
}
|
||||
const tensorflow::uint32* result0_data() const {
|
||||
return static_cast<const tensorflow::uint32*>(temps_[kResultIndex]);
|
||||
}
|
||||
const tensorflow::uint32& result0(size_t dim0, size_t dim1) const {
|
||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||
temps_[kResultIndex]))[dim0][dim1];
|
||||
}
|
||||
|
||||
tensorflow::uint32* result_myfetch_data() {
|
||||
return static_cast<tensorflow::uint32*>(temps_[kResultIndex]);
|
||||
}
|
||||
tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) {
|
||||
return (*static_cast<tensorflow::uint32(*)[5][6]>(
|
||||
temps_[kResultIndex]))[dim0][dim1];
|
||||
}
|
||||
const tensorflow::uint32* result_myfetch_data() const {
|
||||
return static_cast<const tensorflow::uint32*>(temps_[kResultIndex]);
|
||||
}
|
||||
const tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) const {
|
||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||
temps_[kResultIndex]))[dim0][dim1];
|
||||
}
|
||||
|
||||
private:
|
||||
// Number of result and temporary buffers for the compiled computation.
|
||||
static constexpr size_t kNumTemps = 6;
|
||||
// The 0-based index of the result in the temporary buffers.
|
||||
static constexpr size_t kResultIndex = 5;
|
||||
|
||||
// Byte size of each result / temporary buffer. There are kNumTemps entries.
|
||||
static const intptr_t* TempSizes() {
|
||||
static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120};
|
||||
return kTempSizes;
|
||||
}
|
||||
|
||||
void* args_[kNumArgs];
|
||||
void* temps_[kNumTemps];
|
||||
void* alloc_args_ = nullptr;
|
||||
void* alloc_temps_ = nullptr;
|
||||
xla::ExecutableRunOptions run_options_;
|
||||
tensorflow::XlaLocalRuntimeContext context_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MyClass);
|
||||
};
|
||||
|
||||
} // end namespace bar
|
||||
} // end namespace foo
|
||||
|
||||
#endif // TFCOMPILE_GENERATED_entry_point_H_
|
||||
|
||||
// clang-format on
|
416
tensorflow/compiler/aot/compile.cc
Normal file
416
tensorflow/compiler/aot/compile.cc
Normal file
@ -0,0 +1,416 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
const char* const kArgOp = "_Arg";
|
||||
const char* const kRetvalOp = "_Retval";
|
||||
const char* const kFeedIdAttr = "_feed_id";
|
||||
const char* const kFetchIdAttr = "_fetch_id";
|
||||
const char* const kShapeAttr = "_shape";
|
||||
const char* const kDebugNameAttr = "_debug_name";
|
||||
|
||||
namespace {
|
||||
|
||||
Status DumpGraph(const MainFlags& flags, const string& name,
|
||||
const Graph& graph) {
|
||||
if (flags.debug_dir.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
GraphDef graph_def;
|
||||
graph.ToGraphDef(&graph_def);
|
||||
string file = io::JoinPath(flags.debug_dir, name + ".pbtxt");
|
||||
return WriteTextProto(Env::Default(), file, graph_def);
|
||||
}
|
||||
|
||||
string TensorIdToString(const TensorId& id) {
|
||||
return strings::StrCat(id.node_name(), ":", id.output_index());
|
||||
}
|
||||
|
||||
typedef std::unordered_map<string, Node*> NodeMap;
|
||||
|
||||
// Each feed id identifies the positional output of some node, which may consist
|
||||
// of multiple edges. For each feed node, replaces all matching edges so that
|
||||
// they point from a new _Arg node instead.
|
||||
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
|
||||
const protobuf::RepeatedPtrField<Feed>& feeds) {
|
||||
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
|
||||
const Feed& feed = feeds[arg_index];
|
||||
const TensorId& id = feed.id();
|
||||
auto it = node_map.find(id.node_name());
|
||||
if (it == node_map.end()) {
|
||||
return errors::NotFound("Can't find feed id: ", TensorIdToString(id));
|
||||
}
|
||||
const Node* feed_node = it->second;
|
||||
if (id.output_index() >= feed_node->num_outputs()) {
|
||||
return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id),
|
||||
", output index should be < ",
|
||||
feed_node->num_outputs());
|
||||
}
|
||||
// TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr
|
||||
// if we can determine it. That way the graph will be initialized with
|
||||
// whatever shapes we can infer, while the user can still explicitly specify
|
||||
// or override them.
|
||||
Node* arg_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
|
||||
.Attr("T", BaseType(feed_node->output_type(id.output_index())))
|
||||
.Attr("index", arg_index)
|
||||
.Attr(kFeedIdAttr, TensorIdToString(id))
|
||||
.Attr(kShapeAttr, TensorShape(feed.shape()))
|
||||
.Attr(kDebugNameAttr, feed.name())
|
||||
.Finalize(graph, &arg_node));
|
||||
// Collects out-edges from the feed node that have a matching edge index;
|
||||
// these will be replaced with edges from the arg node instead. Also
|
||||
// replaces all control edges from Placeholder feed nodes; similar code
|
||||
// exists in subgraph::RewriteGraphForExecution.
|
||||
// TODO(toddw): Why only replace control edges from Placeholder?
|
||||
//
|
||||
// We must collect the edges first and process them in a second pass, since
|
||||
// removing the edge from the graph invalidates feed_node->out_edges.
|
||||
std::vector<const Edge*> feed_edges;
|
||||
for (const Edge* edge : feed_node->out_edges()) {
|
||||
if (edge->src_output() == id.output_index() ||
|
||||
(edge->src_output() == Graph::kControlSlot &&
|
||||
feed_node->type_string() == "Placeholder")) {
|
||||
feed_edges.push_back(edge);
|
||||
}
|
||||
}
|
||||
for (const Edge* edge : feed_edges) {
|
||||
if (edge->src_output() == id.output_index()) {
|
||||
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
|
||||
} else {
|
||||
CHECK_EQ(edge->src_output(), Graph::kControlSlot);
|
||||
graph->AddControlEdge(arg_node, edge->dst());
|
||||
}
|
||||
graph->RemoveEdge(edge);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Each fetch id identifies the positional output of some node. For each fetch
|
||||
// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
|
||||
Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
|
||||
const protobuf::RepeatedPtrField<Fetch>& fetches,
|
||||
std::unordered_set<const Node*>* retval_nodes) {
|
||||
for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
|
||||
const TensorId& id = fetches[ret_index].id();
|
||||
auto it = node_map.find(id.node_name());
|
||||
if (it == node_map.end()) {
|
||||
return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
|
||||
}
|
||||
Node* fetch_node = it->second;
|
||||
if (id.output_index() >= fetch_node->num_outputs()) {
|
||||
return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
|
||||
", output index should be < ",
|
||||
fetch_node->num_outputs());
|
||||
}
|
||||
// Connects fetch_node -> retval_node.
|
||||
Node* retval_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
|
||||
.Input(fetch_node, id.output_index())
|
||||
.Attr("T", BaseType(fetch_node->output_type(id.output_index())))
|
||||
.Attr("index", ret_index)
|
||||
.Attr(kFetchIdAttr, TensorIdToString(id))
|
||||
.Finalize(graph, &retval_node));
|
||||
retval_nodes->insert(retval_node);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// RewriteAndPruneGraph identifies input and output edges (named by the feed and
|
||||
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
|
||||
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
|
||||
// execution to know the input and output args for the generated function.
|
||||
Status RewriteAndPruneGraph(Graph* graph, const Config& config,
|
||||
const MainFlags& flags) {
|
||||
NodeMap node_map;
|
||||
for (Node* n : graph->nodes()) {
|
||||
node_map[n->name()] = n;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed()));
|
||||
std::unordered_set<const Node*> retval_nodes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
||||
TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_rewrite", *graph));
|
||||
PruneForReverseReachability(graph, retval_nodes);
|
||||
FixupSourceAndSinkEdges(graph);
|
||||
TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_prune", *graph));
|
||||
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
|
||||
std::set<string> missing_feeds, missing_fetches;
|
||||
for (const Feed& feed : config.feed()) {
|
||||
missing_feeds.insert(TensorIdToString(feed.id()));
|
||||
}
|
||||
for (const Fetch& fetch : config.fetch()) {
|
||||
missing_fetches.insert(TensorIdToString(fetch.id()));
|
||||
}
|
||||
for (const Node* n : graph->nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
string feed_id;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id));
|
||||
if (missing_feeds.erase(feed_id) == 0) {
|
||||
return errors::Aborted(kArgOp, " node found with unknown feed id: ",
|
||||
feed_id);
|
||||
}
|
||||
} else if (n->type_string() == kRetvalOp) {
|
||||
string fetch_id;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id));
|
||||
if (missing_fetches.erase(fetch_id) == 0) {
|
||||
return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ",
|
||||
fetch_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!missing_feeds.empty() || !missing_fetches.empty()) {
|
||||
return errors::Aborted("Post graph-pruning", ", missing feeds: ",
|
||||
str_util::Join(missing_feeds, ", "),
|
||||
", missing fetches: ",
|
||||
str_util::Join(missing_fetches, ", "));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// CollectArgNodes collects _Arg nodes from the graph, and performs basic
|
||||
// sanity-checking to ensure the index and type attributes of each node are
|
||||
// initialized correctly.
|
||||
Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
|
||||
std::map<int, Node*> indexed_arg_nodes;
|
||||
for (Node* n : graph.nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
|
||||
auto insert_result = indexed_arg_nodes.insert({index, n});
|
||||
if (!insert_result.second) {
|
||||
const Node* dup = insert_result.first->second;
|
||||
return errors::InvalidArgument(
|
||||
"Multiple ", kArgOp, " nodes with index ", index, ", ",
|
||||
n->DebugString(), " and ", dup->DebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_nodes->clear();
|
||||
for (const auto& index_node : indexed_arg_nodes) {
|
||||
if (index_node.first != arg_nodes->size()) {
|
||||
return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
|
||||
arg_nodes->size(), ", but got index ",
|
||||
index_node.first);
|
||||
}
|
||||
arg_nodes->push_back(index_node.second);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Fills in xla_args from the corresponding _Arg nodes in the graph.
|
||||
Status CreateXlaArgs(const Graph& graph,
|
||||
std::vector<XlaCompiler::Argument>* xla_args) {
|
||||
std::vector<Node*> arg_nodes;
|
||||
TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
|
||||
for (const Node* node : arg_nodes) {
|
||||
XlaCompiler::Argument arg;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &arg.parameter));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name));
|
||||
xla_args->push_back(arg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts the TensorFlow graph into an XLA computation, by executing the
|
||||
// graph symbolically, with each op building up the XLA HLO.
|
||||
Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
|
||||
const FunctionLibraryDefinition* flib_def,
|
||||
xla::Computation* computation, bool* has_context_arg) {
|
||||
// Create a device and context to convert the graph into an XLA computation.
|
||||
XlaOpRegistry::RegisterJitKernels();
|
||||
// Populate the context with args from the graph.
|
||||
for (Node* node : graph->nodes()) {
|
||||
node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
|
||||
}
|
||||
std::vector<XlaCompiler::Argument> xla_args;
|
||||
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
|
||||
|
||||
// Compile the graph into an XLA computation.
|
||||
XlaCompiler::Options compiler_options;
|
||||
compiler_options.client = client;
|
||||
compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
|
||||
compiler_options.allow_cpu_custom_calls = true;
|
||||
XlaCompiler compiler(compiler_options);
|
||||
|
||||
std::unique_ptr<FunctionLibraryRuntime> flib_run(NewFunctionLibraryRuntime(
|
||||
compiler.device_mgr(), Env::Default(), compiler.device(),
|
||||
graph->versions().producer(), flib_def, OptimizerOptions()));
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_RETURN_IF_ERROR(compiler.CompileGraph("tfcompile", std::move(graph),
|
||||
flib_run.get(), xla_args,
|
||||
false /* use_tuple_arg */, &result));
|
||||
*has_context_arg = result.requires_runtime_context;
|
||||
*computation = std::move(result.computation);
|
||||
|
||||
int num_const_results = 0;
|
||||
for (int i = 0; i < result.outputs.size(); ++i) {
|
||||
// Ending up with const results (i.e. output args) is an error, since it
|
||||
// means that one or more fetches that the user specified will be dropped
|
||||
// from the generated function. It's most likely a configuration error,
|
||||
// since the user shouldn't be asking for output args that end up as consts.
|
||||
//
|
||||
// TODO(toddw): Provide a way for the user to access const output args,
|
||||
// e.g. perhaps hard-coded into the header, or somehow copied into the
|
||||
// output buffers.
|
||||
if (result.outputs[i].is_constant) {
|
||||
++num_const_results;
|
||||
LOG(ERROR) << "ConstRetVal index:" << i
|
||||
<< " value:" << result.outputs[i].constant_value.DebugString();
|
||||
}
|
||||
}
|
||||
if (num_const_results > 0) {
|
||||
return errors::Unimplemented(
|
||||
"Conversion from TensorFlow graph to XLA resulted in ",
|
||||
num_const_results,
|
||||
" constant results. The configuration of "
|
||||
"the output args (i.e. fetch ids) is probably wrong.");
|
||||
}
|
||||
if (computation->IsNull()) {
|
||||
return errors::Aborted(
|
||||
"Conversion from TensorFlow graph to XLA resulted in an empty "
|
||||
"computation.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Compiles the XLA computation into executable code.
|
||||
Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
|
||||
const xla::cpu::CpuAotCompilationOptions& aot_opts,
|
||||
CompileResult* compile_result) {
|
||||
// Retrieves arg and result layouts from the computation.
|
||||
// TODO(toddw): Should we let the user choose the major/minor ordering?
|
||||
xla::StatusOr<std::unique_ptr<xla::ProgramShape>> pshape_or =
|
||||
client->GetComputationShape(computation);
|
||||
if (!pshape_or.ok()) {
|
||||
return errors::Unknown("Couldn't get XLA program shape: ",
|
||||
pshape_or.status().error_message());
|
||||
}
|
||||
compile_result->program_shape = *pshape_or.ValueOrDie();
|
||||
xla::ProgramShape* pshape = &compile_result->program_shape;
|
||||
std::vector<const xla::Shape*> arg_layouts;
|
||||
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
||||
arg_layouts.push_back(pshape->mutable_parameters(i));
|
||||
}
|
||||
xla::StatusOr<std::unique_ptr<xla::AotCompilationResult>> aot_or =
|
||||
client->CompileAheadOfTime(computation, arg_layouts, pshape->result(),
|
||||
aot_opts);
|
||||
if (!aot_or.ok()) {
|
||||
return errors::Unknown("XLA compilation failed: ",
|
||||
aot_or.status().error_message());
|
||||
}
|
||||
compile_result->aot =
|
||||
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
|
||||
aot_or.ConsumeValueOrDie());
|
||||
compile_result->entry_point = aot_opts.entry_point_name();
|
||||
compile_result->pointer_size =
|
||||
xla::LocalClient::PointerSizeForTriple(aot_opts.triple());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status InitGraph(const GraphDef& graph_def, const Config& config,
|
||||
const MainFlags& flags, const FunctionLibraryDefinition* flib,
|
||||
std::unique_ptr<Graph>* graph) {
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
std::unique_ptr<Graph> g(new Graph(flib));
|
||||
GraphDef copy_def(graph_def);
|
||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©_def, *g->op_registry(),
|
||||
0 /*node_offset*/));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get()));
|
||||
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags));
|
||||
*graph = std::move(g);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
|
||||
const FunctionLibraryDefinition* flib,
|
||||
CompileResult* compile_result) {
|
||||
// Converts the graph into an XLA computation, and compiles the
|
||||
// computation.
|
||||
// TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
|
||||
namespace gpu = perftools::gputools;
|
||||
gpu::Platform* cpu_platform =
|
||||
gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
|
||||
xla::LocalClient* client =
|
||||
xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie();
|
||||
xla::Computation computation;
|
||||
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), flib,
|
||||
&computation,
|
||||
&compile_result->has_context_arg));
|
||||
if (!flags.debug_dir.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
|
||||
computation.Snapshot());
|
||||
string file = io::JoinPath(flags.debug_dir, "tfcompile_xla_module.pb");
|
||||
TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), file, *module));
|
||||
}
|
||||
xla::cpu::CpuAotCompilationOptions aot_opts(
|
||||
flags.target_triple, flags.target_cpu, flags.target_features,
|
||||
flags.entry_point,
|
||||
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
|
||||
return CompileXla(client, computation, aot_opts, compile_result);
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
92
tensorflow/compiler/aot/compile.h
Normal file
92
tensorflow/compiler/aot/compile.h
Normal file
@ -0,0 +1,92 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_COMPILE_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_COMPILE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// Constants for op types and attribute names.
|
||||
extern const char* const kArgOp;
|
||||
extern const char* const kRetvalOp;
|
||||
extern const char* const kFeedIdAttr;
|
||||
extern const char* const kFetchIdAttr;
|
||||
extern const char* const kShapeAttr;
|
||||
extern const char* const kDebugNameAttr;
|
||||
|
||||
// InitGraph creates a graph based on the graph_def, that may then be compiled
|
||||
// by CompileGraph.
|
||||
//
|
||||
// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
|
||||
// and outputs of the function that will be compiled. Each feed id causes a new
|
||||
// _Arg node to be created, where we first collect all existing edges pointing
|
||||
// from the named node's output index, and then rewrite them to point from that
|
||||
// _Arg node instead. Each fetch id causes a new _Retval node to be created,
|
||||
// with a new edge pointing from the named node's output index to that _Retval
|
||||
// node. All _Retval nodes also point to a special CompileExpressions node,
|
||||
// used internally to finish the compilation.
|
||||
//
|
||||
// The rewritten graph is then pruned to only contain the portion necessary to
|
||||
// compute the outputs. If dump_graphs is true, graph rewrites will be dumped
|
||||
// for debugging.
|
||||
Status InitGraph(const GraphDef& graph_def, const Config& config,
|
||||
const MainFlags& flags, const FunctionLibraryDefinition* flib,
|
||||
std::unique_ptr<Graph>* graph);
|
||||
|
||||
// CompileResult describes the output of CompileGraph, where the object file
|
||||
// data and meta-information is available in aot.
|
||||
struct CompileResult {
|
||||
// Contains object file and meta-info.
|
||||
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
||||
xla::ProgramShape program_shape; // Static shape of args and results.
|
||||
bool has_context_arg = false; // Is last arg XlaLocalRuntimeContext?
|
||||
string entry_point; // Name of generated function.
|
||||
int pointer_size = 0; // Size of a pointer in bytes.
|
||||
};
|
||||
|
||||
// CompileGraph compiles the graph into an object file containing a function
|
||||
// that performs the graph operations.
|
||||
//
|
||||
// The graph must have _Arg and _Retval nodes representing the function inputs
|
||||
// and outputs. Every _Arg node must have a shape attribute (key=kShapeAttr,
|
||||
// value=TensorShape) representing the static shape of that input, and every
|
||||
// _Retval node must point to a CompileExpressions node.
|
||||
//
|
||||
// Typically InitGraph is called to perform this initialization, followed by
|
||||
// full specification of the shape attributes.
|
||||
//
|
||||
// The XLA compilation options are specified in the flags.
|
||||
Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
|
||||
const FunctionLibraryDefinition* flib,
|
||||
CompileResult* result);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_COMPILE_H_
|
72
tensorflow/compiler/aot/flags.cc
Normal file
72
tensorflow/compiler/aot/flags.cc
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
||||
const std::vector<Flag> tmp = {
|
||||
{"graph", &flags->graph,
|
||||
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
|
||||
"be in the human-readable proto text format, otherwise it is expected "
|
||||
"to be in the proto binary format."},
|
||||
{"config", &flags->config,
|
||||
"Input file containing Config proto. If the file ends in '.pbtxt' it "
|
||||
"is expected to be in the human-readable proto text format, otherwise "
|
||||
"it is expected to be in the proto binary format."},
|
||||
{"dump_fetch_nodes", &flags->dump_fetch_nodes,
|
||||
"If set, only flags related to fetches are processed, and the resulting "
|
||||
"fetch nodes will be dumped to stdout in a comma-separated list. "
|
||||
"Typically used to format arguments for other tools, e.g. "
|
||||
"freeze_graph."},
|
||||
{"debug_dir", &flags->debug_dir,
|
||||
"Specifies a directory to dump debugging information, including "
|
||||
"rewritten graphs and the XLA HLO module."},
|
||||
// Flags controlling the XLA ahead-of-time compilation, that correspond to
|
||||
// the fields of xla::cpu::CpuAotCompilationOptions.
|
||||
//
|
||||
// TODO(toddw): The following flags also need to be supported:
|
||||
// --xla_cpu_llvm_opt_level
|
||||
// --xla_cpu_llvm_cl_opts
|
||||
{"target_triple", &flags->target_triple,
|
||||
"Target platform, similar to the clang -target flag. The general "
|
||||
"format is <arch><sub>-<vendor>-<sys>-<abi>. "
|
||||
"http://clang.llvm.org/docs/CrossCompilation.html#target-triple."},
|
||||
{"target_cpu", &flags->target_cpu,
|
||||
"Target cpu, similar to the clang -mcpu flag. "
|
||||
"http://clang.llvm.org/docs/CrossCompilation.html#cpu-fpu-abi"},
|
||||
{"target_features", &flags->target_features,
|
||||
"Target features, e.g. +avx2, +neon, etc."},
|
||||
{"entry_point", &flags->entry_point,
|
||||
"Name of the generated function. If multiple generated object files "
|
||||
"will be linked into the same binary, each will need a unique entry "
|
||||
"point."},
|
||||
{"cpp_class", &flags->cpp_class,
|
||||
"Name of the generated C++ class, wrapping the generated function. The "
|
||||
"syntax of this flag is [[<optional_namespace>::],...]<class_name>. "
|
||||
"This mirrors the C++ syntax for referring to a class, where multiple "
|
||||
"namespaces may precede the class name, separated by double-colons. "
|
||||
"The class will be generated in the given namespace(s), or if no "
|
||||
"namespaces are given, within the global namespace."},
|
||||
{"out_object", &flags->out_object, "Output object file name."},
|
||||
{"out_header", &flags->out_header, "Output header file name."},
|
||||
};
|
||||
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
48
tensorflow/compiler/aot/flags.h
Normal file
48
tensorflow/compiler/aot/flags.h
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_FLAGS_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// Flags for the tfcompile binary. See *.cc file for descriptions.
|
||||
struct MainFlags {
|
||||
string graph;
|
||||
string config;
|
||||
bool dump_fetch_nodes = false;
|
||||
string debug_dir;
|
||||
string target_triple;
|
||||
string target_cpu;
|
||||
string target_features;
|
||||
string entry_point;
|
||||
string cpp_class;
|
||||
string out_object;
|
||||
string out_header;
|
||||
};
|
||||
|
||||
// Appends to flag_list a tensorflow::Flag for each field in MainFlags.
|
||||
void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_FLAGS_H_
|
98
tensorflow/compiler/aot/runtime.cc
Normal file
98
tensorflow/compiler/aot/runtime.cc
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/runtime.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace runtime {
|
||||
|
||||
namespace {
|
||||
|
||||
// Inline memory allocation routines here, because depending on '//base' brings
|
||||
// in libraries which use c++ streams, which adds considerable code size on
|
||||
// android.
|
||||
inline void* aligned_malloc(size_t size, int minimum_alignment) {
|
||||
#if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN)
|
||||
return memalign(minimum_alignment, size);
|
||||
#else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN
|
||||
void* ptr = nullptr;
|
||||
// posix_memalign requires that the requested alignment be at least
|
||||
// sizeof(void*). In this case, fall back on malloc which should return memory
|
||||
// aligned to at least the size of a pointer.
|
||||
const int required_alignment = sizeof(void*);
|
||||
if (minimum_alignment < required_alignment) return malloc(size);
|
||||
if (posix_memalign(&ptr, minimum_alignment, size) != 0)
|
||||
return nullptr;
|
||||
else
|
||||
return ptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void aligned_free(void* aligned_memory) { free(aligned_memory); }
|
||||
|
||||
size_t align_to(size_t n, size_t align) {
|
||||
return (((n - 1) / align) + 1) * align;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) {
|
||||
size_t total = 0;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if (sizes[i] != -1) {
|
||||
total += align_to(sizes[i], kAlign);
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
|
||||
bool annotate_initialized) {
|
||||
const size_t total = aligned_buffer_bytes(sizes, n);
|
||||
void* contiguous = nullptr;
|
||||
if (total > 0) {
|
||||
contiguous = aligned_malloc(total, kAlign);
|
||||
if (annotate_initialized) {
|
||||
// Since the memory for temp buffers is written to by JITed code, msan has
|
||||
// no way of knowing the memory was initialized, so explicitly mark it.
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(contiguous, total);
|
||||
}
|
||||
}
|
||||
uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if (sizes[i] == -1) {
|
||||
bufs[i] = nullptr;
|
||||
} else {
|
||||
bufs[i] = reinterpret_cast<void*>(pos);
|
||||
pos += align_to(sizes[i], kAlign);
|
||||
}
|
||||
}
|
||||
return contiguous;
|
||||
}
|
||||
|
||||
void FreeContiguous(void* contiguous) {
|
||||
if (contiguous != nullptr) {
|
||||
aligned_free(contiguous);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
58
tensorflow/compiler/aot/runtime.h
Normal file
58
tensorflow/compiler/aot/runtime.h
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file contains utilities to make it easier to invoke functions generated
|
||||
// by tfcompile. Usage of these utilities is optional.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_RUNTIME_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_RUNTIME_H_
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace runtime {
|
||||
|
||||
// Align to 32-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
|
||||
static constexpr size_t kAlign = 32;
|
||||
|
||||
// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1
|
||||
// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign
|
||||
// byte boundaries.
|
||||
size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n);
|
||||
|
||||
// MallocContiguousBuffers allocates buffers for use by the entry point
|
||||
// generated by tfcompile. `sizes` is an array of byte sizes for each buffer,
|
||||
// where -1 causes the buffer pointer to be nullptr. There are `n` entries in
|
||||
// `sizes`. If `annotate_initialized` is set, the allocated memory will be
|
||||
// annotated as having been initialized - this is useful when allocating
|
||||
// temporary buffers.
|
||||
//
|
||||
// A single contiguous block of memory is allocated, and portions of it are
|
||||
// parceled out into `bufs`, which must have space for `n` entries. Returns the
|
||||
// head of the allocated contiguous block, which should be passed to
|
||||
// FreeContiguous when the buffers are no longer in use.
|
||||
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
|
||||
bool annotate_initialized);
|
||||
|
||||
// FreeContiguous frees the contiguous block of memory allocated by
|
||||
// MallocContiguousBuffers.
|
||||
void FreeContiguous(void* contiguous);
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_RUNTIME_H_
|
125
tensorflow/compiler/aot/runtime_test.cc
Normal file
125
tensorflow/compiler/aot/runtime_test.cc
Normal file
@ -0,0 +1,125 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/runtime.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace runtime {
|
||||
namespace {
|
||||
|
||||
TEST(Runtime, AlignmentValue) {
|
||||
// We've chosen 32 byte alignment for the tfcompile runtime to mimic the
|
||||
// regular tensorflow allocator, which was chosen to play nicely with Eigen.
|
||||
// The tfcompile runtime also has a requirement that comes from the xla
|
||||
// generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8
|
||||
// So any value that we choose must abide by that constraint as well.
|
||||
EXPECT_EQ(kAlign, Allocator::kAllocatorAlignment);
|
||||
}
|
||||
|
||||
TEST(Runtime, AlignedBufferBytes) {
|
||||
EXPECT_EQ(aligned_buffer_bytes(nullptr, 0), 0);
|
||||
|
||||
static constexpr intptr_t sizesA[1] = {-1};
|
||||
EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0);
|
||||
|
||||
static constexpr intptr_t sizesB[1] = {3};
|
||||
EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 32);
|
||||
|
||||
static constexpr intptr_t sizesC[1] = {32};
|
||||
EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 32);
|
||||
|
||||
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
|
||||
EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 192);
|
||||
}
|
||||
|
||||
void* add_ptr(void* base, uintptr_t delta) {
|
||||
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(base) + delta);
|
||||
}
|
||||
|
||||
// To test MallocContiguousBuffers and FreeContiguous, we just check for
|
||||
// expected nullptrs, and write to each byte of allocated memory. We rely on
|
||||
// the leak checker to tell us if there's an inconsistency between malloc and
|
||||
// free. We also check the contiguous property.
|
||||
TEST(Runtime, MallocFreeContiguousBuffers) {
|
||||
// Test empty sizes.
|
||||
void* base = MallocContiguousBuffers(nullptr, 0, nullptr, false);
|
||||
EXPECT_EQ(base, nullptr);
|
||||
FreeContiguous(base);
|
||||
|
||||
// Test non-empty sizes with 0 sum.
|
||||
static constexpr intptr_t sizesA[1] = {-1};
|
||||
void* bufA[1];
|
||||
base = MallocContiguousBuffers(sizesA, 1, bufA, false);
|
||||
EXPECT_EQ(base, nullptr);
|
||||
EXPECT_EQ(bufA[0], nullptr);
|
||||
FreeContiguous(base);
|
||||
|
||||
// Test non-empty sizes with non-0 sum.
|
||||
static constexpr intptr_t sizesB[1] = {3};
|
||||
void* bufB[1];
|
||||
base = MallocContiguousBuffers(sizesB, 1, bufB, false);
|
||||
EXPECT_NE(base, nullptr);
|
||||
EXPECT_EQ(bufB[0], add_ptr(base, 0));
|
||||
char* bufB0_bytes = static_cast<char*>(bufB[0]);
|
||||
bufB0_bytes[0] = 'A';
|
||||
bufB0_bytes[1] = 'B';
|
||||
bufB0_bytes[2] = 'C';
|
||||
FreeContiguous(base);
|
||||
|
||||
// Test non-empty sizes with non-0 sum, and annotate_initialized.
|
||||
static constexpr intptr_t sizesC[1] = {3};
|
||||
void* bufC[1];
|
||||
base = MallocContiguousBuffers(sizesC, 1, bufC, true);
|
||||
EXPECT_NE(base, nullptr);
|
||||
EXPECT_EQ(bufC[0], add_ptr(base, 0));
|
||||
char* bufC0_bytes = static_cast<char*>(bufC[0]);
|
||||
bufC0_bytes[0] = 'A';
|
||||
bufC0_bytes[1] = 'B';
|
||||
bufC0_bytes[2] = 'C';
|
||||
FreeContiguous(base);
|
||||
|
||||
// Test mixed sizes.
|
||||
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
|
||||
void* bufD[7];
|
||||
base = MallocContiguousBuffers(sizesD, 7, bufD, false);
|
||||
EXPECT_NE(base, nullptr);
|
||||
EXPECT_EQ(bufD[0], add_ptr(base, 0));
|
||||
EXPECT_EQ(bufD[1], nullptr);
|
||||
EXPECT_EQ(bufD[2], add_ptr(base, 32));
|
||||
EXPECT_EQ(bufD[3], nullptr);
|
||||
EXPECT_EQ(bufD[4], add_ptr(base, 64));
|
||||
EXPECT_EQ(bufD[5], add_ptr(base, 128));
|
||||
EXPECT_EQ(bufD[6], add_ptr(base, 160));
|
||||
for (int i = 0; i < 7; ++i) {
|
||||
const intptr_t size = sizesD[i];
|
||||
if (size != -1) {
|
||||
char* bufD_bytes = static_cast<char*>(bufD[i]);
|
||||
for (size_t j = 0; j < size; ++j) {
|
||||
bufD_bytes[j] = 'A' + j;
|
||||
}
|
||||
}
|
||||
}
|
||||
FreeContiguous(base);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace runtime
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
94
tensorflow/compiler/aot/test.cc
Normal file
94
tensorflow/compiler/aot/test.cc
Normal file
@ -0,0 +1,94 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Generated by the tf_library build rule. DO NOT EDIT!
|
||||
//
|
||||
// This file contains a test and benchmark for the function generated by
|
||||
// tfcompile. All tokens of the form `{{TFCOMPILE_*}}` must be rewritten to
|
||||
// real values before this file can be compiled.
|
||||
//
|
||||
// TFCOMPILE_HEADER : Path to the header file generated by tfcompile.
|
||||
// TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile.
|
||||
// TFCOMPILE_NAME : Name for tests and benchmarks.
|
||||
//
|
||||
// The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and
|
||||
// generates a cc_test rule for you.
|
||||
|
||||
// These macros must be defined before eigen files are included.
|
||||
#define EIGEN_USE_THREADS
|
||||
#define EIGEN_USE_CUSTOM_THREAD_POOL
|
||||
|
||||
// clang-format off
|
||||
#include "{{TFCOMPILE_HEADER}}" // NOLINT(whitespace/braces)
|
||||
// clang-format on
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
// Macros that expand to tokens based on the entry point name.
|
||||
// clang-format off
|
||||
#define CPP_CLASS {{TFCOMPILE_CPP_CLASS}} // NOLINT(whitespace/braces)
|
||||
#define TEST_NAME {{TFCOMPILE_NAME}}Test // NOLINT(whitespace/braces)
|
||||
#define BM_NAME BM_{{TFCOMPILE_NAME}} // NOLINT(whitespace/braces)
|
||||
// clang-format on
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace {
|
||||
|
||||
void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (sizes[i] != -1) {
|
||||
memset(bufs[i], 0, sizes[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Trivial test that runs the generated function to ensure it doesn't crash.
|
||||
TEST(TEST_NAME, NoCrash) {
|
||||
Eigen::ThreadPool pool(port::NumSchedulableCPUs());
|
||||
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
|
||||
|
||||
CPP_CLASS computation;
|
||||
computation.set_thread_pool(&device);
|
||||
zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
|
||||
|
||||
EXPECT_TRUE(computation.Run());
|
||||
}
|
||||
|
||||
// Simple benchmark that repeatedly runs the generated function.
|
||||
void BM_NAME(int iters) {
|
||||
testing::StopTiming();
|
||||
|
||||
Eigen::ThreadPool pool(port::NumSchedulableCPUs());
|
||||
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
|
||||
|
||||
CPP_CLASS computation;
|
||||
computation.set_thread_pool(&device);
|
||||
zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
|
||||
|
||||
testing::StartTiming();
|
||||
while (--iters) {
|
||||
computation.Run();
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
BENCHMARK(BM_NAME);
|
||||
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
16
tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
Normal file
16
tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
Normal file
@ -0,0 +1,16 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_sum" }
|
||||
}
|
63
tensorflow/compiler/aot/test_graph_tfadd.pbtxt
Normal file
63
tensorflow/compiler/aot/test_graph_tfadd.pbtxt
Normal file
@ -0,0 +1,63 @@
|
||||
node {
|
||||
name : "x_const"
|
||||
op : "Const"
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "dtype"
|
||||
value {
|
||||
type : DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "y_const"
|
||||
op : "Const"
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
int_val: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "x_y_sum"
|
||||
op : "Add"
|
||||
input : "x_const"
|
||||
input : "y_const"
|
||||
attr {
|
||||
key : "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 15
|
||||
}
|
146
tensorflow/compiler/aot/tests/BUILD
Normal file
146
tensorflow/compiler/aot/tests/BUILD
Normal file
@ -0,0 +1,146 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
|
||||
test_suite(
|
||||
name = "all_tests",
|
||||
tags = ["manual"],
|
||||
tests = [
|
||||
":test_graph_tfadd_test",
|
||||
":test_graph_tfadd_with_ckpt_saver_test",
|
||||
":test_graph_tfadd_with_ckpt_test",
|
||||
":test_graph_tfgather_test",
|
||||
":test_graph_tfmatmul_test",
|
||||
":test_graph_tfmatmulandadd_test",
|
||||
":tfcompile_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "make_test_graphs",
|
||||
testonly = 1,
|
||||
srcs = ["make_test_graphs.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python", # TODO(b/34059704): remove when fixed
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "gen_test_graphs",
|
||||
testonly = 1,
|
||||
outs = [
|
||||
"test_graph_tfadd.pb",
|
||||
"test_graph_tfadd_with_ckpt.pb",
|
||||
"test_graph_tfadd_with_ckpt.ckpt",
|
||||
"test_graph_tfadd_with_ckpt_saver.pb",
|
||||
"test_graph_tfadd_with_ckpt_saver.ckpt",
|
||||
"test_graph_tfadd_with_ckpt_saver.saver",
|
||||
"test_graph_tfgather.pb",
|
||||
"test_graph_tfmatmul.pb",
|
||||
"test_graph_tfmatmulandadd.pb",
|
||||
],
|
||||
cmd = "$(location :make_test_graphs) --out_dir $(@D)",
|
||||
tags = ["manual"],
|
||||
tools = [":make_test_graphs"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfadd",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfadd.config.pbtxt",
|
||||
cpp_class = "AddComp",
|
||||
graph = "test_graph_tfadd.pb",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfadd_with_ckpt",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfadd_with_ckpt.config.pbtxt",
|
||||
cpp_class = "AddWithCkptComp",
|
||||
freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt",
|
||||
graph = "test_graph_tfadd_with_ckpt.pb",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfadd_with_ckpt_saver",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfadd_with_ckpt.config.pbtxt",
|
||||
cpp_class = "AddWithCkptSaverComp",
|
||||
freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt",
|
||||
freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver",
|
||||
graph = "test_graph_tfadd_with_ckpt_saver.pb",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfgather",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfgather.config.pbtxt",
|
||||
cpp_class = "GatherComp",
|
||||
graph = "test_graph_tfgather.pb",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfmatmul",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfmatmul.config.pbtxt",
|
||||
cpp_class = "foo::bar::MatMulComp",
|
||||
graph = "test_graph_tfmatmul.pb",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfmatmulandadd",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfmatmulandadd.config.pbtxt",
|
||||
cpp_class = "MatMulAndAddComp",
|
||||
graph = "test_graph_tfmatmulandadd.pb",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tfcompile_test",
|
||||
srcs = ["tfcompile_test.cc"],
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":test_graph_tfadd",
|
||||
":test_graph_tfadd_with_ckpt",
|
||||
":test_graph_tfadd_with_ckpt_saver",
|
||||
":test_graph_tfgather",
|
||||
":test_graph_tfmatmul",
|
||||
":test_graph_tfmatmulandadd",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
119
tensorflow/compiler/aot/tests/make_test_graphs.py
Normal file
119
tensorflow/compiler/aot/tests/make_test_graphs.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Generate tensorflow graphs for testing tfcompile."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags as flags_lib
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
flags = flags_lib
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('out_dir', '',
|
||||
'Output directory for graphs, checkpoints and savers.')
|
||||
|
||||
|
||||
def tfadd():
|
||||
x = constant_op.constant([1], name='x_const')
|
||||
y = constant_op.constant([2], name='y_const')
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
|
||||
|
||||
def tfadd_with_ckpt():
|
||||
x = array_ops.placeholder(dtypes.int32, name='x_hold')
|
||||
y = variables.Variable(constant_op.constant([0]), name='y_saved')
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
|
||||
init_op = variables.initialize_all_variables()
|
||||
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
|
||||
with session.Session() as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(y.assign(y + 42))
|
||||
# Without the checkpoint, the variable won't be set to 42.
|
||||
ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % FLAGS.out_dir
|
||||
saver.save(sess, ckpt)
|
||||
|
||||
|
||||
def tfadd_with_ckpt_saver():
|
||||
x = array_ops.placeholder(dtypes.int32, name='x_hold')
|
||||
y = variables.Variable(constant_op.constant([0]), name='y_saved')
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
|
||||
init_op = variables.initialize_all_variables()
|
||||
saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1)
|
||||
with session.Session() as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(y.assign(y + 42))
|
||||
# Without the checkpoint, the variable won't be set to 42.
|
||||
ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % FLAGS.out_dir
|
||||
saver.save(sess, ckpt_file)
|
||||
# Without the SaverDef, the restore op won't be named correctly.
|
||||
saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % FLAGS.out_dir
|
||||
with open(saver_file, 'w') as f:
|
||||
f.write(saver.as_saver_def().SerializeToString())
|
||||
|
||||
|
||||
def tfgather():
|
||||
params = array_ops.placeholder(dtypes.float32, name='params')
|
||||
indices = array_ops.placeholder(dtypes.int32, name='indices')
|
||||
array_ops.gather(params, indices, name='gather_output')
|
||||
|
||||
|
||||
def tfmatmul():
|
||||
x = array_ops.placeholder(dtypes.float32, name='x_hold')
|
||||
y = array_ops.placeholder(dtypes.float32, name='y_hold')
|
||||
math_ops.matmul(x, y, name='x_y_prod')
|
||||
|
||||
|
||||
def tfmatmulandadd():
|
||||
# This tests multiple outputs.
|
||||
x = array_ops.placeholder(dtypes.float32, name='x_hold')
|
||||
y = array_ops.placeholder(dtypes.float32, name='y_hold')
|
||||
math_ops.matmul(x, y, name='x_y_prod')
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
|
||||
|
||||
def write_graph(build_graph):
|
||||
"""Build a graph using build_graph and write it out."""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
build_graph()
|
||||
filename = '%s/test_graph_%s.pb' % (FLAGS.out_dir, build_graph.__name__)
|
||||
with open(filename, 'w') as f:
|
||||
f.write(g.as_graph_def().SerializeToString())
|
||||
|
||||
|
||||
def main(_):
|
||||
write_graph(tfadd)
|
||||
write_graph(tfadd_with_ckpt)
|
||||
write_graph(tfadd_with_ckpt_saver)
|
||||
write_graph(tfgather)
|
||||
write_graph(tfmatmul)
|
||||
write_graph(tfmatmulandadd)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run()
|
16
tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt
Normal file
16
tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt
Normal file
@ -0,0 +1,16 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_sum" }
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_hold" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_sum" }
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed {
|
||||
id { node_name: "params" }
|
||||
shape {
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "indices" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "gather_output" }
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_hold" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 3 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y_hold" }
|
||||
shape {
|
||||
dim { size: 3 }
|
||||
dim { size: 2 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_prod" }
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_hold" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 2 }
|
||||
}
|
||||
name: "x"
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y_hold" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 2 }
|
||||
}
|
||||
name: "y"
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_prod" }
|
||||
name: "x_y_prod"
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_sum" }
|
||||
name: "x_y_sum"
|
||||
}
|
381
tensorflow/compiler/aot/tests/tfcompile_test.cc
Normal file
381
tensorflow/compiler/aot/tests/tfcompile_test.cc
Normal file
@ -0,0 +1,381 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#define EIGEN_USE_CUSTOM_THREAD_POOL
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace {
|
||||
|
||||
TEST(TFCompileTest, Add) {
|
||||
AddComp add;
|
||||
EXPECT_EQ(add.arg0_data(), add.args()[0]);
|
||||
EXPECT_EQ(add.arg1_data(), add.args()[1]);
|
||||
|
||||
add.arg0() = 1;
|
||||
add.arg1() = 2;
|
||||
EXPECT_TRUE(add.Run());
|
||||
EXPECT_EQ(add.error_msg(), "");
|
||||
EXPECT_EQ(add.result0(), 3);
|
||||
EXPECT_EQ(add.result0_data()[0], 3);
|
||||
EXPECT_EQ(add.result0_data(), add.results()[0]);
|
||||
|
||||
add.arg0_data()[0] = 123;
|
||||
add.arg1_data()[0] = 456;
|
||||
EXPECT_TRUE(add.Run());
|
||||
EXPECT_EQ(add.error_msg(), "");
|
||||
EXPECT_EQ(add.result0(), 579);
|
||||
EXPECT_EQ(add.result0_data()[0], 579);
|
||||
EXPECT_EQ(add.result0_data(), add.results()[0]);
|
||||
|
||||
const AddComp& add_const = add;
|
||||
EXPECT_EQ(add_const.error_msg(), "");
|
||||
EXPECT_EQ(add_const.arg0(), 123);
|
||||
EXPECT_EQ(add_const.arg0_data()[0], 123);
|
||||
EXPECT_EQ(add_const.arg0_data(), add.args()[0]);
|
||||
EXPECT_EQ(add_const.arg1(), 456);
|
||||
EXPECT_EQ(add_const.arg1_data()[0], 456);
|
||||
EXPECT_EQ(add_const.arg1_data(), add.args()[1]);
|
||||
EXPECT_EQ(add_const.result0(), 579);
|
||||
EXPECT_EQ(add_const.result0_data()[0], 579);
|
||||
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
||||
}
|
||||
|
||||
// Run tests that use set_argN_data separately, to avoid accidentally re-using
|
||||
// non-existent buffers.
|
||||
TEST(TFCompileTest, Add_SetArg) {
|
||||
AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY);
|
||||
|
||||
int32 arg_x = 10;
|
||||
int32 arg_y = 32;
|
||||
add.set_arg0_data(&arg_x);
|
||||
add.set_arg1_data(&arg_y);
|
||||
EXPECT_EQ(add.arg0_data(), add.args()[0]);
|
||||
EXPECT_EQ(add.arg1_data(), add.args()[1]);
|
||||
|
||||
EXPECT_TRUE(add.Run());
|
||||
EXPECT_EQ(add.error_msg(), "");
|
||||
EXPECT_EQ(add.result0(), 42);
|
||||
EXPECT_EQ(add.result0_data()[0], 42);
|
||||
EXPECT_EQ(add.result0_data(), add.results()[0]);
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, AddWithCkpt) {
|
||||
AddWithCkptComp add;
|
||||
EXPECT_EQ(add.arg0_data(), add.args()[0]);
|
||||
|
||||
add.arg0() = 1;
|
||||
EXPECT_TRUE(add.Run());
|
||||
EXPECT_EQ(add.error_msg(), "");
|
||||
EXPECT_EQ(add.result0(), 43);
|
||||
EXPECT_EQ(add.result0_data()[0], 43);
|
||||
EXPECT_EQ(add.result0_data(), add.results()[0]);
|
||||
|
||||
add.arg0_data()[0] = 111;
|
||||
EXPECT_TRUE(add.Run());
|
||||
EXPECT_EQ(add.error_msg(), "");
|
||||
EXPECT_EQ(add.result0(), 153);
|
||||
EXPECT_EQ(add.result0_data()[0], 153);
|
||||
EXPECT_EQ(add.result0_data(), add.results()[0]);
|
||||
|
||||
const AddWithCkptComp& add_const = add;
|
||||
EXPECT_EQ(add_const.error_msg(), "");
|
||||
EXPECT_EQ(add_const.arg0(), 111);
|
||||
EXPECT_EQ(add_const.arg0_data()[0], 111);
|
||||
EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
|
||||
EXPECT_EQ(add_const.result0(), 153);
|
||||
EXPECT_EQ(add_const.result0_data()[0], 153);
|
||||
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, AddWithCkptSaver) {
|
||||
AddWithCkptSaverComp add;
|
||||
EXPECT_EQ(add.arg0_data(), add.args()[0]);
|
||||
|
||||
add.arg0() = 1;
|
||||
EXPECT_TRUE(add.Run());
|
||||
EXPECT_EQ(add.error_msg(), "");
|
||||
EXPECT_EQ(add.result0(), 43);
|
||||
EXPECT_EQ(add.result0_data()[0], 43);
|
||||
EXPECT_EQ(add.result0_data(), add.results()[0]);
|
||||
|
||||
add.arg0_data()[0] = 111;
|
||||
EXPECT_TRUE(add.Run());
|
||||
EXPECT_EQ(add.error_msg(), "");
|
||||
EXPECT_EQ(add.result0(), 153);
|
||||
EXPECT_EQ(add.result0_data()[0], 153);
|
||||
EXPECT_EQ(add.result0_data(), add.results()[0]);
|
||||
|
||||
const AddWithCkptSaverComp& add_const = add;
|
||||
EXPECT_EQ(add_const.error_msg(), "");
|
||||
EXPECT_EQ(add_const.arg0(), 111);
|
||||
EXPECT_EQ(add_const.arg0_data()[0], 111);
|
||||
EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
|
||||
EXPECT_EQ(add_const.result0(), 153);
|
||||
EXPECT_EQ(add_const.result0_data()[0], 153);
|
||||
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, Gather) {
|
||||
GatherComp gather;
|
||||
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
|
||||
EXPECT_EQ(gather.arg1_data(), gather.args()[1]);
|
||||
|
||||
// Successful gather.
|
||||
{
|
||||
const float params[4] = {1, 2, 3, 4};
|
||||
std::copy(params + 0, params + 4, gather.arg0_data());
|
||||
const int32 indices[2] = {1, 3};
|
||||
std::copy(indices + 0, indices + 2, gather.arg1_data());
|
||||
EXPECT_TRUE(gather.Run());
|
||||
EXPECT_EQ(gather.error_msg(), "");
|
||||
const float results[2] = {2, 4};
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
EXPECT_EQ(gather.result0(i), results[i]);
|
||||
EXPECT_EQ(gather.result0_data()[i], results[i]);
|
||||
}
|
||||
EXPECT_EQ(gather.result0_data(), gather.results()[0]);
|
||||
|
||||
const GatherComp& gather_const = gather;
|
||||
EXPECT_EQ(gather_const.error_msg(), "");
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(gather_const.arg0(i), params[i]);
|
||||
EXPECT_EQ(gather_const.arg0_data()[i], params[i]);
|
||||
}
|
||||
EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]);
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
EXPECT_EQ(gather_const.arg1(i), indices[i]);
|
||||
EXPECT_EQ(gather_const.arg1_data()[i], indices[i]);
|
||||
}
|
||||
EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]);
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
EXPECT_EQ(gather_const.result0(i), results[i]);
|
||||
EXPECT_EQ(gather_const.result0_data()[i], results[i]);
|
||||
}
|
||||
EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
|
||||
}
|
||||
|
||||
// Bad indices returns an error.
|
||||
{
|
||||
const float params[4] = {1, 2, 3, 4};
|
||||
std::copy(params + 0, params + 4, gather.arg0_data());
|
||||
const int32 indices[2] = {1, 4};
|
||||
std::copy(indices + 0, indices + 2, gather.arg1_data());
|
||||
EXPECT_FALSE(gather.Run());
|
||||
EXPECT_EQ(gather.error_msg(), "Invalid index for gather");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, MatMul2) {
|
||||
Eigen::ThreadPool tp(2);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
|
||||
foo::bar::MatMulComp matmul;
|
||||
matmul.set_thread_pool(&device);
|
||||
EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
|
||||
EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
|
||||
|
||||
// Test using the argN() methods.
|
||||
{
|
||||
matmul.arg0(0, 0) = 1;
|
||||
matmul.arg0(0, 1) = 2;
|
||||
matmul.arg0(0, 2) = 3;
|
||||
matmul.arg0(1, 0) = 4;
|
||||
matmul.arg0(1, 1) = 5;
|
||||
matmul.arg0(1, 2) = 6;
|
||||
|
||||
matmul.arg1(0, 0) = 7;
|
||||
matmul.arg1(0, 1) = 8;
|
||||
matmul.arg1(1, 0) = 9;
|
||||
matmul.arg1(1, 1) = 10;
|
||||
matmul.arg1(2, 0) = 11;
|
||||
matmul.arg1(2, 1) = 12;
|
||||
|
||||
EXPECT_TRUE(matmul.Run());
|
||||
EXPECT_EQ(matmul.error_msg(), "");
|
||||
const float results[4] = {58, 64, 139, 154};
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
|
||||
EXPECT_EQ(matmul.result0_data()[i], results[i]);
|
||||
}
|
||||
EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
|
||||
}
|
||||
|
||||
// Test using the argN_data() methods.
|
||||
{
|
||||
const float args[12] = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120};
|
||||
std::copy(args + 0, args + 6, matmul.arg0_data());
|
||||
std::copy(args + 6, args + 12, matmul.arg1_data());
|
||||
EXPECT_TRUE(matmul.Run());
|
||||
EXPECT_EQ(matmul.error_msg(), "");
|
||||
const float results[4] = {5800, 6400, 13900, 15400};
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
|
||||
EXPECT_EQ(matmul.result0_data()[i], results[i]);
|
||||
}
|
||||
EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
|
||||
|
||||
const foo::bar::MatMulComp& matmul_const = matmul;
|
||||
EXPECT_EQ(matmul_const.error_msg(), "");
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]);
|
||||
EXPECT_EQ(matmul_const.arg0_data()[i], args[i]);
|
||||
}
|
||||
EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]);
|
||||
EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]);
|
||||
}
|
||||
EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]);
|
||||
EXPECT_EQ(matmul_const.result0_data()[i], results[i]);
|
||||
}
|
||||
EXPECT_EQ(matmul_const.result0_data(), matmul.results()[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// Run tests that use set_argN_data separately, to avoid accidentally re-using
|
||||
// non-existent buffers.
|
||||
TEST(TFCompileTest, MatMul2_SetArg) {
|
||||
Eigen::ThreadPool tp(2);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
|
||||
foo::bar::MatMulComp matmul(
|
||||
foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY);
|
||||
matmul.set_thread_pool(&device);
|
||||
|
||||
// Test using the set_argN_data() methods.
|
||||
float arg0[2][3] = {{1, 2, 3}, {4, 5, 6}};
|
||||
float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}};
|
||||
matmul.set_arg0_data(&arg0);
|
||||
matmul.set_arg1_data(&arg1);
|
||||
EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
|
||||
EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
|
||||
|
||||
EXPECT_TRUE(matmul.Run());
|
||||
EXPECT_EQ(matmul.error_msg(), "");
|
||||
const float results[4] = {58, 64, 139, 154};
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
|
||||
EXPECT_EQ(matmul.result0_data()[i], results[i]);
|
||||
}
|
||||
EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, MatMulAndAdd1) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
|
||||
MatMulAndAddComp muladd;
|
||||
muladd.set_thread_pool(&device);
|
||||
EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]);
|
||||
EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]);
|
||||
|
||||
// Test methods with positional args and results.
|
||||
{
|
||||
const float args[8] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
std::copy(args + 0, args + 4, muladd.arg0_data());
|
||||
std::copy(args + 4, args + 8, muladd.arg1_data());
|
||||
EXPECT_TRUE(muladd.Run());
|
||||
EXPECT_EQ(muladd.error_msg(), "");
|
||||
const float results0[4] = {19, 22, 43, 50};
|
||||
const float results1[4] = {6, 8, 10, 12};
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd.result0(i / 2, i % 2), results0[i]);
|
||||
EXPECT_EQ(muladd.result0_data()[i], results0[i]);
|
||||
EXPECT_EQ(muladd.result1(i / 2, i % 2), results1[i]);
|
||||
EXPECT_EQ(muladd.result1_data()[i], results1[i]);
|
||||
}
|
||||
EXPECT_EQ(muladd.result0_data(), muladd.results()[0]);
|
||||
EXPECT_EQ(muladd.result1_data(), muladd.results()[1]);
|
||||
|
||||
const MatMulAndAddComp& muladd_const = muladd;
|
||||
EXPECT_EQ(muladd_const.error_msg(), "");
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]);
|
||||
EXPECT_EQ(muladd_const.arg0_data()[i], args[i]);
|
||||
}
|
||||
EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]);
|
||||
EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]);
|
||||
}
|
||||
EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]);
|
||||
EXPECT_EQ(muladd_const.result0_data()[i], results0[i]);
|
||||
EXPECT_EQ(muladd_const.result1(i / 2, i % 2), results1[i]);
|
||||
EXPECT_EQ(muladd_const.result1_data()[i], results1[i]);
|
||||
}
|
||||
EXPECT_EQ(muladd_const.result0_data(), muladd.results()[0]);
|
||||
EXPECT_EQ(muladd_const.result1_data(), muladd.results()[1]);
|
||||
}
|
||||
|
||||
// Test methods with named args and results.
|
||||
{
|
||||
const float args[8] = {10, 20, 30, 40, 50, 60, 70, 80};
|
||||
std::copy(args + 0, args + 4, muladd.arg_x_data());
|
||||
std::copy(args + 4, args + 8, muladd.arg_y_data());
|
||||
EXPECT_TRUE(muladd.Run());
|
||||
EXPECT_EQ(muladd.error_msg(), "");
|
||||
const float results0[4] = {1900, 2200, 4300, 5000};
|
||||
const float results1[4] = {60, 80, 100, 120};
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd.result_x_y_prod(i / 2, i % 2), results0[i]);
|
||||
EXPECT_EQ(muladd.result_x_y_prod_data()[i], results0[i]);
|
||||
EXPECT_EQ(muladd.result_x_y_sum(i / 2, i % 2), results1[i]);
|
||||
EXPECT_EQ(muladd.result_x_y_sum_data()[i], results1[i]);
|
||||
}
|
||||
EXPECT_EQ(muladd.result_x_y_prod_data(), muladd.results()[0]);
|
||||
EXPECT_EQ(muladd.result_x_y_sum_data(), muladd.results()[1]);
|
||||
|
||||
// Test const methods.
|
||||
const MatMulAndAddComp& muladd_const = muladd;
|
||||
EXPECT_EQ(muladd_const.error_msg(), "");
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]);
|
||||
EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]);
|
||||
}
|
||||
EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]);
|
||||
EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]);
|
||||
}
|
||||
EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]);
|
||||
EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]);
|
||||
EXPECT_EQ(muladd_const.result_x_y_sum(i / 2, i % 2), results1[i]);
|
||||
EXPECT_EQ(muladd_const.result_x_y_sum_data()[i], results1[i]);
|
||||
}
|
||||
EXPECT_EQ(muladd_const.result_x_y_prod_data(), muladd.results()[0]);
|
||||
EXPECT_EQ(muladd_const.result_x_y_sum_data(), muladd.results()[1]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
285
tensorflow/compiler/aot/tfcompile.bzl
Normal file
285
tensorflow/compiler/aot/tfcompile.bzl
Normal file
@ -0,0 +1,285 @@
|
||||
# -*- Python -*-
|
||||
|
||||
"""Build macro that compiles a TensorFlow graph into a cc_library.
|
||||
|
||||
To use from your BUILD file, add the following line to load the macro:
|
||||
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
|
||||
Then call the macro like this:
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfmatmul",
|
||||
config = "test_graph_tfmatmul.config.pbtxt",
|
||||
cpp_class = "MatMulComp",
|
||||
graph = ":test_graph_tfmatmul.pb",
|
||||
)
|
||||
"""
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "if_android", "tf_copts")
|
||||
|
||||
def tf_library(name, graph, config,
|
||||
freeze_checkpoint=None, freeze_saver=None,
|
||||
cpp_class=None, gen_test=True, gen_benchmark=True,
|
||||
visibility=None, testonly=None,
|
||||
tfcompile_flags=None,
|
||||
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
|
||||
deps=None, tags=None):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
|
||||
Args:
|
||||
name: The name of the build rule.
|
||||
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
|
||||
is expected to be in the human-readable proto text format, otherwise it is
|
||||
expected to be in the proto binary format.
|
||||
config: File containing tensorflow.tfcompile.Config proto. If the file ends
|
||||
in '.pbtxt' it is expected to be in the human-readable proto text format,
|
||||
otherwise it is expected to be in the proto binary format.
|
||||
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
|
||||
convert variables into constants.
|
||||
freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
|
||||
binary form, to convert variables into constants.
|
||||
cpp_class: The name of the generated C++ class, wrapping the generated
|
||||
function. The syntax of this flag is
|
||||
[[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
|
||||
for referring to a class, where multiple namespaces may precede the class
|
||||
name, separated by double-colons. The class will be generated in the
|
||||
given namespace(s), or if no namespaces are given, within the global
|
||||
namespace.
|
||||
gen_test: If True, also generate a cc_test rule that builds a simple
|
||||
test and benchmark.
|
||||
gen_benchmark: If True, also generate a binary with a simple benchmark.
|
||||
Unlike the output of gen_test, this benchmark can be run on android.
|
||||
visibility: Bazel build visibility.
|
||||
testonly: Bazel testonly attribute.
|
||||
tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
|
||||
tfcompile_tool: The tfcompile binary. A non-default can be passed to
|
||||
use a tfcompile built with extra dependencies.
|
||||
deps: a list of extra deps to include on the build rules for
|
||||
the generated library.
|
||||
tags: tags to apply to subsidiary build rules.
|
||||
|
||||
The output header is called <name>.h.
|
||||
"""
|
||||
if not cpp_class:
|
||||
fail("cpp_class must be specified")
|
||||
|
||||
tfcompile_graph = graph
|
||||
if freeze_checkpoint or freeze_saver:
|
||||
if not freeze_checkpoint:
|
||||
fail("freeze_checkpoint must be specified when freeze_saver is specified")
|
||||
|
||||
freeze_name = "freeze_" + name
|
||||
freeze_file = freeze_name + ".pb"
|
||||
|
||||
# First run tfcompile to generate the list of out_nodes.
|
||||
out_nodes_file = "out_nodes_" + freeze_name
|
||||
native.genrule(
|
||||
name=("gen_" + out_nodes_file),
|
||||
srcs=[config],
|
||||
outs=[out_nodes_file],
|
||||
cmd=("$(location " + tfcompile_tool + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --dump_fetch_nodes > $@"),
|
||||
tools=[tfcompile_tool],
|
||||
# Run tfcompile on the build host, rather than forge, since it's
|
||||
# typically way faster on the local machine.
|
||||
local=1,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# Now run freeze_graph to convert variables into constants.
|
||||
freeze_args = (" --input_graph=$(location " + graph + ")" +
|
||||
" --input_binary=" + str(not graph.endswith(".pbtxt")) +
|
||||
" --input_checkpoint=$(location " + freeze_checkpoint + ")" +
|
||||
" --output_graph=$(location " + freeze_file + ")" +
|
||||
" --output_node_names=$$(<$(location " + out_nodes_file +
|
||||
"))")
|
||||
freeze_saver_srcs = []
|
||||
if freeze_saver:
|
||||
freeze_args += " --input_saver=$(location " + freeze_saver + ")"
|
||||
freeze_saver_srcs += [freeze_saver]
|
||||
native.genrule(
|
||||
name=freeze_name,
|
||||
srcs=[
|
||||
graph,
|
||||
freeze_checkpoint,
|
||||
out_nodes_file,
|
||||
] + freeze_saver_srcs,
|
||||
outs=[freeze_file],
|
||||
cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
|
||||
freeze_args),
|
||||
tools=["//tensorflow/python/tools:freeze_graph"],
|
||||
tags=tags,
|
||||
)
|
||||
tfcompile_graph = freeze_file
|
||||
|
||||
# Rule that runs tfcompile to produce the header and object file.
|
||||
header_file = name + ".h"
|
||||
object_file = name + ".o"
|
||||
ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_")
|
||||
native.genrule(
|
||||
name=("gen_" + name),
|
||||
srcs=[
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
outs=[
|
||||
header_file,
|
||||
object_file,
|
||||
],
|
||||
cmd=("$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
" --target_triple=" + target_llvm_triple() +
|
||||
" --out_header=$(@D)/" + header_file +
|
||||
" --out_object=$(@D)/" + object_file +
|
||||
" " + (tfcompile_flags or "")),
|
||||
tools=[tfcompile_tool],
|
||||
visibility=visibility,
|
||||
testonly=testonly,
|
||||
# Run tfcompile on the build host since it's typically faster on the local
|
||||
# machine.
|
||||
#
|
||||
# Note that setting the local=1 attribute on a *test target* causes the
|
||||
# test infrastructure to skip that test. However this is a genrule, not a
|
||||
# test target, and runs with --genrule_strategy=forced_forge, meaning the
|
||||
# local=1 attribute is ignored, and the genrule is still run.
|
||||
#
|
||||
# https://www.bazel.io/versions/master/docs/be/general.html#genrule
|
||||
local=1,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# The cc_library rule packaging up the header and object file, and needed
|
||||
# kernel implementations.
|
||||
native.cc_library(
|
||||
name=name,
|
||||
srcs=[object_file],
|
||||
hdrs=[header_file],
|
||||
visibility=visibility,
|
||||
testonly=testonly,
|
||||
deps = [
|
||||
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
|
||||
"//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int32",
|
||||
"//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int64",
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/aot:runtime",
|
||||
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:framework_lite",
|
||||
] + (deps or []),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# Variables used for gen_test and gen_benchmark.
|
||||
no_ns_name = ""
|
||||
cpp_class_split = cpp_class.rsplit("::", maxsplit=2)
|
||||
if len(cpp_class_split) == 1:
|
||||
no_ns_name = cpp_class_split[0]
|
||||
else:
|
||||
no_ns_name = cpp_class_split[1]
|
||||
sed_replace = (
|
||||
"-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
|
||||
"-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
|
||||
"-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ")
|
||||
|
||||
if gen_test:
|
||||
test_name = name + "_test"
|
||||
test_file = test_name + ".cc"
|
||||
# Rule to rewrite test.cc to produce the test_file.
|
||||
native.genrule(
|
||||
name=("gen_" + test_name),
|
||||
testonly=1,
|
||||
srcs=[
|
||||
"//tensorflow/compiler/aot:test.cc",
|
||||
header_file,
|
||||
],
|
||||
outs=[test_file],
|
||||
cmd=("sed " + sed_replace +
|
||||
" $(location //tensorflow/compiler/aot:test.cc) " +
|
||||
"> $(OUTS)"),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# The cc_test rule for the generated code.
|
||||
native.cc_test(
|
||||
name=test_name,
|
||||
srcs=[test_file],
|
||||
deps=[
|
||||
":" + name,
|
||||
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
|
||||
"//tensorflow/compiler/aot:runtime",
|
||||
"//tensorflow/compiler/aot:tf_library_test_main",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
if gen_benchmark:
|
||||
benchmark_name = name + "_benchmark"
|
||||
benchmark_file = benchmark_name + ".cc"
|
||||
benchmark_main = ("//tensorflow/compiler/aot:" +
|
||||
"benchmark_main.template")
|
||||
|
||||
# Rule to rewrite benchmark.cc to produce the benchmark_file.
|
||||
native.genrule(
|
||||
name=("gen_" + benchmark_name),
|
||||
srcs=[
|
||||
benchmark_main,
|
||||
header_file,
|
||||
],
|
||||
testonly = testonly,
|
||||
outs=[benchmark_file],
|
||||
cmd=("sed " + sed_replace +
|
||||
" $(location " + benchmark_main + ") " +
|
||||
"> $(OUTS)"),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# The cc_benchmark rule for the generated code.
|
||||
#
|
||||
# Note: to get smaller size on android for comparison, compile with:
|
||||
# --copt=-fvisibility=hidden
|
||||
# --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
|
||||
# --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
|
||||
native.cc_binary(
|
||||
name=benchmark_name,
|
||||
srcs=[benchmark_file],
|
||||
testonly = testonly,
|
||||
copts = tf_copts(),
|
||||
linkopts = if_android(["-pie", "-s"]),
|
||||
deps=[
|
||||
":" + name,
|
||||
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
|
||||
"//tensorflow/compiler/aot:benchmark",
|
||||
"//tensorflow/compiler/aot:runtime",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//third_party/eigen3",
|
||||
] + if_android([
|
||||
"//tensorflow/compiler/aot:benchmark_extra_android",
|
||||
]),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
def target_llvm_triple():
|
||||
"""Returns the target LLVM triple to be used for compiling the target."""
|
||||
# TODO(toddw): Add target_triple for other targets. For details see:
|
||||
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
|
||||
return select({
|
||||
"//tensorflow:android_arm": "armv7-none-android",
|
||||
"//tensorflow:android_arm64": "aarch64-none-android",
|
||||
"//conditions:default": "x86_64-pc-linux",
|
||||
})
|
43
tensorflow/compiler/aot/tfcompile.proto
Normal file
43
tensorflow/compiler/aot/tfcompile.proto
Normal file
@ -0,0 +1,43 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.tfcompile;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "CompileProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.tfcompile";
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
|
||||
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
|
||||
// index of a particular node in the graph. If the output of the named node
|
||||
// feeds into other node(s), this corresponds to one or more edges. Otherwise
|
||||
// it doesn't correspond to any existing edges at all, e.g. for output nodes.
|
||||
message TensorId {
|
||||
string node_name = 1;
|
||||
int64 output_index = 2;
|
||||
};
|
||||
|
||||
// Feed represents a single feed tensor in the graph, which corresponds to an
|
||||
// input argument for the generated function.
|
||||
message Feed {
|
||||
TensorId id = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
string name = 3; // Optional name for generated code.
|
||||
};
|
||||
|
||||
// Fetch represents a single fetch tensor in the graph, which corresponds to an
|
||||
// output argument for the generated function.
|
||||
message Fetch {
|
||||
TensorId id = 1;
|
||||
string name = 2; // Optional name for generated code.
|
||||
};
|
||||
|
||||
// Config represents configuration information for tfcompile.
|
||||
message Config {
|
||||
// Each feed is a positional input argument for the generated function. The
|
||||
// order of each entry matches the order of each input argument.
|
||||
repeated Feed feed = 1;
|
||||
// Each fetch is a positional output argument for the generated function. The
|
||||
// order of each entry matches the order of each output argument.
|
||||
repeated Fetch fetch = 2;
|
||||
};
|
142
tensorflow/compiler/aot/tfcompile_main.cc
Normal file
142
tensorflow/compiler/aot/tfcompile_main.cc
Normal file
@ -0,0 +1,142 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
const char kUsageHeader[] =
|
||||
"tfcompile performs ahead-of-time compilation of a TensorFlow graph,\n"
|
||||
"resulting in an object file compiled for your target architecture, and a\n"
|
||||
"header file that gives access to the functionality in the object file.\n"
|
||||
"A typical invocation looks like this:\n"
|
||||
"\n"
|
||||
" $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n"
|
||||
"\n";
|
||||
|
||||
Status ReadProtoFile(const string& kind, const string& fname,
|
||||
protobuf::Message* proto) {
|
||||
if (StringPiece(fname).ends_with(".pbtxt")) {
|
||||
return ReadTextProto(Env::Default(), fname, proto);
|
||||
} else {
|
||||
return ReadBinaryProto(Env::Default(), fname, proto);
|
||||
}
|
||||
}
|
||||
|
||||
void ParseTensorId(const string& name, TensorId* id) {
|
||||
const std::pair<StringPiece, int> name_index = ParseTensorName(name);
|
||||
id->set_node_name(name_index.first.ToString());
|
||||
id->set_output_index(name_index.second);
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
// Process config.
|
||||
Config config;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config));
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
if (flags.dump_fetch_nodes) {
|
||||
std::set<string> nodes;
|
||||
for (const Fetch& fetch : config.fetch()) {
|
||||
nodes.insert(fetch.id().node_name());
|
||||
}
|
||||
std::cout << str_util::Join(nodes, ",");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Read and initialize the graph.
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def));
|
||||
std::unique_ptr<Graph> graph;
|
||||
FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library());
|
||||
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &flib, &graph));
|
||||
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(std::move(graph), flags, &flib, &compile_result));
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object,
|
||||
StringPiece(obj.data(), obj.size())));
|
||||
HeaderOpts header_opts;
|
||||
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name,
|
||||
&header_opts.namespaces));
|
||||
string header;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GenerateHeader(header_opts, config, compile_result, &header));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace tfcompile
|
||||
} // end namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::tfcompile::MainFlags flags;
|
||||
flags.target_triple = "x86_64-pc-linux";
|
||||
flags.out_object = "out.o";
|
||||
flags.out_header = "out.h";
|
||||
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list);
|
||||
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
|
||||
xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
|
||||
|
||||
tensorflow::string usage = tensorflow::tfcompile::kUsageHeader;
|
||||
usage += tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
QCHECK(parsed_flags_ok) << "\n" << usage;
|
||||
|
||||
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
||||
QCHECK(argc == 1 && !flags.config.empty() &&
|
||||
(flags.dump_fetch_nodes ||
|
||||
(!flags.graph.empty() && !flags.entry_point.empty())))
|
||||
<< "\n"
|
||||
<< usage;
|
||||
|
||||
TF_QCHECK_OK(tensorflow::tfcompile::Main(flags));
|
||||
return 0;
|
||||
}
|
119
tensorflow/compiler/aot/tfcompile_util.cc
Normal file
119
tensorflow/compiler/aot/tfcompile_util.cc
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsAlpha(char c) {
|
||||
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
|
||||
}
|
||||
|
||||
bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
|
||||
|
||||
Status ValidateTensorId(const TensorId& id) {
|
||||
if (id.node_name().empty()) {
|
||||
return errors::InvalidArgument("TensorId node_name must be non-empty");
|
||||
}
|
||||
if (id.output_index() < 0) {
|
||||
return errors::InvalidArgument("TensorId output_index must be positive");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateFeedFetchName(const string& kind, const string& name,
|
||||
std::set<string>* names) {
|
||||
if (!name.empty()) {
|
||||
TF_RETURN_IF_ERROR(ValidateCppIdent(name, kind + " name"));
|
||||
if (!names->insert(name).second) {
|
||||
return errors::InvalidArgument("duplicate ", kind, " name: ", name);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckFeedFetchNameConflicts(const string& kind,
|
||||
const std::set<string>& names) {
|
||||
// We don't allow the feeds or fetches to contain both "foo" and "foo_data",
|
||||
// since that will cause a collision in codegen symbols.
|
||||
for (const string& name : names) {
|
||||
const string name_data(name + "_data");
|
||||
if (names.find(name_data) != names.end()) {
|
||||
return errors::InvalidArgument("conflicting ", kind, " name: ", name,
|
||||
" and ", name_data);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
|
||||
if (ident.empty()) {
|
||||
return errors::InvalidArgument("empty identifier: ", msg);
|
||||
}
|
||||
// Require that the identifier starts with a nondigit, and is composed of
|
||||
// nondigits and digits, as specified in section [2.11 Identifiers] of the
|
||||
// C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
|
||||
// defined as [0-9].
|
||||
//
|
||||
// Technically the standard also allows for `universal-character-name`, with a
|
||||
// table of allowed unicode ranges, as well as `other implementation-defined
|
||||
// characters`. We disallow those here to give better error messages, at the
|
||||
// expensive of being more restrictive than the standard.
|
||||
if (ident[0] != '_' && !IsAlpha(ident[0])) {
|
||||
return errors::InvalidArgument("illegal leading char: ", msg);
|
||||
}
|
||||
for (size_t pos = 1; pos < ident.size(); ++pos) {
|
||||
if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
|
||||
return errors::InvalidArgument("illegal char: ", msg);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateConfig(const Config& config) {
|
||||
std::set<string> names;
|
||||
for (const Feed& feed : config.feed()) {
|
||||
TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
|
||||
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
|
||||
TF_RETURN_IF_ERROR(ValidateFeedFetchName("feed", feed.name(), &names));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
|
||||
names.clear();
|
||||
for (const Fetch& fetch : config.fetch()) {
|
||||
TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
|
||||
TF_RETURN_IF_ERROR(ValidateFeedFetchName("fetch", fetch.name(), &names));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
|
||||
if (config.feed().empty() || config.fetch().empty()) {
|
||||
return errors::InvalidArgument("feeds and fetches must be specified");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
36
tensorflow/compiler/aot/tfcompile_util.h
Normal file
36
tensorflow/compiler/aot/tfcompile_util.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
|
||||
// appended to error messages.
|
||||
Status ValidateCppIdent(StringPiece ident, StringPiece msg);
|
||||
|
||||
// ValidateConfig returns OK iff config is valid.
|
||||
Status ValidateConfig(const Config& config);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
185
tensorflow/compiler/aot/tfcompile_util_test.cc
Normal file
185
tensorflow/compiler/aot/tfcompile_util_test.cc
Normal file
@ -0,0 +1,185 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace {
|
||||
|
||||
void ExpectErrorContains(Status status, StringPiece str) {
|
||||
EXPECT_NE(Status::OK(), status);
|
||||
EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
|
||||
<< "expected error: " << status.error_message() << " to contain: " << str;
|
||||
}
|
||||
|
||||
TEST(ValidateCppIdent, Simple) {
|
||||
TF_EXPECT_OK(ValidateCppIdent("a", ""));
|
||||
TF_EXPECT_OK(ValidateCppIdent("abc", ""));
|
||||
TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
|
||||
TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
|
||||
// Make sure we didn't skip a valid letter or digit
|
||||
string ident;
|
||||
for (char c = 'a'; c <= 'z'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
for (char c = 'A'; c <= 'Z'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
for (char c = '0'; c <= '9'; c++) {
|
||||
ident.append(1, c);
|
||||
}
|
||||
ident += "_";
|
||||
TF_EXPECT_OK(ValidateCppIdent(ident, ""));
|
||||
|
||||
ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
|
||||
ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
|
||||
ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
|
||||
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
|
||||
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, Good) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
feed->mutable_id()->set_output_index(123);
|
||||
feed->set_name("foo_debug");
|
||||
feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("bar");
|
||||
feed->mutable_id()->set_output_index(0);
|
||||
Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("baz");
|
||||
fetch->mutable_id()->set_output_index(456);
|
||||
fetch->set_name("baz_debug");
|
||||
fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("banana");
|
||||
fetch->mutable_id()->set_output_index(0);
|
||||
TF_EXPECT_OK(ValidateConfig(config));
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadEmpty) {
|
||||
Config config;
|
||||
ExpectErrorContains(ValidateConfig(config),
|
||||
"feeds and fetches must be specified");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadNoFeed) {
|
||||
Config config;
|
||||
Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("foo");
|
||||
ExpectErrorContains(ValidateConfig(config),
|
||||
"feeds and fetches must be specified");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadNoFetch) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
ExpectErrorContains(ValidateConfig(config),
|
||||
"feeds and fetches must be specified");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadFeedNodeName) {
|
||||
Config config;
|
||||
config.add_feed();
|
||||
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadFeedOutputIndex) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
feed->mutable_id()->set_output_index(-1);
|
||||
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadFetchNodeName) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
config.add_fetch();
|
||||
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, BadFetchOutputIndex) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("bar");
|
||||
fetch->mutable_id()->set_output_index(-1);
|
||||
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, DuplicateFeedName) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
feed->set_name("dup");
|
||||
feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("bar");
|
||||
feed->set_name("dup");
|
||||
ExpectErrorContains(ValidateConfig(config), "duplicate feed name");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, DuplicateFetchName) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("bar");
|
||||
fetch->set_name("dup");
|
||||
fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("baz");
|
||||
fetch->set_name("dup");
|
||||
ExpectErrorContains(ValidateConfig(config), "duplicate fetch name");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, ConflictingFeedName) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
feed->set_name("conflict");
|
||||
feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("bar");
|
||||
feed->set_name("conflict_data");
|
||||
ExpectErrorContains(ValidateConfig(config), "conflicting feed name");
|
||||
}
|
||||
|
||||
TEST(ValidateConfig, ConflictingFetchName) {
|
||||
Config config;
|
||||
Feed* feed = config.add_feed();
|
||||
feed->mutable_id()->set_node_name("foo");
|
||||
Fetch* fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("bar");
|
||||
fetch->set_name("conflict");
|
||||
fetch = config.add_fetch();
|
||||
fetch->mutable_id()->set_node_name("baz");
|
||||
fetch->set_name("conflict_data");
|
||||
ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
282
tensorflow/compiler/jit/BUILD
Normal file
282
tensorflow/compiler/jit/BUILD
Normal file
@ -0,0 +1,282 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
includes = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = [
|
||||
"//tensorflow/compiler/tf2xla:friends",
|
||||
],
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [":internal"],
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
|
||||
# Target that bundles up the XLA CPU and GPU JIT devices.
|
||||
cc_library(
|
||||
name = "jit",
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":xla_cpu_device",
|
||||
":xla_cpu_jit",
|
||||
":xla_gpu_device",
|
||||
":xla_gpu_jit",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_cpu_jit",
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":xla_local_launch_op",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_gpu_jit",
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":xla_local_launch_op",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_cpu_device",
|
||||
srcs = ["xla_cpu_device.cc"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":xla_device",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_gpu_device",
|
||||
srcs = ["xla_gpu_device.cc"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":xla_device",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = [
|
||||
"defs.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"defs.h",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_device",
|
||||
srcs = [
|
||||
"xla_device.cc",
|
||||
"xla_device_context.cc",
|
||||
"xla_device_launch_op.cc",
|
||||
"xla_device_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_device.h",
|
||||
"xla_device_context.h",
|
||||
"xla_device_launch_op.h",
|
||||
"xla_device_ops.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":jit_compilation_passes",
|
||||
":xla_compilation_cache",
|
||||
":xla_local_launch_op",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
"//tensorflow/core/kernels:assign_op",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:no_op",
|
||||
"//tensorflow/core/kernels:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_compilation_cache",
|
||||
srcs = ["xla_compilation_cache.cc"],
|
||||
hdrs = ["xla_compilation_cache.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jit_compilation_passes",
|
||||
srcs = ["jit_compilation_pass_registration.cc"],
|
||||
deps = [
|
||||
":compilation_passes",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "compilation_passes",
|
||||
srcs = [
|
||||
"build_xla_launch_ops_pass.cc",
|
||||
"encapsulate_subgraphs_pass.cc",
|
||||
"graph_to_functiondef.cc",
|
||||
"mark_for_compilation_pass.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"build_xla_launch_ops_pass.h",
|
||||
"encapsulate_subgraphs_pass.h",
|
||||
"graph_to_functiondef.h",
|
||||
"mark_for_compilation_pass.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":parallel_check_op",
|
||||
":xla_local_launch_op",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags",
|
||||
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
|
||||
"//tensorflow/compiler/tf2xla:const_analysis",
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "compilation_passes_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"encapsulate_subgraphs_pass_test.cc",
|
||||
"graph_to_functiondef_test.cc",
|
||||
"mark_for_compilation_pass_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":compilation_passes",
|
||||
":xla_local_launch_op",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_local_launch_op",
|
||||
srcs = ["xla_local_launch_op.cc"],
|
||||
hdrs = ["xla_local_launch_op.h"],
|
||||
deps = [
|
||||
":common",
|
||||
":xla_compilation_cache",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "parallel_check_op",
|
||||
srcs = ["parallel_check_op.cc"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
215
tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
Normal file
215
tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
Normal file
@ -0,0 +1,215 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/xla_local_launch_op.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static Status BuildLaunchNode(
|
||||
const string& nodename, const string& function_name,
|
||||
const AttrValueMap& function_attr, const string& device_name,
|
||||
const DataTypeVector& constant_dtypes, const DataTypeVector& arg_dtypes,
|
||||
const DataTypeVector& result_dtypes, Graph* graph, Node** node) {
|
||||
NodeDef def;
|
||||
def.set_name(graph->NewName(nodename));
|
||||
def.set_op("_XlaLaunch");
|
||||
def.set_device(device_name);
|
||||
AddNodeAttr("Tconstants", constant_dtypes, &def);
|
||||
AddNodeAttr("Targs", arg_dtypes, &def);
|
||||
AddNodeAttr("Tresults", result_dtypes, &def);
|
||||
NameAttrList function;
|
||||
function.set_name(function_name);
|
||||
*function.mutable_attr() = function_attr;
|
||||
AddNodeAttr("function", function, &def);
|
||||
|
||||
Status status;
|
||||
*node = graph->AddNode(def, &status);
|
||||
return status;
|
||||
}
|
||||
|
||||
static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
|
||||
VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
|
||||
|
||||
int num_constant_args;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args));
|
||||
|
||||
if (num_constant_args < 0 || num_constant_args > node->input_types().size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid number of constant arguments to XLA kernel");
|
||||
}
|
||||
DataTypeVector const_dtypes(node->input_types().begin(),
|
||||
node->input_types().begin() + num_constant_args);
|
||||
DataTypeVector arg_dtypes(node->input_types().begin() + num_constant_args,
|
||||
node->input_types().end());
|
||||
|
||||
// Build a _XlaLaunch operator to execute the function body.
|
||||
Node* launch_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildLaunchNode(graph->NewName(node->name()), node->type_string(),
|
||||
node->def().attr(), node->def().device(), const_dtypes,
|
||||
arg_dtypes, node->output_types(), graph, &launch_node));
|
||||
launch_node->set_assigned_device_name(node->assigned_device_name());
|
||||
|
||||
// Copy incoming edges to the launch node.
|
||||
for (const Edge* edge : node->in_edges()) {
|
||||
if (edge->IsControlEdge()) {
|
||||
graph->AddControlEdge(edge->src(), launch_node);
|
||||
} else {
|
||||
graph->AddEdge(edge->src(), edge->src_output(), launch_node,
|
||||
edge->dst_input());
|
||||
}
|
||||
}
|
||||
|
||||
// Copy outgoing edges to the launch node.
|
||||
std::vector<const Edge*> out_edges(node->out_edges().begin(),
|
||||
node->out_edges().end());
|
||||
for (const Edge* edge : out_edges) {
|
||||
Node* dst = edge->dst();
|
||||
int src_output = edge->src_output();
|
||||
int dst_input = edge->dst_input();
|
||||
graph->RemoveEdge(edge);
|
||||
|
||||
if (edge->IsControlEdge()) {
|
||||
graph->AddControlEdge(launch_node, dst);
|
||||
} else {
|
||||
graph->AddEdge(launch_node, src_output, dst, dst_input);
|
||||
}
|
||||
}
|
||||
graph->RemoveNode(node);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
Graph* graph = options.graph->get();
|
||||
|
||||
for (Node* n : graph->nodes()) {
|
||||
// In all cases, only try to compile computational nodes.
|
||||
if (!n->IsOp() || n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Only compile nodes that are marked for compilation by the
|
||||
// compilation-marking pass (via 'attr_name').
|
||||
if (IsXlaCompiledKernel(*n)) {
|
||||
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
|
||||
// 'ndef' is a call to a compilable function defined in 'flr', returns OK
|
||||
// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
|
||||
// node. Otherwise, returns a non-OK.
|
||||
//
|
||||
// This routine is here so that FunctionLibraryRuntime can jit a
|
||||
// specific function call as requested.
|
||||
Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
std::unique_ptr<OpKernel>* kernel) {
|
||||
bool xla_compile = false;
|
||||
if (!flr->GetFunctionLibraryDefinition()
|
||||
->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
|
||||
.ok() ||
|
||||
!xla_compile) {
|
||||
// Not marked as _XlaCompile=true.
|
||||
return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
|
||||
}
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterJitKernels();
|
||||
if (!IsCompilable(flr, ndef)) {
|
||||
// ndef is calling a function that XLA can't compile.
|
||||
return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
|
||||
}
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If ndef is not instantiable, e.g., the function does not exist,
|
||||
// simply bail out.
|
||||
TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle));
|
||||
const FunctionBody* fbody = flr->GetFunctionBody(handle);
|
||||
CHECK(fbody); // Can't be nullptr since we just instantiated it.
|
||||
std::vector<bool> const_args(fbody->arg_types.size());
|
||||
// If we can't analyze the const args. Bail out.
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
|
||||
|
||||
for (int i = 0; i < const_args.size(); ++i) {
|
||||
if (const_args[i]) {
|
||||
// There is a const arg. Bail out.
|
||||
return errors::InvalidArgument("Const arg: ", i, " in ",
|
||||
DebugString(fbody->fdef));
|
||||
}
|
||||
}
|
||||
|
||||
NodeDef launch_def;
|
||||
launch_def.set_name(ndef.name());
|
||||
launch_def.set_op("_XlaLaunch");
|
||||
launch_def.set_device(flr->device()->name());
|
||||
AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
|
||||
AddNodeAttr("Targs", fbody->arg_types, &launch_def);
|
||||
AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
|
||||
NameAttrList func;
|
||||
func.set_name(ndef.op());
|
||||
*(func.mutable_attr()) = ndef.attr();
|
||||
AddNodeAttr("function", func, &launch_def);
|
||||
|
||||
// TODO(b/32387911): Handles the host memory types across function
|
||||
// calls properly. For now, we assume all inputs and outputs are on
|
||||
// the device memory.
|
||||
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
|
||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
||||
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &launch_def,
|
||||
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
|
||||
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
|
||||
kernel->reset(new XlaLocalLaunchOp(&construction));
|
||||
return s;
|
||||
}
|
||||
|
||||
bool RegisterLaunchOpCreator() {
|
||||
RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool register_me = RegisterLaunchOpCreator();
|
||||
|
||||
} // end namespace
|
||||
|
||||
} // namespace tensorflow
|
31
tensorflow/compiler/jit/build_xla_launch_ops_pass.h
Normal file
31
tensorflow/compiler/jit/build_xla_launch_ops_pass.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class BuildXlaLaunchOpsPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
|
22
tensorflow/compiler/jit/defs.cc
Normal file
22
tensorflow/compiler/jit/defs.cc
Normal file
@ -0,0 +1,22 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kXlaCompileAttr = "_XlaCompile";
|
||||
|
||||
} // namespace tensorflow
|
29
tensorflow/compiler/jit/defs.h
Normal file
29
tensorflow/compiler/jit/defs.h
Normal file
@ -0,0 +1,29 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Provides definitions needed for use of the TensorFlow XLA
|
||||
// device.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_DEFS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_DEFS_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Name of attribute used to tag operators for compilation with XLA
|
||||
extern const char* const kXlaCompileAttr; // "_XlaCompile"
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_DEFS_H_
|
660
tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
Normal file
660
tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
Normal file
@ -0,0 +1,660 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
|
||||
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
|
||||
const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
|
||||
|
||||
namespace {
|
||||
|
||||
// A node/slot pair.
|
||||
// TODO(phawkins): is there a common definition of this?
|
||||
struct NodeSlot {
|
||||
NodeSlot() : node(nullptr), slot(-1) {}
|
||||
NodeSlot(const Node* node, int slot) : node(node), slot(slot) {}
|
||||
|
||||
const Node* node;
|
||||
int slot;
|
||||
|
||||
bool operator==(const NodeSlot& other) const {
|
||||
return node == other.node && slot == other.slot;
|
||||
}
|
||||
|
||||
struct Hasher {
|
||||
uint64 operator()(NodeSlot const& s) const {
|
||||
return Hash64Combine(std::hash<const Node*>()(s.node),
|
||||
std::hash<int>()(s.slot));
|
||||
}
|
||||
};
|
||||
|
||||
struct PairHasher {
|
||||
uint64 operator()(std::pair<NodeSlot, NodeSlot> const& s) const {
|
||||
return Hash64Combine(Hasher()(s.first), Hasher()(s.second));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
class Encapsulator {
|
||||
public:
|
||||
Encapsulator(string group_attribute, Graph const* graph_in)
|
||||
: group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {}
|
||||
|
||||
// Find subgraphs marked with 'group_attribute', and build a new
|
||||
// subgraph, one for each value of 'group_attribute'.
|
||||
Status SplitIntoSubgraphs();
|
||||
|
||||
// Build a FunctionDef for each subgraph, and add it 'library'. The values of
|
||||
// the 'group_attribute' annotations become the function names.
|
||||
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
|
||||
// function conversion.
|
||||
Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
|
||||
FunctionLibraryDefinition* library);
|
||||
|
||||
// Write a copy of the input graph to 'graph_out', where the subgraphs are
|
||||
// replaced with calls to the new functions.
|
||||
Status BuildOutputGraph(bool parallel_checking, Graph* graph_out);
|
||||
|
||||
private:
|
||||
// Returns the key attribute associated with a node. Returns the empty string
|
||||
// if no key attribute is found.
|
||||
string GetFunctionNameAttr(const Node* node) const;
|
||||
|
||||
// A subgraph of the input, all marked with a common 'group_attribute'
|
||||
// value.
|
||||
struct Subgraph {
|
||||
// The subgraph extracted from the input graph, suitable for being turned
|
||||
// into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
|
||||
// returned by _Retval nodes.
|
||||
std::unique_ptr<Graph> graph;
|
||||
|
||||
// Which device are these nodes on? Used both to check that all nodes
|
||||
// are assigned to the same device, and to assign a device to the call node.
|
||||
string device;
|
||||
|
||||
// NodeDef for the function call node.
|
||||
NodeDef call_node_def;
|
||||
|
||||
// Function call node(s) in the output graph. Not owned.
|
||||
// If parallel_checking is enabled, 'call_node_inputs' is the function call
|
||||
// node to which inputs should be fed, and 'call_node_outputs' is the
|
||||
// parallel check op from which outputs should be read. If parallel checking
|
||||
// is disabled, both point to the function call node.
|
||||
Node* call_node_inputs;
|
||||
Node* call_node_outputs;
|
||||
|
||||
// Maps from source (producer node/slot) and destination
|
||||
// (consumer node/slot) tensors in the input graph to _Arg numbers in
|
||||
// the subgraph. The source map is one-to-one, whereas the dest map may be
|
||||
// many-to-one.
|
||||
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_src;
|
||||
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_dst;
|
||||
|
||||
// The _Arg nodes in the subgraph, in order by argument number.
|
||||
std::vector<Node*> args;
|
||||
|
||||
// Map from source tensor in the input graph to result #.
|
||||
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> results;
|
||||
};
|
||||
|
||||
// Builds a ParallelCheck op that compares the output of the original subgraph
|
||||
// with the encapsulated subgraph.
|
||||
Status BuildParallelCheckOp(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const Subgraph& subgraph, Graph* graph_out, Node** parallel_check_op);
|
||||
|
||||
const string group_attribute_;
|
||||
const Graph* graph_in_;
|
||||
|
||||
std::unordered_map<string, Subgraph> subgraphs_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
|
||||
};
|
||||
|
||||
// TODO(phawkins) add a canonical copy of these operator names and refactor
|
||||
// everything to use it.
|
||||
static const char* const kArgOp = "_Arg";
|
||||
static const char* const kRetValOp = "_Retval";
|
||||
|
||||
// Returns the function name attached to 'node', or the empty string if there is
|
||||
// none.
|
||||
string Encapsulator::GetFunctionNameAttr(Node const* node) const {
|
||||
string attr;
|
||||
if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) {
|
||||
attr.clear();
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
|
||||
Status Encapsulator::SplitIntoSubgraphs() {
|
||||
Status s;
|
||||
|
||||
// Map from input graph nodes to subgraph nodes.
|
||||
std::unordered_map<Node*, Node*> node_images;
|
||||
|
||||
// Copy all marked nodes to a subgraph. Do nothing for unmarked nodes.
|
||||
for (Node* node : graph_in_->nodes()) {
|
||||
if (node->IsSource() || node->IsSink()) continue;
|
||||
string func_id = GetFunctionNameAttr(node);
|
||||
if (func_id.empty()) continue;
|
||||
|
||||
Subgraph& subgraph = subgraphs_[func_id];
|
||||
if (!subgraph.graph) {
|
||||
subgraph.graph.reset(new Graph(graph_in_->op_registry()));
|
||||
subgraph.graph->set_versions(graph_in_->versions());
|
||||
}
|
||||
|
||||
Node* image = subgraph.graph->CopyNode(node);
|
||||
image->ClearAttr(group_attribute_);
|
||||
node_images[node] = image;
|
||||
|
||||
// Check the device matches any existing device.
|
||||
string device = node->assigned_device_name().empty()
|
||||
? node->def().device()
|
||||
: node->assigned_device_name();
|
||||
|
||||
if (subgraph.device.empty()) {
|
||||
subgraph.device = device;
|
||||
} else if (subgraph.device != device) {
|
||||
s.Update(errors::InvalidArgument(
|
||||
"Mismatched devices for nodes to be grouped by Encapsulator"));
|
||||
}
|
||||
}
|
||||
|
||||
// Copy edges local to a subgraph. Add _Arg and _Retval nodes to subgraphs for
|
||||
// data edges that cross subgraph boundaries.
|
||||
for (const Edge* edge : graph_in_->edges()) {
|
||||
string src_func_id = GetFunctionNameAttr(edge->src());
|
||||
string dst_func_id = GetFunctionNameAttr(edge->dst());
|
||||
Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
|
||||
Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
|
||||
|
||||
// Copy edges that are local to a subgraph.
|
||||
if (!src_func_id.empty() && src_func_id == dst_func_id) {
|
||||
Graph* g = subgraphs_[src_func_id].graph.get();
|
||||
if (edge->IsControlEdge()) {
|
||||
g->AddControlEdge(src_image, dst_image);
|
||||
} else {
|
||||
g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ignore cross-boundary control edges for right now. We will lift them
|
||||
// onto the enclosing call operators in BuildOutputGraph().
|
||||
if (edge->IsControlEdge()) continue;
|
||||
|
||||
// Add 'src' as an output of its subgraph, if applicable.
|
||||
if (!src_func_id.empty()) {
|
||||
Subgraph& src_subgraph = subgraphs_[src_func_id];
|
||||
int ret_index = src_subgraph.results.size();
|
||||
if (src_subgraph.results
|
||||
.emplace(NodeSlot(edge->src(), edge->src_output()), ret_index)
|
||||
.second) {
|
||||
// Create a new _Retval node
|
||||
DataType dtype = edge->src()->output_type(edge->src_output());
|
||||
|
||||
NodeDef ret_def;
|
||||
ret_def.set_op(kRetValOp);
|
||||
ret_def.set_name(src_subgraph.graph->NewName("output"));
|
||||
AddNodeAttr("T", dtype, &ret_def);
|
||||
AddNodeAttr("index", ret_index, &ret_def);
|
||||
Node* ret = src_subgraph.graph->AddNode(ret_def, &s);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
// Add an edge from 'src' to _Retval.
|
||||
src_subgraph.graph->AddEdge(src_image, edge->src_output(), ret, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Add 'dst' as an input of its subgraph, if applicable.
|
||||
if (!dst_func_id.empty()) {
|
||||
Subgraph& dst_subgraph = subgraphs_[dst_func_id];
|
||||
|
||||
// Create an _Arg node for this tensor, if none exists yet.
|
||||
std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
|
||||
bool inserted;
|
||||
std::tie(iter, inserted) = dst_subgraph.args_by_src.emplace(
|
||||
NodeSlot(edge->src(), edge->src_output()), dst_subgraph.args.size());
|
||||
int arg_index = iter->second;
|
||||
if (inserted) {
|
||||
// This is the first time we have seen this tensor. Create an _Arg node.
|
||||
DataType dtype = edge->dst()->input_type(edge->dst_input());
|
||||
|
||||
NodeDef arg_def;
|
||||
NodeDefBuilder builder(dst_subgraph.graph->NewName("input"), kArgOp);
|
||||
builder.Attr("T", dtype);
|
||||
builder.Attr("index", arg_index);
|
||||
s = builder.Finalize(&arg_def);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
Node* arg = dst_subgraph.graph->AddNode(arg_def, &s);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
dst_subgraph.args.push_back(arg);
|
||||
}
|
||||
// Add an edge from the _Arg node to 'dst' in the subgraph.
|
||||
dst_subgraph.args_by_dst[NodeSlot(edge->dst(), edge->dst_input())] =
|
||||
arg_index;
|
||||
dst_subgraph.graph->AddEdge(dst_subgraph.args[arg_index], 0, dst_image,
|
||||
edge->dst_input());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& entry : subgraphs_) {
|
||||
FixupSourceAndSinkEdges(entry.second.graph.get());
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
Status Encapsulator::BuildFunctionDefs(
|
||||
const RewriteSubgraphFn& rewrite_subgraph_fn,
|
||||
FunctionLibraryDefinition* library) {
|
||||
// For each subgraph, build a FunctionDef.
|
||||
for (auto& subgraph_entry : subgraphs_) {
|
||||
const string& name = subgraph_entry.first;
|
||||
Subgraph& subgraph = subgraph_entry.second;
|
||||
|
||||
subgraph.call_node_def.set_op(name);
|
||||
subgraph.call_node_def.set_name(name);
|
||||
subgraph.call_node_def.set_device(subgraph.device);
|
||||
|
||||
if (rewrite_subgraph_fn) {
|
||||
// Initialize the input and output permutations to the identity.
|
||||
std::vector<int> input_permutation(subgraph.args_by_src.size());
|
||||
std::iota(input_permutation.begin(), input_permutation.end(), 0);
|
||||
std::vector<int> output_permutation(subgraph.results.size());
|
||||
std::iota(output_permutation.begin(), output_permutation.end(), 0);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
rewrite_subgraph_fn(&subgraph.graph, &input_permutation,
|
||||
&output_permutation, &subgraph.call_node_def));
|
||||
|
||||
// Apply the input/output permutations to the 'args_by_...' and 'results'
|
||||
// mappings in 'subgraph', so when we build edges in BuildOutputGraph() we
|
||||
// connect them to the right input/output positions.
|
||||
if (input_permutation.size() != subgraph.args_by_src.size()) {
|
||||
return errors::InvalidArgument("Input permutation has incorrect size.");
|
||||
}
|
||||
if (output_permutation.size() != subgraph.results.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Output permutation has incorrect size.");
|
||||
}
|
||||
for (auto& arg : subgraph.args_by_src) {
|
||||
arg.second = input_permutation[arg.second];
|
||||
}
|
||||
for (auto& arg : subgraph.args_by_dst) {
|
||||
arg.second = input_permutation[arg.second];
|
||||
}
|
||||
for (auto& result : subgraph.results) {
|
||||
result.second = output_permutation[result.second];
|
||||
}
|
||||
}
|
||||
|
||||
FunctionDef fdef;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph.graph, name, &fdef));
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(2) << "Build function def " << name;
|
||||
dump_graph::DumpGraphToFile(
|
||||
strings::StrCat("encapsulate_fdef_graph_", name), *subgraph.graph,
|
||||
library);
|
||||
dump_graph::DumpFunctionDefToFile(
|
||||
strings::StrCat("encapsulate_fdef_", name), fdef);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::BuildParallelCheckOp(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const Encapsulator::Subgraph& subgraph, Graph* graph_out,
|
||||
Node** parallel_check_op) {
|
||||
// Build an index mapping output positions to node/slot pairs in the
|
||||
// original graph.
|
||||
std::vector<NodeSlot> results_by_num(subgraph.results.size());
|
||||
for (const auto& entry : subgraph.results) {
|
||||
results_by_num[entry.second] = entry.first;
|
||||
}
|
||||
|
||||
// Build a parallel check NodeDef.
|
||||
int num_results = results_by_num.size();
|
||||
std::vector<DataType> result_dtypes(num_results);
|
||||
std::vector<NodeDefBuilder::NodeOut> expected_outputs(num_results);
|
||||
std::vector<NodeDefBuilder::NodeOut> actual_outputs(num_results);
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
const NodeSlot& node_slot = results_by_num[i];
|
||||
result_dtypes[i] = node_slot.node->output_type(node_slot.slot);
|
||||
expected_outputs[i] =
|
||||
NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(),
|
||||
node_slot.slot, result_dtypes[i]);
|
||||
actual_outputs[i] = NodeDefBuilder::NodeOut(subgraph.call_node_def.name(),
|
||||
i, result_dtypes[i]);
|
||||
}
|
||||
// Assign the parallel check op to a CPU on the same task as the cluster it is
|
||||
// checking.
|
||||
string device, dummy;
|
||||
if (!DeviceNameUtils::SplitDeviceName(
|
||||
subgraph.call_node_inputs->assigned_device_name(), &device, &dummy)) {
|
||||
return errors::InvalidArgument("Could not parse device name");
|
||||
}
|
||||
strings::StrAppend(&device, "/cpu:0");
|
||||
|
||||
NodeDef check_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeDefBuilder(graph_out->NewName(strings::StrCat(
|
||||
subgraph.call_node_def.name(), "_parallel_check")),
|
||||
"ParallelCheck")
|
||||
.Device(device)
|
||||
.Attr("T", result_dtypes)
|
||||
.Input(expected_outputs)
|
||||
.Input(actual_outputs)
|
||||
.Finalize(&check_def));
|
||||
|
||||
Status s;
|
||||
Node* check_op = graph_out->AddNode(check_def, &s);
|
||||
if (!s.ok()) return s;
|
||||
check_op->set_assigned_device_name(device);
|
||||
|
||||
// TODO(phawkins): it seems redundant to call AddEdge as well as
|
||||
// pass Inputs to the NodeDefBuilder, but I have been unable to find a
|
||||
// way to avoid it.
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
const NodeSlot& node_slot = results_by_num[i];
|
||||
graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op,
|
||||
i);
|
||||
graph_out->AddEdge(subgraph.call_node_inputs, i, check_op, num_results + i);
|
||||
}
|
||||
|
||||
*parallel_check_op = check_op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::BuildOutputGraph(bool parallel_checking,
|
||||
Graph* graph_out) {
|
||||
Status s;
|
||||
|
||||
// Map from nodes in the input graph to nodes in the output graph.
|
||||
std::unordered_map<const Node*, Node*> node_images;
|
||||
|
||||
// Copy all unmarked nodes to the output graph.
|
||||
for (Node* node : graph_in_->nodes()) {
|
||||
if (node->IsSource() || node->IsSink()) continue;
|
||||
string func_id = GetFunctionNameAttr(node);
|
||||
|
||||
// Don't copy nodes that going to be encapsulated, unless parallel checking
|
||||
// is enabled.
|
||||
if (!func_id.empty() && !parallel_checking) continue;
|
||||
|
||||
Node* image = graph_out->CopyNode(node);
|
||||
node_images[node] = image;
|
||||
}
|
||||
node_images[graph_in_->source_node()] = graph_out->source_node();
|
||||
node_images[graph_in_->sink_node()] = graph_out->sink_node();
|
||||
|
||||
// Add function call nodes for each subgraph.
|
||||
for (auto& subgraph_entry : subgraphs_) {
|
||||
Subgraph& subgraph = subgraph_entry.second;
|
||||
|
||||
subgraph.call_node_inputs = graph_out->AddNode(subgraph.call_node_def, &s);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
// Copy the assigned device and the key_annotation over.
|
||||
subgraph.call_node_inputs->set_assigned_device_name(subgraph.device);
|
||||
subgraph.call_node_outputs = subgraph.call_node_inputs;
|
||||
|
||||
if (parallel_checking) {
|
||||
TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, subgraph, graph_out,
|
||||
&subgraph.call_node_outputs));
|
||||
}
|
||||
}
|
||||
|
||||
// Set of edges already added to the output graph, represented as (src, dst)
|
||||
// pairs. We use the set to deduplicate edges; multiple edges in the input
|
||||
// graph may map to one edge in the output graph.
|
||||
std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>
|
||||
edges_added;
|
||||
|
||||
// Add edges to the graph_out graph.
|
||||
for (const Edge* edge : graph_in_->edges()) {
|
||||
string src_func_id = GetFunctionNameAttr(edge->src());
|
||||
string dst_func_id = GetFunctionNameAttr(edge->dst());
|
||||
|
||||
// Ignore edges that are strictly contained within one subgraph, unless
|
||||
// we are constructing parallel check graphs.
|
||||
if (!src_func_id.empty() && src_func_id == dst_func_id) {
|
||||
if (parallel_checking) {
|
||||
Node* src_image = node_images.at(edge->src());
|
||||
Node* dst_image = node_images.at(edge->dst());
|
||||
if (edge->IsControlEdge()) {
|
||||
graph_out->AddControlEdge(src_image, dst_image);
|
||||
} else {
|
||||
graph_out->AddEdge(src_image, edge->src_output(), dst_image,
|
||||
edge->dst_input());
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// We have an edge that crosses a cluster boundary.
|
||||
Node* src_image = src_func_id.empty()
|
||||
? node_images.at(edge->src())
|
||||
: subgraphs_.at(src_func_id).call_node_outputs;
|
||||
Node* dst_image = dst_func_id.empty()
|
||||
? node_images.at(edge->dst())
|
||||
: subgraphs_.at(dst_func_id).call_node_inputs;
|
||||
|
||||
// Copy control edges. Lift control edges onto the enclosing call operator.
|
||||
if (edge->IsControlEdge()) {
|
||||
// Add the control edge, if we have not already added it.
|
||||
if (edges_added.emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1))
|
||||
.second) {
|
||||
graph_out->AddControlEdge(src_image, dst_image);
|
||||
}
|
||||
|
||||
// If parallel checking is enabled, also add a control edge to the
|
||||
// corresponding parallel check op.
|
||||
if (parallel_checking) {
|
||||
graph_out->AddControlEdge(src_image, node_images.at(edge->dst()));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
int src_output = edge->src_output();
|
||||
if (!src_func_id.empty()) {
|
||||
// 'src' is in a subgraph. Use the corresponding call output instead.
|
||||
const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
|
||||
src_output =
|
||||
src_subgraph.results.at(NodeSlot(edge->src(), edge->src_output()));
|
||||
}
|
||||
|
||||
int dst_input = edge->dst_input();
|
||||
|
||||
if (!dst_func_id.empty()) {
|
||||
// 'dst' is in a subgraph. Use the corresponding call input instead.
|
||||
const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
|
||||
dst_input =
|
||||
dst_subgraph.args_by_dst.at(NodeSlot(edge->dst(), edge->dst_input()));
|
||||
|
||||
// If we are parallel checking, also feed the tensor as an input to the
|
||||
// corresponding parallel check subgraph.
|
||||
if (parallel_checking) {
|
||||
graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()),
|
||||
edge->dst_input());
|
||||
}
|
||||
}
|
||||
// Add the edge, if we have not already added it.
|
||||
if (edges_added
|
||||
.emplace(NodeSlot(src_image, src_output),
|
||||
NodeSlot(dst_image, dst_input))
|
||||
.second) {
|
||||
graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
|
||||
}
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Status EncapsulateSubgraphsInFunctions(
|
||||
string group_attribute, const Graph& graph_in,
|
||||
const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking,
|
||||
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
|
||||
Status s;
|
||||
|
||||
Encapsulator encapsulator(std::move(group_attribute), &graph_in);
|
||||
s = encapsulator.SplitIntoSubgraphs();
|
||||
if (!s.ok()) return s;
|
||||
|
||||
s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, library);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
std::unique_ptr<Graph> out(new Graph(library));
|
||||
out->set_versions(graph_in.versions());
|
||||
s = encapsulator.BuildOutputGraph(parallel_checking, out.get());
|
||||
if (!s.ok()) return s;
|
||||
|
||||
*graph_out = std::move(out);
|
||||
return s;
|
||||
}
|
||||
|
||||
// Renumber the indices of _Arg nodes in a graph, according to
|
||||
// 'permutation' that maps old indices to new indices.
|
||||
static Status RenumberArguments(Graph* graph,
|
||||
const std::vector<int>& permutation) {
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
|
||||
if (index < 0 || index >= permutation.size()) {
|
||||
return errors::InvalidArgument("Invalid argument number");
|
||||
}
|
||||
n->AddAttr("index", permutation[index]);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EncapsulateSubgraphsPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
VLOG(1) << "EncapsulateSubgraphsPass::Run";
|
||||
legacy_flags::EncapsulateSubgraphsPassFlags* flags =
|
||||
legacy_flags::GetEncapsulateSubgraphsPassFlags();
|
||||
if (VLOG_IS_ON(1)) {
|
||||
dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph,
|
||||
options.flib_def);
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> graph_out;
|
||||
FunctionLibraryDefinition* const library = options.flib_def;
|
||||
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<FunctionLibraryRuntime> flr(
|
||||
NewFunctionLibraryRuntime(nullptr, options.session_options->env, nullptr,
|
||||
TF_GRAPH_DEF_VERSION, library, opts));
|
||||
|
||||
auto rewrite_subgraph = [&flr](
|
||||
std::unique_ptr<Graph>* subgraph, std::vector<int>* input_permutation,
|
||||
std::vector<int>* output_permutation, NodeDef* node) {
|
||||
// Optimize the subgraph.
|
||||
Graph* g = subgraph->release();
|
||||
OptimizeGraph(flr.get(), &g);
|
||||
subgraph->reset(g);
|
||||
|
||||
std::vector<bool> const_args(input_permutation->size());
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*g, &const_args));
|
||||
|
||||
// Compute a permutation of the arguments such that the constant arguments
|
||||
// are first.
|
||||
const int num_consts =
|
||||
std::count(const_args.begin(), const_args.end(), true);
|
||||
int const_pos = 0;
|
||||
int arg_pos = num_consts;
|
||||
for (int i = 0; i < const_args.size(); ++i) {
|
||||
if (const_args[i]) {
|
||||
(*input_permutation)[i] = const_pos;
|
||||
++const_pos;
|
||||
} else {
|
||||
(*input_permutation)[i] = arg_pos;
|
||||
++arg_pos;
|
||||
}
|
||||
}
|
||||
|
||||
// Renumber argument nodes in the graph.
|
||||
TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation));
|
||||
|
||||
// TODO(phawkins): add a forward is-constant analysis, similarly split
|
||||
// outputs into host-memory constants and device-memory non-constants.
|
||||
|
||||
AddNodeAttr(kXlaCompiledKernelAttr, true, node);
|
||||
AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
|
||||
kXlaClusterAttr, **options.graph, rewrite_subgraph,
|
||||
flags->tf_xla_parallel_checking, &graph_out, library));
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
|
||||
options.flib_def);
|
||||
}
|
||||
|
||||
*options.graph = std::move(graph_out);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsXlaCompiledKernel(const Node& node) {
|
||||
bool is_compiled = false;
|
||||
bool has_compilation_attr =
|
||||
GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
|
||||
is_compiled;
|
||||
return has_compilation_attr ? is_compiled : false;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
86
tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
Normal file
86
tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// An optimization pass that groups nodes marked with a common
|
||||
// kXlaClusterAttr into functions, and replaces the original nodes by
|
||||
// calls. The calls are annotated with kXlaCompiledKernelAttr.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A rewriting function to apply to each subgraph during encapsulation.
|
||||
// 'graph' is the subgraph. The rewriting may renumber the inputs and outputs;
|
||||
// 'input_permutation' is a mapping from old argument numbers to new argument
|
||||
// numbers, whereas 'output_permutation' is the same for outputs. Both
|
||||
// 'input_permutation' and 'output_permutation' are initialized to the identity
|
||||
// permutation. 'nodedef' is the NodeDef for the call to the function under
|
||||
// construction, provided to allow additional attributes to be set.
|
||||
typedef std::function<Status(
|
||||
std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
|
||||
std::vector<int>* output_permutation, NodeDef* node_def)>
|
||||
RewriteSubgraphFn;
|
||||
|
||||
// Transformation that finds subgraphs whose nodes are marked with
|
||||
// 'group_attribute', splits those subgraphs into functions, and replaces
|
||||
// the originals with function calls.
|
||||
//
|
||||
// 'group_attribute' must be a string valued-attribute that names the new
|
||||
// functions to introduce.
|
||||
//
|
||||
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
|
||||
// function conversion.
|
||||
//
|
||||
// If 'parallel_checking' is true, the unencapsulated operators are added to the
|
||||
// output graph, together with a "ParallelCheck" operator, that verifies that
|
||||
// the original and encapsulated subgraphs produce similar results.
|
||||
//
|
||||
// TODO(phawkins): currently, some information in control edges
|
||||
// is not preserved. Suppose you have A and B in the main
|
||||
// graph, C and D in a subgraph. B and C have control deps from A, D has control
|
||||
// dep from B. Originally D must run after C, post-transformation this
|
||||
// dependency is lost.
|
||||
Status EncapsulateSubgraphsInFunctions(
|
||||
string group_attribute, const Graph& graph_in,
|
||||
const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking,
|
||||
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
|
||||
|
||||
// The attribute that marks function calls produced by the encapsulate
|
||||
// subgraphs pass and that should in turn be compiled via _XlaLaunch operators.
|
||||
extern const char* const kXlaCompiledKernelAttr;
|
||||
|
||||
// Does `node` have the kXlaCompiledKernelAttr attribute?
|
||||
bool IsXlaCompiledKernel(const Node& node);
|
||||
|
||||
// Functions produce by the EncapsulateSubgraphs pass have their arguments
|
||||
// ordered such that compile-time constant arguments are first in the argument
|
||||
// order. The functions are annotated with the following attribute giving the
|
||||
// number of constant arguments.
|
||||
extern const char* const kXlaNumConstantArgsAttr;
|
||||
|
||||
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
|
397
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
Normal file
397
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
Normal file
@ -0,0 +1,397 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/graph/equal_graph_def.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
string* diff) {
|
||||
// TODO(phawkins) use a more sophisticated equality test.
|
||||
if (a.DebugString() != b.DebugString()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Definition mismatch for function ",
|
||||
a.signature().name(), ", expected:\n",
|
||||
a.DebugString());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
|
||||
const FunctionDefLibrary& actual, string* diff) {
|
||||
std::unordered_map<string, const FunctionDef*> actual_index;
|
||||
for (const FunctionDef& function : actual.function()) {
|
||||
actual_index[function.signature().name()] = &function;
|
||||
}
|
||||
|
||||
for (const FunctionDef& expected_function : expected.function()) {
|
||||
auto it = actual_index.find(expected_function.signature().name());
|
||||
if (it == actual_index.end()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Did not find expected function '",
|
||||
expected_function.signature().name(), "'");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!EqualFunctionDef(expected_function, *it->second, diff)) return false;
|
||||
actual_index.erase(it);
|
||||
}
|
||||
|
||||
if (!actual_index.empty()) {
|
||||
if (diff != nullptr) {
|
||||
*diff = strings::StrCat("Found unexpected function '",
|
||||
actual_index.begin()->second->signature().name(),
|
||||
"'");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \
|
||||
do { \
|
||||
string diff; \
|
||||
EXPECT_TRUE(EqualFunctionDefLibrary(actual, expected, &diff)) \
|
||||
<< diff << "\nActual: " << actual.DebugString(); \
|
||||
} while (false)
|
||||
|
||||
REGISTER_OP("InputTest").Output("o: float");
|
||||
|
||||
REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
|
||||
REGISTER_OP("BinaryTest")
|
||||
.Input("a: float")
|
||||
.Input("b: float")
|
||||
.Output("o: float");
|
||||
|
||||
REGISTER_OP("AddNLikeTest")
|
||||
.Input("inputs: N * T")
|
||||
.Output("sum: T")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: numbertype")
|
||||
.SetIsCommutative()
|
||||
.SetIsAggregate();
|
||||
|
||||
Node* Input(const GraphDefBuilder::Options& opts) {
|
||||
return ops::SourceOp("InputTest", opts);
|
||||
}
|
||||
|
||||
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
||||
return ops::UnaryOp("UnaryTest", a, opts);
|
||||
}
|
||||
|
||||
Node* Binary(ops::NodeOut a, ops::NodeOut b,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
return ops::BinaryOp("BinaryTest", a, b, opts);
|
||||
}
|
||||
|
||||
Node* AddNLike(std::vector<ops::NodeOut> inputs,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest",
|
||||
opts.op_registry());
|
||||
node_builder.Input(inputs);
|
||||
return opts.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
Node* ArgOp(int index, DataType type, const GraphDefBuilder::Options& opts) {
|
||||
return ops::SourceOp("_Arg",
|
||||
opts.WithAttr("T", type).WithAttr("index", index));
|
||||
}
|
||||
|
||||
Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
|
||||
opts.op_registry());
|
||||
node_builder.Input(a).Attr("index", index);
|
||||
return opts.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
|
||||
Status s;
|
||||
// Convert the GraphDef to a Graph
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), *library));
|
||||
GraphConstructorOptions options;
|
||||
options.allow_internal_ops = true;
|
||||
std::unique_ptr<Graph> graph(new Graph(lib_def.get()));
|
||||
s = ConvertGraphDefToGraph(options, *graphdef, graph.get());
|
||||
if (!s.ok()) return s;
|
||||
|
||||
std::unique_ptr<Graph> graph_out;
|
||||
s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
|
||||
/* rewrite_subgraph_fn= */ {},
|
||||
/* parallel_checking= */ false,
|
||||
&graph_out, lib_def.get());
|
||||
if (!s.ok()) return s;
|
||||
|
||||
GraphDef graphdef_out;
|
||||
graph_out->ToGraphDef(&graphdef_out);
|
||||
graphdef->Swap(&graphdef_out);
|
||||
|
||||
*library = lib_def->ToProto();
|
||||
return s;
|
||||
}
|
||||
|
||||
// If there are no marked nodes, funcification should be a no-op.
|
||||
TEST(EncapsulateSubgraphsTest, NoFunctions) {
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
|
||||
Node* a = Input(builder.opts().WithName("A"));
|
||||
Node* b = Input(builder.opts().WithName("B"));
|
||||
Node* c = Unary(a, builder.opts().WithName("C"));
|
||||
Binary(b, c, builder.opts().WithName("D"));
|
||||
|
||||
GraphDef graphdef_in;
|
||||
FunctionDefLibrary library_in;
|
||||
builder.ToGraphDef(&graphdef_in);
|
||||
*library_in.add_function() = test::function::XTimesTwo();
|
||||
|
||||
GraphDef graphdef_out = graphdef_in;
|
||||
FunctionDefLibrary library_out = library_in;
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out));
|
||||
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out);
|
||||
}
|
||||
|
||||
// Test with one function to transform.
|
||||
TEST(EncapsulateSubgraphsTest, OneFunction) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
*library.add_function() = test::function::XTimesTwo();
|
||||
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
// Give nodes 'c' and 'd' names that collide after lowercasing.
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr(
|
||||
"_encapsulate", "F1"));
|
||||
Binary(a, d, b1.opts().WithName("E"));
|
||||
b1.ToGraphDef(&graphdef);
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
*library_expected.add_function() = test::function::XTimesTwo();
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"input__0:float", "input__1:float"}, {"output__2:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"input__0"}},
|
||||
{{"c"}, "BinaryTest", {"input__1", "C:o:0"}, {}, {"C"}},
|
||||
},
|
||||
{{"output__2", "c:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
NodeBuilder node_builder("F1", "F1", lib_def.get());
|
||||
node_builder.Input(a).Input(b);
|
||||
Node* call = b2.opts().FinalizeBuilder(&node_builder);
|
||||
|
||||
Binary(a, call, b2.opts().WithName("E"));
|
||||
b2.ToGraphDef(&graphdef_expected);
|
||||
}
|
||||
|
||||
// If there are no marked nodes, funcification should be a no-op.
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
// Test with two functions to transform.
|
||||
TEST(EncapsulateSubgraphsTest, TwoFunctions) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
*library.add_function() = test::function::XTimesTwo();
|
||||
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* control = Input(b1.opts().WithName("Control"));
|
||||
Node* c =
|
||||
Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr(
|
||||
"_encapsulate", "F1"));
|
||||
Node* d =
|
||||
Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr(
|
||||
"_encapsulate", "F2"));
|
||||
Binary(a, d, b1.opts().WithName("E"));
|
||||
b1.ToGraphDef(&graphdef);
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
*library_expected.add_function() = test::function::XTimesTwo();
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"input__0:float"}, {"output__1:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"input__0"}},
|
||||
},
|
||||
{{"output__1", "C:o:0"}});
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F2", {"input__0:float", "input__1:float"}, {"output__2:float"}, {},
|
||||
{
|
||||
{{"D"}, "BinaryTest", {"input__0", "input__1"}},
|
||||
},
|
||||
{{"output__2", "D:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
Node* control = Input(b2.opts().WithName("Control"));
|
||||
|
||||
NodeBuilder nb("F1", "F1", lib_def.get());
|
||||
nb.Input(a).ControlInput(control);
|
||||
Node* call1 = b2.opts().FinalizeBuilder(&nb);
|
||||
|
||||
NodeBuilder nb2("F2", "F2", lib_def.get());
|
||||
nb2.Input(b).Input(call1).ControlInput(control);
|
||||
Node* call2 = b2.opts().FinalizeBuilder(&nb2);
|
||||
|
||||
Binary(a, call2, b2.opts().WithName("E"));
|
||||
b2.ToGraphDef(&graphdef_expected);
|
||||
}
|
||||
|
||||
// If there are no marked nodes, funcification should be a no-op.
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
// Returns a vector of node names in 'graph', sorted by name.
|
||||
std::vector<string> GraphNodes(const Graph& graph) {
|
||||
std::vector<string> nodes;
|
||||
for (const auto& node : graph.nodes()) {
|
||||
if (!node->IsSource() && !node->IsSink()) {
|
||||
nodes.push_back(node->name());
|
||||
}
|
||||
}
|
||||
std::sort(nodes.begin(), nodes.end());
|
||||
return nodes;
|
||||
}
|
||||
|
||||
// Returns a sorted vector of (src, dst) edges in 'graph'.
|
||||
std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) {
|
||||
std::vector<std::pair<string, string>> edges;
|
||||
for (const Edge* edge : graph.edges()) {
|
||||
if (edge->src()->IsSource() || edge->dst()->IsSink()) continue;
|
||||
edges.emplace_back(
|
||||
strings::StrCat(edge->src()->name(), ":", edge->src_output()),
|
||||
strings::StrCat(edge->dst()->name(), ":", edge->dst_input()));
|
||||
}
|
||||
std::sort(edges.begin(), edges.end());
|
||||
return edges;
|
||||
}
|
||||
|
||||
TEST(EncapsulateSubgraphsTest, InputDeduplication) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
|
||||
"/job:localhost/replica:0/task:0/cpu:0");
|
||||
auto x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT);
|
||||
auto add1 = ops::Add(root.WithOpName("add1"), x, x);
|
||||
add1.node()->AddAttr("_cluster", "cluster1");
|
||||
auto add2 = ops::Add(root.WithOpName("add2"), add1, add1);
|
||||
add2.node()->AddAttr("_cluster", "cluster2");
|
||||
auto out = ops::Mul(root.WithOpName("mul"), add1, add2);
|
||||
|
||||
Graph graph_before_encapsulation(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
|
||||
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
|
||||
/*parallel_checking=*/false, &graph, &library));
|
||||
|
||||
std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
|
||||
EXPECT_EQ(expected_nodes, GraphNodes(*graph));
|
||||
|
||||
std::vector<std::pair<string, string>> expected_edges = {
|
||||
{"cluster1:0", "cluster2:0"},
|
||||
{"cluster1:0", "mul:0"},
|
||||
{"cluster2:0", "mul:1"},
|
||||
{"x:0", "cluster1:0"}};
|
||||
EXPECT_EQ(expected_edges, GraphEdges(*graph));
|
||||
}
|
||||
|
||||
TEST(EncapsulateSubgraphsTest, ParallelChecking) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
|
||||
"/job:localhost/replica:0/task:0/cpu:0");
|
||||
auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
|
||||
auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
|
||||
auto add1 = ops::Add(root.WithOpName("add1"), x1, x2);
|
||||
add1.node()->AddAttr("_cluster", "cluster1");
|
||||
auto add2 = ops::Add(root.WithOpName("add2"), add1, x2);
|
||||
add2.node()->AddAttr("_cluster", "cluster1");
|
||||
auto out = ops::Mul(root.WithOpName("mul"), x1, add2);
|
||||
|
||||
Graph graph_before_encapsulation(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
|
||||
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
|
||||
/*parallel_checking=*/true, &graph, &library));
|
||||
|
||||
std::vector<string> expected_nodes = {
|
||||
"add1", "add2", "cluster1", "cluster1_parallel_check/_0",
|
||||
"mul", "x1", "x2"};
|
||||
EXPECT_EQ(expected_nodes, GraphNodes(*graph));
|
||||
|
||||
std::vector<std::pair<string, string>> expected_edges = {
|
||||
{"add1:0", "add2:0"},
|
||||
{"add2:0", "cluster1_parallel_check/_0:0"},
|
||||
{"cluster1:0", "cluster1_parallel_check/_0:1"},
|
||||
{"cluster1_parallel_check/_0:0", "mul:1"},
|
||||
{"x1:0", "add1:0"},
|
||||
{"x1:0", "cluster1:0"},
|
||||
{"x1:0", "mul:0"},
|
||||
{"x2:0", "add1:1"},
|
||||
{"x2:0", "add2:1"},
|
||||
{"x2:0", "cluster1:1"},
|
||||
};
|
||||
EXPECT_EQ(expected_edges, GraphEdges(*graph));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
274
tensorflow/compiler/jit/graph_to_functiondef.cc
Normal file
274
tensorflow/compiler/jit/graph_to_functiondef.cc
Normal file
@ -0,0 +1,274 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// TODO(phawkins) add a canonical copy of these operator names and refactor
|
||||
// everything to use it.
|
||||
const char* const kArgOp = "_Arg";
|
||||
const char* const kRetValOp = "_Retval";
|
||||
|
||||
// Class that maintains a one-to-one original node name -> new name mapping.
|
||||
// We have to normalize the names used as input and output arguments to
|
||||
// match regexp "[a-z][a-z0-9_]*". Once we rename them, we risk creating
|
||||
// a name collision with the other node names, so if necessary we add
|
||||
// a suffix to make names unique. So if we have an input named "A" and a
|
||||
// node in the function body named "a", they will be renamed to "a" and "a_0".
|
||||
class NodeNameMapping {
|
||||
public:
|
||||
NodeNameMapping() = default;
|
||||
|
||||
// Normalize the input/output name and then make it unique.
|
||||
string Normalize(const string& name);
|
||||
|
||||
// Make the node name unique.
|
||||
string Uniquify(const string& name);
|
||||
|
||||
// Look up how a node name was previously normalized/uniquified.
|
||||
// Returns empty if name was never seen.
|
||||
string Renormalize(const string& name) const;
|
||||
|
||||
private:
|
||||
string NormalizeHelper(string name) const;
|
||||
string UniquifyHelper(string name);
|
||||
|
||||
std::unordered_set<string> used_names_;
|
||||
std::unordered_map<string, string> name_mapping_;
|
||||
};
|
||||
|
||||
string NodeNameMapping::NormalizeHelper(string name) const {
|
||||
// Convert letters to lowercase and non-alphanumeric characters to '_'.
|
||||
if (name.empty()) name = "unknown";
|
||||
const int n = name.size();
|
||||
for (int i = 0; i < n; i++) {
|
||||
char c = name[i];
|
||||
if (isalnum(c)) {
|
||||
if (isupper(c)) {
|
||||
name[i] = tolower(c);
|
||||
}
|
||||
} else {
|
||||
name[i] = '_';
|
||||
}
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
string NodeNameMapping::UniquifyHelper(string name) {
|
||||
// If the name hasn't been used yet, use it as-is.
|
||||
if (used_names_.insert(name).second) return name;
|
||||
// Add a suffix to name to make it unique.
|
||||
for (int i = 0;; ++i) {
|
||||
const string candidate = strings::StrCat(name, "_", i);
|
||||
if (used_names_.insert(candidate).second) return candidate;
|
||||
}
|
||||
}
|
||||
|
||||
string NodeNameMapping::Normalize(const string& name) {
|
||||
const string normalized = UniquifyHelper(NormalizeHelper(name));
|
||||
name_mapping_[name] = normalized;
|
||||
return normalized;
|
||||
}
|
||||
|
||||
string NodeNameMapping::Uniquify(const string& name) {
|
||||
const string uniqued = UniquifyHelper(name);
|
||||
name_mapping_[name] = uniqued;
|
||||
return uniqued;
|
||||
}
|
||||
|
||||
string NodeNameMapping::Renormalize(const string& name) const {
|
||||
const auto iter = name_mapping_.find(name);
|
||||
if (iter == name_mapping_.end()) return string();
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// Graph to FunctionDef conversion. This code is closely modeled on the Python
|
||||
// code in third_party/tensorflow/python/framework/function.py.
|
||||
|
||||
Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||
FunctionDef* fdef) {
|
||||
fdef->mutable_signature()->set_name(name);
|
||||
|
||||
std::unordered_map<string, string> tensor_renaming;
|
||||
std::unordered_map<string, string> return_values;
|
||||
NodeNameMapping node_names;
|
||||
|
||||
for (Node const* node : graph.nodes()) {
|
||||
if (!node->IsOp()) continue;
|
||||
|
||||
if (node->type_string() == kArgOp) {
|
||||
int index;
|
||||
DataType type;
|
||||
GetNodeAttr(node->def(), "T", &type);
|
||||
GetNodeAttr(node->def(), "index", &index);
|
||||
while (fdef->signature().input_arg_size() <= index) {
|
||||
fdef->mutable_signature()->add_input_arg();
|
||||
}
|
||||
OpDef::ArgDef* argdef =
|
||||
fdef->mutable_signature()->mutable_input_arg(index);
|
||||
argdef->set_type(type);
|
||||
const string normalized = node_names.Normalize(node->name());
|
||||
argdef->set_name(normalized);
|
||||
tensor_renaming[strings::StrCat(node->name(), ":0")] = normalized;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->type_string() == kRetValOp) {
|
||||
int index;
|
||||
DataType type;
|
||||
GetNodeAttr(node->def(), "T", &type);
|
||||
GetNodeAttr(node->def(), "index", &index);
|
||||
while (fdef->signature().output_arg_size() <= index) {
|
||||
fdef->mutable_signature()->add_output_arg();
|
||||
}
|
||||
OpDef::ArgDef* argdef =
|
||||
fdef->mutable_signature()->mutable_output_arg(index);
|
||||
argdef->set_type(type);
|
||||
const string normalized = node_names.Normalize(node->name());
|
||||
argdef->set_name(normalized);
|
||||
CHECK_EQ(node->in_edges().size(), 1);
|
||||
Edge const* edge = *node->in_edges().begin();
|
||||
return_values[normalized] =
|
||||
strings::StrCat(edge->src()->name(), ":", edge->src_output());
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeDef* node_def = fdef->add_node_def();
|
||||
node_def->CopyFrom(node->def());
|
||||
node_def->set_name(node_names.Uniquify(node->name()));
|
||||
node_def->clear_device();
|
||||
|
||||
// Reset input names based on graph rather than the NodeDef.
|
||||
node_def->clear_input();
|
||||
|
||||
// Edges, indexed by dst_input.
|
||||
std::vector<const Edge*> in_edges;
|
||||
std::vector<const Edge*> control_edges;
|
||||
for (Edge const* edge : node->in_edges()) {
|
||||
if (edge->src()->IsSource()) continue;
|
||||
|
||||
if (edge->IsControlEdge()) {
|
||||
control_edges.push_back(edge);
|
||||
} else {
|
||||
if (in_edges.size() <= edge->dst_input()) {
|
||||
in_edges.resize(edge->dst_input() + 1);
|
||||
}
|
||||
in_edges[edge->dst_input()] = edge;
|
||||
}
|
||||
}
|
||||
|
||||
// Add regular inputs
|
||||
for (int i = 0; i < in_edges.size(); ++i) {
|
||||
const Edge* edge = in_edges[i];
|
||||
if (edge == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Nonconsecutive input edges; missing "
|
||||
"input edge ",
|
||||
i, " for node ", node->name());
|
||||
}
|
||||
node_def->add_input(
|
||||
strings::StrCat(edge->src()->name(), ":", edge->src_output()));
|
||||
}
|
||||
|
||||
// Add control inputs
|
||||
for (const Edge* edge : control_edges) {
|
||||
node_def->add_input(strings::StrCat("^", edge->src()->name()));
|
||||
}
|
||||
|
||||
// Populate tensor_renaming.
|
||||
NameRangeMap output_ranges;
|
||||
TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr,
|
||||
&output_ranges));
|
||||
for (const auto& output : output_ranges) {
|
||||
for (int i = output.second.first; i < output.second.second; ++i) {
|
||||
const string tensor_name = strings::StrCat(
|
||||
node_def->name(), ":", output.first, ":", i - output.second.first);
|
||||
tensor_renaming[strings::StrCat(node->name(), ":", i)] = tensor_name;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Detect missing function inputs.
|
||||
for (int i = 0; i < fdef->signature().input_arg_size(); ++i) {
|
||||
const string& input_name = fdef->signature().input_arg(i).name();
|
||||
if (input_name.empty()) {
|
||||
return errors::InvalidArgument("Missing input ", i, " to function ",
|
||||
name);
|
||||
}
|
||||
}
|
||||
|
||||
// Remap input names. We do this as a second pass to allow the nodes to be in
|
||||
// any order.
|
||||
for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) {
|
||||
NodeDef* node_def = fdef->mutable_node_def(n_index);
|
||||
for (int i = 0; i < node_def->input_size(); ++i) {
|
||||
if (StringPiece(node_def->input(i)).starts_with("^")) {
|
||||
// Control input
|
||||
const string normalized =
|
||||
node_names.Renormalize(node_def->input(i).substr(1));
|
||||
if (normalized.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Could not remap control input ", i, ", '", node_def->input(i),
|
||||
"', of node '", node_def->name(), "' in function ", name);
|
||||
}
|
||||
*node_def->mutable_input(i) = strings::StrCat("^", normalized);
|
||||
} else {
|
||||
const auto iter = tensor_renaming.find(node_def->input(i));
|
||||
if (iter == tensor_renaming.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Could not remap input ", i, ", '", node_def->input(i),
|
||||
"', of node '", node_def->name(), "' in function ", name);
|
||||
}
|
||||
*node_def->mutable_input(i) = iter->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remap return values.
|
||||
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
|
||||
const string& ret_name = fdef->signature().output_arg(r).name();
|
||||
if (ret_name.empty()) {
|
||||
return errors::InvalidArgument("Missing output ", r, " to function ",
|
||||
name);
|
||||
}
|
||||
const string& return_value = return_values[ret_name];
|
||||
const auto iter = tensor_renaming.find(return_value);
|
||||
if (iter == tensor_renaming.end()) {
|
||||
return errors::InvalidArgument("Could not remap return value ", r, ", '",
|
||||
ret_name, "', of '", return_value,
|
||||
"' in function ", name);
|
||||
}
|
||||
(*fdef->mutable_ret())[ret_name] = iter->second;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
33
tensorflow/compiler/jit/graph_to_functiondef.h
Normal file
33
tensorflow/compiler/jit/graph_to_functiondef.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts 'graph' to a FunctionDef 'fdef', with name 'name'.
|
||||
// Closely modeled on the Python code in
|
||||
// third_party/tensorflow/python/framework/function.py
|
||||
Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||
FunctionDef* fdef);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
|
87
tensorflow/compiler/jit/graph_to_functiondef_test.cc
Normal file
87
tensorflow/compiler/jit/graph_to_functiondef_test.cc
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/graph/equal_graph_def.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
string* diff) {
|
||||
// TODO(phawkins) use a more sophisticated equality test.
|
||||
if (a.DebugString() != b.DebugString()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Definition mismatch for function ",
|
||||
a.signature().name(), ":\n", a.DebugString(),
|
||||
"\n ---- vs. ----\n", b.DebugString());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
TEST(GraphToFunctionDefTest, Basics) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0);
|
||||
auto b = ops::_Arg(root.WithOpName("B"), DT_FLOAT, 1);
|
||||
auto c = ops::_Arg(root.WithOpName("C"), DT_FLOAT, 2);
|
||||
auto d = ops::Add(root.WithOpName("D"), a, b);
|
||||
auto e = ops::Add(root.WithOpName("b"), d, c);
|
||||
auto f = ops::Neg(root.WithOpName("h"), e);
|
||||
auto g =
|
||||
ops::AddN(root.WithOpName("G"), std::initializer_list<ops::Output>{e, f});
|
||||
auto h = ops::_Retval(root.WithOpName("H"), g, 0);
|
||||
|
||||
GraphDef graph_def;
|
||||
root.ToGraphDef(&graph_def);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphConstructorOptions options;
|
||||
TF_EXPECT_OK(ConvertGraphDefToGraph(options, graph_def, graph.get()));
|
||||
|
||||
FunctionDef fdef;
|
||||
TF_EXPECT_OK(GraphToFunctionDef(*graph, "test_fn", &fdef));
|
||||
|
||||
FunctionDef fdef_expected = FunctionDefHelper::Create(
|
||||
"test_fn", // function name
|
||||
{"a: float", "b: float", "c: float"}, // inputs
|
||||
{"h_0: float"}, // outputs
|
||||
{}, // attrs
|
||||
{
|
||||
// nodes in the function body
|
||||
{{"D"}, "Add", {"a", "b"}, {{"T", DT_FLOAT}}},
|
||||
{{"b_0"}, "Add", {"D:z:0", "c"}, {{"T", DT_FLOAT}}},
|
||||
{{"h"}, "Neg", {"b_0:z:0"}, {{"T", DT_FLOAT}}},
|
||||
{{"G"}, "AddN", {"b_0:z:0", "h:y:0"}, {{"N", 2}, {"T", DT_FLOAT}}},
|
||||
},
|
||||
{{"h_0", "G:sum:0"}}); // return values
|
||||
|
||||
string diff;
|
||||
bool fdefs_equal = EqualFunctionDef(fdef_expected, fdef, &diff);
|
||||
EXPECT_TRUE(fdefs_equal) << diff;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
41
tensorflow/compiler/jit/graphcycles/BUILD
Normal file
41
tensorflow/compiler/jit/graphcycles/BUILD
Normal file
@ -0,0 +1,41 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graphcycles",
|
||||
srcs = ["graphcycles.cc"],
|
||||
hdrs = ["graphcycles.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "graphcycles_test",
|
||||
srcs = ["graphcycles_test.cc"],
|
||||
deps = [
|
||||
":graphcycles",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
391
tensorflow/compiler/jit/graphcycles/graphcycles.cc
Normal file
391
tensorflow/compiler/jit/graphcycles/graphcycles.cc
Normal file
@ -0,0 +1,391 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// GraphCycles provides incremental cycle detection on a dynamic
|
||||
// graph using the following algorithm:
|
||||
//
|
||||
// A dynamic topological sort algorithm for directed acyclic graphs
|
||||
// David J. Pearce, Paul H. J. Kelly
|
||||
// Journal of Experimental Algorithmics (JEA) JEA Homepage archive
|
||||
// Volume 11, 2006, Article No. 1.7
|
||||
//
|
||||
// Brief summary of the algorithm:
|
||||
//
|
||||
// (1) Maintain a rank for each node that is consistent
|
||||
// with the topological sort of the graph. I.e., path from x to y
|
||||
// implies rank[x] < rank[y].
|
||||
// (2) When a new edge (x->y) is inserted, do nothing if rank[x] < rank[y].
|
||||
// (3) Otherwise: adjust ranks in the neighborhood of x and y.
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
typedef std::unordered_set<int32> NodeSet;
|
||||
template <typename T>
|
||||
struct VecStruct {
|
||||
typedef gtl::InlinedVector<T, 4> type;
|
||||
};
|
||||
template <typename T>
|
||||
using Vec = typename VecStruct<T>::type;
|
||||
|
||||
struct Node {
|
||||
Node() : in(4), out(4) {} // Small hashtables for in/out edges
|
||||
|
||||
int32 rank; // rank number assigned by Pearce-Kelly algorithm
|
||||
bool visited; // Temporary marker used by depth-first-search
|
||||
void* data; // User-supplied data
|
||||
NodeSet in; // List of immediate predecessor nodes in graph
|
||||
NodeSet out; // List of immediate successor nodes in graph
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
struct GraphCycles::Rep {
|
||||
Vec<Node*> nodes_;
|
||||
Vec<int32> free_nodes_; // Indices for unused entries in nodes_
|
||||
|
||||
// Temporary state.
|
||||
Vec<int32> deltaf_; // Results of forward DFS
|
||||
Vec<int32> deltab_; // Results of backward DFS
|
||||
Vec<int32> list_; // All nodes to reprocess
|
||||
Vec<int32> merged_; // Rank values to assign to list_ entries
|
||||
Vec<int32> stack_; // Emulates recursion stack when doing depth first search
|
||||
};
|
||||
|
||||
GraphCycles::GraphCycles() : rep_(new Rep) {}
|
||||
|
||||
GraphCycles::~GraphCycles() {
|
||||
for (int i = 0; i < rep_->nodes_.size(); i++) {
|
||||
delete rep_->nodes_[i];
|
||||
}
|
||||
delete rep_;
|
||||
}
|
||||
|
||||
bool GraphCycles::CheckInvariants() const {
|
||||
Rep* r = rep_;
|
||||
NodeSet ranks; // Set of ranks seen so far.
|
||||
for (int32 x = 0; x < r->nodes_.size(); x++) {
|
||||
Node* nx = r->nodes_[x];
|
||||
if (nx->visited) {
|
||||
LOG(FATAL) << "Did not clear visited marker on node " << x;
|
||||
}
|
||||
if (!ranks.insert(nx->rank).second) {
|
||||
LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank;
|
||||
}
|
||||
for (auto y : nx->out) {
|
||||
Node* ny = r->nodes_[y];
|
||||
if (nx->rank >= ny->rank) {
|
||||
LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment "
|
||||
<< nx->rank << "->" << ny->rank;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int32 GraphCycles::NewNode() {
|
||||
if (rep_->free_nodes_.empty()) {
|
||||
Node* n = new Node;
|
||||
n->visited = false;
|
||||
n->data = NULL;
|
||||
n->rank = rep_->nodes_.size();
|
||||
rep_->nodes_.push_back(n);
|
||||
return n->rank;
|
||||
} else {
|
||||
// Preserve preceding rank since the set of ranks in use must be
|
||||
// a permutation of [0,rep_->nodes_.size()-1].
|
||||
int32 r = rep_->free_nodes_.back();
|
||||
rep_->nodes_[r]->data = NULL;
|
||||
rep_->free_nodes_.pop_back();
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphCycles::RemoveNode(int32 node) {
|
||||
Node* x = rep_->nodes_[node];
|
||||
for (auto y : x->out) {
|
||||
rep_->nodes_[y]->in.erase(node);
|
||||
}
|
||||
for (auto y : x->in) {
|
||||
rep_->nodes_[y]->out.erase(node);
|
||||
}
|
||||
x->in.clear();
|
||||
x->out.clear();
|
||||
rep_->free_nodes_.push_back(node);
|
||||
}
|
||||
|
||||
void* GraphCycles::GetNodeData(int32 node) const {
|
||||
return rep_->nodes_[node]->data;
|
||||
}
|
||||
|
||||
void GraphCycles::SetNodeData(int32 node, void* data) {
|
||||
rep_->nodes_[node]->data = data;
|
||||
}
|
||||
|
||||
bool GraphCycles::HasEdge(int32 x, int32 y) const {
|
||||
return rep_->nodes_[x]->out.find(y) != rep_->nodes_[x]->out.end();
|
||||
}
|
||||
|
||||
void GraphCycles::RemoveEdge(int32 x, int32 y) {
|
||||
rep_->nodes_[x]->out.erase(y);
|
||||
rep_->nodes_[y]->in.erase(x);
|
||||
// No need to update the rank assignment since a previous valid
|
||||
// rank assignment remains valid after an edge deletion.
|
||||
}
|
||||
|
||||
static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound);
|
||||
static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound);
|
||||
static void Reorder(GraphCycles::Rep* r);
|
||||
static void Sort(const Vec<Node*>&, Vec<int32>* delta);
|
||||
static void MoveToList(GraphCycles::Rep* r, Vec<int32>* src, Vec<int32>* dst);
|
||||
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32>& nodes);
|
||||
|
||||
bool GraphCycles::InsertEdge(int32 x, int32 y) {
|
||||
if (x == y) return false;
|
||||
Rep* r = rep_;
|
||||
Node* nx = r->nodes_[x];
|
||||
if (!nx->out.insert(y).second) {
|
||||
// Edge already exists.
|
||||
return true;
|
||||
}
|
||||
|
||||
Node* ny = r->nodes_[y];
|
||||
ny->in.insert(x);
|
||||
|
||||
if (nx->rank <= ny->rank) {
|
||||
// New edge is consistent with existing rank assignment.
|
||||
return true;
|
||||
}
|
||||
|
||||
// Current rank assignments are incompatible with the new edge. Recompute.
|
||||
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
|
||||
if (!ForwardDFS(r, y, nx->rank)) {
|
||||
// Found a cycle. Undo the insertion and tell caller.
|
||||
nx->out.erase(y);
|
||||
ny->in.erase(x);
|
||||
// Since we do not call Reorder() on this path, clear any visited
|
||||
// markers left by ForwardDFS.
|
||||
ClearVisitedBits(r, r->deltaf_);
|
||||
return false;
|
||||
}
|
||||
BackwardDFS(r, x, ny->rank);
|
||||
Reorder(r);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound) {
|
||||
// Avoid recursion since stack space might be limited.
|
||||
// We instead keep a stack of nodes to visit.
|
||||
r->deltaf_.clear();
|
||||
r->stack_.clear();
|
||||
r->stack_.push_back(n);
|
||||
while (!r->stack_.empty()) {
|
||||
n = r->stack_.back();
|
||||
r->stack_.pop_back();
|
||||
Node* nn = r->nodes_[n];
|
||||
if (nn->visited) continue;
|
||||
|
||||
nn->visited = true;
|
||||
r->deltaf_.push_back(n);
|
||||
|
||||
for (auto w : nn->out) {
|
||||
Node* nw = r->nodes_[w];
|
||||
if (nw->rank == upper_bound) {
|
||||
return false; // Cycle
|
||||
}
|
||||
if (!nw->visited && nw->rank < upper_bound) {
|
||||
r->stack_.push_back(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound) {
|
||||
r->deltab_.clear();
|
||||
r->stack_.clear();
|
||||
r->stack_.push_back(n);
|
||||
while (!r->stack_.empty()) {
|
||||
n = r->stack_.back();
|
||||
r->stack_.pop_back();
|
||||
Node* nn = r->nodes_[n];
|
||||
if (nn->visited) continue;
|
||||
|
||||
nn->visited = true;
|
||||
r->deltab_.push_back(n);
|
||||
|
||||
for (auto w : nn->in) {
|
||||
Node* nw = r->nodes_[w];
|
||||
if (!nw->visited && lower_bound < nw->rank) {
|
||||
r->stack_.push_back(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void Reorder(GraphCycles::Rep* r) {
|
||||
Sort(r->nodes_, &r->deltab_);
|
||||
Sort(r->nodes_, &r->deltaf_);
|
||||
|
||||
// Adds contents of delta lists to list_ (backwards deltas first).
|
||||
r->list_.clear();
|
||||
MoveToList(r, &r->deltab_, &r->list_);
|
||||
MoveToList(r, &r->deltaf_, &r->list_);
|
||||
|
||||
// Produce sorted list of all ranks that will be reassigned.
|
||||
r->merged_.resize(r->deltab_.size() + r->deltaf_.size());
|
||||
std::merge(r->deltab_.begin(), r->deltab_.end(), r->deltaf_.begin(),
|
||||
r->deltaf_.end(), r->merged_.begin());
|
||||
|
||||
// Assign the ranks in order to the collected list.
|
||||
for (int32 i = 0; i < r->list_.size(); i++) {
|
||||
r->nodes_[r->list_[i]]->rank = r->merged_[i];
|
||||
}
|
||||
}
|
||||
|
||||
static void Sort(const Vec<Node*>& nodes, Vec<int32>* delta) {
|
||||
struct ByRank {
|
||||
const Vec<Node*>* nodes;
|
||||
bool operator()(int32 a, int32 b) const {
|
||||
return (*nodes)[a]->rank < (*nodes)[b]->rank;
|
||||
}
|
||||
};
|
||||
ByRank cmp;
|
||||
cmp.nodes = &nodes;
|
||||
std::sort(delta->begin(), delta->end(), cmp);
|
||||
}
|
||||
|
||||
static void MoveToList(GraphCycles::Rep* r, Vec<int32>* src, Vec<int32>* dst) {
|
||||
for (int32 i = 0; i < src->size(); i++) {
|
||||
int32 w = (*src)[i];
|
||||
(*src)[i] = r->nodes_[w]->rank; // Replace src entry with its rank
|
||||
r->nodes_[w]->visited = false; // Prepare for future DFS calls
|
||||
dst->push_back(w);
|
||||
}
|
||||
}
|
||||
|
||||
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32>& nodes) {
|
||||
for (int32 i = 0; i < nodes.size(); i++) {
|
||||
r->nodes_[nodes[i]]->visited = false;
|
||||
}
|
||||
}
|
||||
|
||||
int GraphCycles::FindPath(int32 x, int32 y, int max_path_len,
|
||||
int32 path[]) const {
|
||||
// Forward depth first search starting at x until we hit y.
|
||||
// As we descend into a node, we push it onto the path.
|
||||
// As we leave a node, we remove it from the path.
|
||||
int path_len = 0;
|
||||
|
||||
Rep* r = rep_;
|
||||
NodeSet seen;
|
||||
r->stack_.clear();
|
||||
r->stack_.push_back(x);
|
||||
while (!r->stack_.empty()) {
|
||||
int32 n = r->stack_.back();
|
||||
r->stack_.pop_back();
|
||||
if (n < 0) {
|
||||
// Marker to indicate that we are leaving a node
|
||||
path_len--;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (path_len < max_path_len) {
|
||||
path[path_len] = n;
|
||||
}
|
||||
path_len++;
|
||||
r->stack_.push_back(-1); // Will remove tentative path entry
|
||||
|
||||
if (n == y) {
|
||||
return path_len;
|
||||
}
|
||||
|
||||
for (auto w : r->nodes_[n]->out) {
|
||||
if (seen.insert(w).second) {
|
||||
r->stack_.push_back(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool GraphCycles::IsReachable(int32 x, int32 y) const {
|
||||
return FindPath(x, y, 0, NULL) > 0;
|
||||
}
|
||||
|
||||
bool GraphCycles::IsReachableNonConst(int32 x, int32 y) {
|
||||
if (x == y) return true;
|
||||
Rep* r = rep_;
|
||||
Node* nx = r->nodes_[x];
|
||||
Node* ny = r->nodes_[y];
|
||||
|
||||
if (nx->rank >= ny->rank) {
|
||||
// x cannot reach y since it is after it in the topological ordering
|
||||
return false;
|
||||
}
|
||||
|
||||
// See if x can reach y using a DFS search that is limited to y's rank
|
||||
bool reachable = !ForwardDFS(r, x, ny->rank);
|
||||
|
||||
// Clear any visited markers left by ForwardDFS.
|
||||
ClearVisitedBits(r, r->deltaf_);
|
||||
return reachable;
|
||||
}
|
||||
|
||||
bool GraphCycles::ContractEdge(int32 a, int32 b) {
|
||||
CHECK(HasEdge(a, b));
|
||||
RemoveEdge(a, b);
|
||||
|
||||
if (IsReachableNonConst(a, b)) {
|
||||
// Restore the graph to its original state.
|
||||
InsertEdge(a, b);
|
||||
return false;
|
||||
}
|
||||
|
||||
Node* nb = rep_->nodes_[b];
|
||||
std::unordered_set<int32> out = std::move(nb->out);
|
||||
std::unordered_set<int32> in = std::move(nb->in);
|
||||
for (auto y : out) {
|
||||
rep_->nodes_[y]->in.erase(b);
|
||||
}
|
||||
for (auto y : in) {
|
||||
rep_->nodes_[y]->out.erase(b);
|
||||
}
|
||||
rep_->free_nodes_.push_back(b);
|
||||
|
||||
for (auto y : out) {
|
||||
InsertEdge(a, y);
|
||||
}
|
||||
for (auto y : in) {
|
||||
InsertEdge(y, a);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_set<int32> GraphCycles::Successors(int32 node) {
|
||||
return rep_->nodes_[node]->out;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
128
tensorflow/compiler/jit/graphcycles/graphcycles.h
Normal file
128
tensorflow/compiler/jit/graphcycles/graphcycles.h
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
|
||||
|
||||
// GraphCycles detects the introduction of a cycle into a directed
|
||||
// graph that is being built up incrementally.
|
||||
//
|
||||
// Nodes are identified by small integers. It is not possible to
|
||||
// record multiple edges with the same (source, destination) pair;
|
||||
// requests to add an edge where one already exists are silently
|
||||
// ignored.
|
||||
//
|
||||
// It is also not possible to introduce a cycle; an attempt to insert
|
||||
// an edge that would introduce a cycle fails and returns false.
|
||||
//
|
||||
// GraphCycles uses no internal locking; calls into it should be
|
||||
// serialized externally.
|
||||
|
||||
// Performance considerations:
|
||||
// Works well on sparse graphs, poorly on dense graphs.
|
||||
// Extra information is maintained incrementally to detect cycles quickly.
|
||||
// InsertEdge() is very fast when the edge already exists, and reasonably fast
|
||||
// otherwise.
|
||||
// FindPath() is linear in the size of the graph.
|
||||
// The current implementation uses O(|V|+|E|) space.
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// NOTE!!!
|
||||
// For now a copy of this is forked to net/plaque. If you
|
||||
// find a bug or add a feature, please inform the owners of the
|
||||
// net/plaque copy in case it should be integrated.
|
||||
// NOTE!!!
|
||||
class GraphCycles {
|
||||
public:
|
||||
GraphCycles();
|
||||
~GraphCycles();
|
||||
|
||||
// Allocate an unused node id and return it.
|
||||
// The new node has a null pointer for its node data.
|
||||
// All node identifiers passed to other routines in this interface
|
||||
// must have been allocated by NewNode() and not yet deallocated
|
||||
// by RemoveNode().
|
||||
int32 NewNode();
|
||||
|
||||
// Remove "node" from the graph, deleting all edges to and from it.
|
||||
// After this call the identifier "node" it may no longer be used
|
||||
// as an argument to any routine until it has been reallocated with
|
||||
// NewNode().
|
||||
void RemoveNode(int32 node);
|
||||
|
||||
// Attempt to insert an edge from source_node to dest_node. If the
|
||||
// edge would introduce a cycle, return false without making any
|
||||
// changes. Otherwise add the edge and return true.
|
||||
bool InsertEdge(int32 source_node, int32 dest_node);
|
||||
|
||||
// Remove any edge that exists from source_node to dest_node.
|
||||
void RemoveEdge(int32 source_node, int32 dest_node);
|
||||
|
||||
// Return whether there is an edge directly from source_node to dest_node.
|
||||
bool HasEdge(int32 source_node, int32 dest_node) const;
|
||||
|
||||
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. 'b' is
|
||||
// removed from the graph, and edges to/from 'b' are replaced with edges
|
||||
// to/from 'a'. If contracting the edge would create a cycle, does nothing
|
||||
// and returns false.
|
||||
bool ContractEdge(int32 a, int32 b);
|
||||
|
||||
// Return whether dest_node is reachable from source_node
|
||||
// by following edges.
|
||||
bool IsReachable(int32 source_node, int32 dest_node) const;
|
||||
|
||||
// A faster non-thread-safe version of IsReachable.
|
||||
bool IsReachableNonConst(int32 source_node, int32 dest_node);
|
||||
|
||||
// Return or set the node data for a node. This data is unused
|
||||
// by the implementation.
|
||||
void *GetNodeData(int32 node) const;
|
||||
void SetNodeData(int32 node, void *data);
|
||||
|
||||
// Find a path from "source" to "dest". If such a path exists, place the
|
||||
// node IDs of the nodes on the path in the array path[], and return the
|
||||
// number of nodes on the path. If the path is longer than max_path_len
|
||||
// nodes, only the first max_path_len nodes are placed in path[]. The client
|
||||
// should compare the return value with max_path_len" to see when this
|
||||
// occurs. If no path exists, return 0. Any valid path stored in path[]
|
||||
// will start with "source" and end with "dest". There is no guarantee that
|
||||
// the path is the shortest, but no node will appear twice in the path,
|
||||
// except the source and destination node if they are identical; therefore,
|
||||
// the return value is at most one greater than the number of nodes in the
|
||||
// graph.
|
||||
int FindPath(int32 source, int32 dest, int max_path_len, int32 path[]) const;
|
||||
|
||||
// Check internal invariants. Crashes on failure, returns true on success.
|
||||
// Expensive: should only be called from graphcycles_test.cc.
|
||||
bool CheckInvariants() const;
|
||||
|
||||
std::unordered_set<int32> Successors(int32 node);
|
||||
|
||||
// ----------------------------------------------------
|
||||
struct Rep;
|
||||
|
||||
private:
|
||||
Rep *rep_; // opaque representation
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GraphCycles);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
|
515
tensorflow/compiler/jit/graphcycles/graphcycles_test.cc
Normal file
515
tensorflow/compiler/jit/graphcycles/graphcycles_test.cc
Normal file
@ -0,0 +1,515 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A test for the GraphCycles interface.
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
|
||||
#include <random>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::int32;
|
||||
using tensorflow::string;
|
||||
|
||||
// We emulate a GraphCycles object with a node vector and an edge vector.
|
||||
// We then compare the two implementations.
|
||||
|
||||
typedef std::vector<int> Nodes;
|
||||
struct Edge {
|
||||
int from;
|
||||
int to;
|
||||
};
|
||||
typedef std::vector<Edge> Edges;
|
||||
|
||||
// Return whether "to" is reachable from "from".
|
||||
static bool IsReachable(Edges *edges, int from, int to,
|
||||
std::unordered_set<int> *seen) {
|
||||
seen->insert(from); // we are investigating "from"; don't do it again
|
||||
if (from == to) return true;
|
||||
for (int i = 0; i != edges->size(); i++) {
|
||||
Edge *edge = &(*edges)[i];
|
||||
if (edge->from == from) {
|
||||
if (edge->to == to) { // success via edge directly
|
||||
return true;
|
||||
} else if (seen->find(edge->to) == seen->end() && // success via edge
|
||||
IsReachable(edges, edge->to, to, seen)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void PrintNodes(Nodes *nodes) {
|
||||
LOG(INFO) << "NODES (" << nodes->size() << ")";
|
||||
for (int i = 0; i != nodes->size(); i++) {
|
||||
LOG(INFO) << (*nodes)[i];
|
||||
}
|
||||
}
|
||||
|
||||
static void PrintEdges(Edges *edges) {
|
||||
LOG(INFO) << "EDGES (" << edges->size() << ")";
|
||||
for (int i = 0; i != edges->size(); i++) {
|
||||
int a = (*edges)[i].from;
|
||||
int b = (*edges)[i].to;
|
||||
LOG(INFO) << a << " " << b;
|
||||
}
|
||||
LOG(INFO) << "---";
|
||||
}
|
||||
|
||||
static void PrintGCEdges(Nodes *nodes, tensorflow::GraphCycles *gc) {
|
||||
LOG(INFO) << "GC EDGES";
|
||||
for (int i = 0; i != nodes->size(); i++) {
|
||||
for (int j = 0; j != nodes->size(); j++) {
|
||||
int a = (*nodes)[i];
|
||||
int b = (*nodes)[j];
|
||||
if (gc->HasEdge(a, b)) {
|
||||
LOG(INFO) << a << " " << b;
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "---";
|
||||
}
|
||||
|
||||
static void PrintTransitiveClosure(Nodes *nodes, Edges *edges,
|
||||
tensorflow::GraphCycles *gc) {
|
||||
LOG(INFO) << "Transitive closure";
|
||||
for (int i = 0; i != nodes->size(); i++) {
|
||||
for (int j = 0; j != nodes->size(); j++) {
|
||||
int a = (*nodes)[i];
|
||||
int b = (*nodes)[j];
|
||||
std::unordered_set<int> seen;
|
||||
if (IsReachable(edges, a, b, &seen)) {
|
||||
LOG(INFO) << a << " " << b;
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "---";
|
||||
}
|
||||
|
||||
static void PrintGCTransitiveClosure(Nodes *nodes,
|
||||
tensorflow::GraphCycles *gc) {
|
||||
LOG(INFO) << "GC Transitive closure";
|
||||
for (int i = 0; i != nodes->size(); i++) {
|
||||
for (int j = 0; j != nodes->size(); j++) {
|
||||
int a = (*nodes)[i];
|
||||
int b = (*nodes)[j];
|
||||
if (gc->IsReachable(a, b)) {
|
||||
LOG(INFO) << a << " " << b;
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "---";
|
||||
}
|
||||
|
||||
static void CheckTransitiveClosure(Nodes *nodes, Edges *edges,
|
||||
tensorflow::GraphCycles *gc) {
|
||||
std::unordered_set<int> seen;
|
||||
for (int i = 0; i != nodes->size(); i++) {
|
||||
for (int j = 0; j != nodes->size(); j++) {
|
||||
seen.clear();
|
||||
int a = (*nodes)[i];
|
||||
int b = (*nodes)[j];
|
||||
bool gc_reachable = gc->IsReachable(a, b);
|
||||
CHECK_EQ(gc_reachable, gc->IsReachableNonConst(a, b));
|
||||
bool reachable = IsReachable(edges, a, b, &seen);
|
||||
if (gc_reachable != reachable) {
|
||||
PrintEdges(edges);
|
||||
PrintGCEdges(nodes, gc);
|
||||
PrintTransitiveClosure(nodes, edges, gc);
|
||||
PrintGCTransitiveClosure(nodes, gc);
|
||||
LOG(FATAL) << "gc_reachable " << gc_reachable << " reachable "
|
||||
<< reachable << " a " << a << " b " << b;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void CheckEdges(Nodes *nodes, Edges *edges,
|
||||
tensorflow::GraphCycles *gc) {
|
||||
int count = 0;
|
||||
for (int i = 0; i != edges->size(); i++) {
|
||||
int a = (*edges)[i].from;
|
||||
int b = (*edges)[i].to;
|
||||
if (!gc->HasEdge(a, b)) {
|
||||
PrintEdges(edges);
|
||||
PrintGCEdges(nodes, gc);
|
||||
LOG(FATAL) << "!gc->HasEdge(" << a << ", " << b << ")";
|
||||
}
|
||||
}
|
||||
for (int i = 0; i != nodes->size(); i++) {
|
||||
for (int j = 0; j != nodes->size(); j++) {
|
||||
int a = (*nodes)[i];
|
||||
int b = (*nodes)[j];
|
||||
if (gc->HasEdge(a, b)) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (count != edges->size()) {
|
||||
PrintEdges(edges);
|
||||
PrintGCEdges(nodes, gc);
|
||||
LOG(FATAL) << "edges->size() " << edges->size() << " count " << count;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the index of a randomly chosen node in *nodes.
|
||||
// Requires *nodes be non-empty.
|
||||
static int RandomNode(std::mt19937 *rnd, Nodes *nodes) {
|
||||
std::uniform_int_distribution<int> distribution(0, nodes->size() - 1);
|
||||
return distribution(*rnd);
|
||||
}
|
||||
|
||||
// Returns the index of a randomly chosen edge in *edges.
|
||||
// Requires *edges be non-empty.
|
||||
static int RandomEdge(std::mt19937 *rnd, Edges *edges) {
|
||||
std::uniform_int_distribution<int> distribution(0, edges->size() - 1);
|
||||
return distribution(*rnd);
|
||||
}
|
||||
|
||||
// Returns the index of edge (from, to) in *edges or -1 if it is not in *edges.
|
||||
static int EdgeIndex(Edges *edges, int from, int to) {
|
||||
int i = 0;
|
||||
while (i != edges->size() &&
|
||||
((*edges)[i].from != from || (*edges)[i].to != to)) {
|
||||
i++;
|
||||
}
|
||||
return i == edges->size() ? -1 : i;
|
||||
}
|
||||
|
||||
TEST(GraphCycles, RandomizedTest) {
|
||||
Nodes nodes;
|
||||
Edges edges; // from, to
|
||||
tensorflow::GraphCycles graph_cycles;
|
||||
static const int kMaxNodes = 7; // use <= 7 nodes to keep test short
|
||||
static const int kDataOffset = 17; // an offset to the node-specific data
|
||||
int n = 100000;
|
||||
int op = 0;
|
||||
std::mt19937 rnd(tensorflow::testing::RandomSeed() + 1);
|
||||
|
||||
for (int iter = 0; iter != n; iter++) {
|
||||
if ((iter % 10000) == 0) VLOG(0) << "Iter " << iter << " of " << n;
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
LOG(INFO) << "===============";
|
||||
LOG(INFO) << "last op " << op;
|
||||
PrintNodes(&nodes);
|
||||
PrintEdges(&edges);
|
||||
PrintGCEdges(&nodes, &graph_cycles);
|
||||
}
|
||||
for (int i = 0; i != nodes.size(); i++) {
|
||||
ASSERT_EQ(reinterpret_cast<intptr_t>(graph_cycles.GetNodeData(i)),
|
||||
i + kDataOffset)
|
||||
<< " node " << i;
|
||||
}
|
||||
CheckEdges(&nodes, &edges, &graph_cycles);
|
||||
CheckTransitiveClosure(&nodes, &edges, &graph_cycles);
|
||||
std::uniform_int_distribution<int> distribution(0, 5);
|
||||
op = distribution(rnd);
|
||||
switch (op) {
|
||||
case 0: // Add a node
|
||||
if (nodes.size() < kMaxNodes) {
|
||||
int new_node = graph_cycles.NewNode();
|
||||
ASSERT_NE(-1, new_node);
|
||||
VLOG(1) << "adding node " << new_node;
|
||||
ASSERT_EQ(0, graph_cycles.GetNodeData(new_node));
|
||||
graph_cycles.SetNodeData(
|
||||
new_node, reinterpret_cast<void *>(
|
||||
static_cast<intptr_t>(new_node + kDataOffset)));
|
||||
ASSERT_GE(new_node, 0);
|
||||
for (int i = 0; i != nodes.size(); i++) {
|
||||
ASSERT_NE(nodes[i], new_node);
|
||||
}
|
||||
nodes.push_back(new_node);
|
||||
}
|
||||
break;
|
||||
|
||||
case 1: // Remove a node
|
||||
if (nodes.size() > 0) {
|
||||
int node_index = RandomNode(&rnd, &nodes);
|
||||
int node = nodes[node_index];
|
||||
nodes[node_index] = nodes.back();
|
||||
nodes.pop_back();
|
||||
VLOG(1) << "removing node " << node;
|
||||
graph_cycles.RemoveNode(node);
|
||||
int i = 0;
|
||||
while (i != edges.size()) {
|
||||
if (edges[i].from == node || edges[i].to == node) {
|
||||
edges[i] = edges.back();
|
||||
edges.pop_back();
|
||||
} else {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case 2: // Add an edge
|
||||
if (nodes.size() > 0) {
|
||||
int from = RandomNode(&rnd, &nodes);
|
||||
int to = RandomNode(&rnd, &nodes);
|
||||
if (EdgeIndex(&edges, nodes[from], nodes[to]) == -1) {
|
||||
if (graph_cycles.InsertEdge(nodes[from], nodes[to])) {
|
||||
Edge new_edge;
|
||||
new_edge.from = nodes[from];
|
||||
new_edge.to = nodes[to];
|
||||
edges.push_back(new_edge);
|
||||
} else {
|
||||
std::unordered_set<int> seen;
|
||||
ASSERT_TRUE(IsReachable(&edges, nodes[to], nodes[from], &seen))
|
||||
<< "Edge " << nodes[to] << "->" << nodes[from];
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case 3: // Remove an edge
|
||||
if (edges.size() > 0) {
|
||||
int i = RandomEdge(&rnd, &edges);
|
||||
int from = edges[i].from;
|
||||
int to = edges[i].to;
|
||||
ASSERT_EQ(i, EdgeIndex(&edges, from, to));
|
||||
edges[i] = edges.back();
|
||||
edges.pop_back();
|
||||
ASSERT_EQ(-1, EdgeIndex(&edges, from, to));
|
||||
VLOG(1) << "removing edge " << from << " " << to;
|
||||
graph_cycles.RemoveEdge(from, to);
|
||||
}
|
||||
break;
|
||||
|
||||
case 4: // Check a path
|
||||
if (nodes.size() > 0) {
|
||||
int from = RandomNode(&rnd, &nodes);
|
||||
int to = RandomNode(&rnd, &nodes);
|
||||
int32 path[2 * kMaxNodes];
|
||||
int path_len = graph_cycles.FindPath(nodes[from], nodes[to],
|
||||
2 * kMaxNodes, path);
|
||||
std::unordered_set<int> seen;
|
||||
bool reachable = IsReachable(&edges, nodes[from], nodes[to], &seen);
|
||||
bool gc_reachable = graph_cycles.IsReachable(nodes[from], nodes[to]);
|
||||
ASSERT_EQ(gc_reachable,
|
||||
graph_cycles.IsReachableNonConst(nodes[from], nodes[to]));
|
||||
ASSERT_EQ(path_len != 0, reachable);
|
||||
ASSERT_EQ(path_len != 0, gc_reachable);
|
||||
// In the following line, we add one because a node can appear
|
||||
// twice, if the path is from that node to itself, perhaps via
|
||||
// every other node.
|
||||
ASSERT_LE(path_len, kMaxNodes + 1);
|
||||
if (path_len != 0) {
|
||||
ASSERT_EQ(nodes[from], path[0]);
|
||||
ASSERT_EQ(nodes[to], path[path_len - 1]);
|
||||
for (int i = 1; i < path_len; i++) {
|
||||
ASSERT_NE(-1, EdgeIndex(&edges, path[i - 1], path[i]));
|
||||
ASSERT_TRUE(graph_cycles.HasEdge(path[i - 1], path[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case 5: // Check invariants
|
||||
CHECK(graph_cycles.CheckInvariants());
|
||||
break;
|
||||
|
||||
default:
|
||||
LOG(FATAL);
|
||||
}
|
||||
|
||||
// Very rarely, test graph expansion by adding then removing many nodes.
|
||||
std::bernoulli_distribution rarely(1.0 / 1024.0);
|
||||
if (rarely(rnd)) {
|
||||
VLOG(3) << "Graph expansion";
|
||||
CheckEdges(&nodes, &edges, &graph_cycles);
|
||||
CheckTransitiveClosure(&nodes, &edges, &graph_cycles);
|
||||
for (int i = 0; i != 256; i++) {
|
||||
int new_node = graph_cycles.NewNode();
|
||||
ASSERT_NE(-1, new_node);
|
||||
VLOG(1) << "adding node " << new_node;
|
||||
ASSERT_GE(new_node, 0);
|
||||
ASSERT_EQ(0, graph_cycles.GetNodeData(new_node));
|
||||
graph_cycles.SetNodeData(
|
||||
new_node, reinterpret_cast<void *>(
|
||||
static_cast<intptr_t>(new_node + kDataOffset)));
|
||||
for (int j = 0; j != nodes.size(); j++) {
|
||||
ASSERT_NE(nodes[j], new_node);
|
||||
}
|
||||
nodes.push_back(new_node);
|
||||
}
|
||||
for (int i = 0; i != 256; i++) {
|
||||
ASSERT_GT(nodes.size(), 0);
|
||||
int node_index = RandomNode(&rnd, &nodes);
|
||||
int node = nodes[node_index];
|
||||
nodes[node_index] = nodes.back();
|
||||
nodes.pop_back();
|
||||
VLOG(1) << "removing node " << node;
|
||||
graph_cycles.RemoveNode(node);
|
||||
int j = 0;
|
||||
while (j != edges.size()) {
|
||||
if (edges[j].from == node || edges[j].to == node) {
|
||||
edges[j] = edges.back();
|
||||
edges.pop_back();
|
||||
} else {
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK(graph_cycles.CheckInvariants());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class GraphCyclesTest : public ::testing::Test {
|
||||
public:
|
||||
tensorflow::GraphCycles g_;
|
||||
|
||||
// Test relies on ith NewNode() call returning Node numbered i
|
||||
GraphCyclesTest() {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
CHECK_EQ(i, g_.NewNode());
|
||||
}
|
||||
CHECK(g_.CheckInvariants());
|
||||
}
|
||||
|
||||
bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); }
|
||||
|
||||
void AddMultiples() {
|
||||
// For every node x > 0: add edge to 2*x, 3*x
|
||||
for (int x = 1; x < 25; x++) {
|
||||
EXPECT_TRUE(AddEdge(x, 2 * x)) << x;
|
||||
EXPECT_TRUE(AddEdge(x, 3 * x)) << x;
|
||||
}
|
||||
CHECK(g_.CheckInvariants());
|
||||
}
|
||||
|
||||
string Path(int x, int y) {
|
||||
static const int kPathSize = 5;
|
||||
int32 path[kPathSize];
|
||||
int np = g_.FindPath(x, y, kPathSize, path);
|
||||
string result;
|
||||
for (int i = 0; i < np; i++) {
|
||||
if (i >= kPathSize) {
|
||||
result += " ...";
|
||||
break;
|
||||
}
|
||||
if (!result.empty()) result.push_back(' ');
|
||||
char buf[20];
|
||||
snprintf(buf, sizeof(buf), "%d", path[i]);
|
||||
result += buf;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GraphCyclesTest, NoCycle) {
|
||||
AddMultiples();
|
||||
CHECK(g_.CheckInvariants());
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, SimpleCycle) {
|
||||
AddMultiples();
|
||||
EXPECT_FALSE(AddEdge(8, 4));
|
||||
EXPECT_EQ("4 8", Path(4, 8));
|
||||
CHECK(g_.CheckInvariants());
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, IndirectCycle) {
|
||||
AddMultiples();
|
||||
EXPECT_TRUE(AddEdge(16, 9));
|
||||
CHECK(g_.CheckInvariants());
|
||||
EXPECT_FALSE(AddEdge(9, 2));
|
||||
EXPECT_EQ("2 4 8 16 9", Path(2, 9));
|
||||
CHECK(g_.CheckInvariants());
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, LongPath) {
|
||||
ASSERT_TRUE(AddEdge(2, 4));
|
||||
ASSERT_TRUE(AddEdge(4, 6));
|
||||
ASSERT_TRUE(AddEdge(6, 8));
|
||||
ASSERT_TRUE(AddEdge(8, 10));
|
||||
ASSERT_TRUE(AddEdge(10, 12));
|
||||
ASSERT_FALSE(AddEdge(12, 2));
|
||||
EXPECT_EQ("2 4 6 8 10 ...", Path(2, 12));
|
||||
CHECK(g_.CheckInvariants());
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, RemoveNode) {
|
||||
ASSERT_TRUE(AddEdge(1, 2));
|
||||
ASSERT_TRUE(AddEdge(2, 3));
|
||||
ASSERT_TRUE(AddEdge(3, 4));
|
||||
ASSERT_TRUE(AddEdge(4, 5));
|
||||
g_.RemoveNode(3);
|
||||
ASSERT_TRUE(AddEdge(5, 1));
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, ManyEdges) {
|
||||
const int N = 50;
|
||||
for (int i = 0; i < N; i++) {
|
||||
for (int j = 1; j < N; j++) {
|
||||
ASSERT_TRUE(AddEdge(i, i + j));
|
||||
}
|
||||
}
|
||||
CHECK(g_.CheckInvariants());
|
||||
ASSERT_TRUE(AddEdge(2 * N - 1, 0));
|
||||
CHECK(g_.CheckInvariants());
|
||||
ASSERT_FALSE(AddEdge(10, 9));
|
||||
CHECK(g_.CheckInvariants());
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, ContractEdge) {
|
||||
ASSERT_TRUE(AddEdge(1, 2));
|
||||
ASSERT_TRUE(AddEdge(1, 3));
|
||||
ASSERT_TRUE(AddEdge(2, 3));
|
||||
ASSERT_TRUE(AddEdge(2, 4));
|
||||
ASSERT_TRUE(AddEdge(3, 4));
|
||||
|
||||
EXPECT_FALSE(g_.ContractEdge(1, 3));
|
||||
CHECK(g_.CheckInvariants());
|
||||
EXPECT_TRUE(g_.HasEdge(1, 3));
|
||||
|
||||
EXPECT_TRUE(g_.ContractEdge(1, 2));
|
||||
CHECK(g_.CheckInvariants());
|
||||
EXPECT_TRUE(g_.HasEdge(1, 3));
|
||||
EXPECT_TRUE(g_.HasEdge(1, 4));
|
||||
EXPECT_TRUE(g_.HasEdge(3, 4));
|
||||
|
||||
EXPECT_TRUE(g_.ContractEdge(1, 3));
|
||||
CHECK(g_.CheckInvariants());
|
||||
EXPECT_TRUE(g_.HasEdge(1, 4));
|
||||
}
|
||||
|
||||
static void BM_StressTest(int iters, int num_nodes) {
|
||||
while (iters > 0) {
|
||||
tensorflow::GraphCycles g;
|
||||
int32 *nodes = new int32[num_nodes];
|
||||
for (int i = 0; i < num_nodes; i++) {
|
||||
nodes[i] = g.NewNode();
|
||||
}
|
||||
for (int i = 0; i < num_nodes && iters > 0; i++, iters--) {
|
||||
int end = std::min(num_nodes, i + 5);
|
||||
for (int j = i + 1; j < end; j++) {
|
||||
if (nodes[i] >= 0 && nodes[j] >= 0) {
|
||||
CHECK(g.InsertEdge(nodes[i], nodes[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
delete[] nodes;
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_StressTest)->Range(2048, 1048576);
|
37
tensorflow/compiler/jit/jit_compilation_pass_registration.cc
Normal file
37
tensorflow/compiler/jit/jit_compilation_pass_registration.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
|
||||
MarkForCompilationPass);
|
||||
|
||||
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
|
||||
// also need to run it after the graph been rewritten to have _Send nodes added
|
||||
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
|
||||
// name, and encapsulation might remove that node from the graph.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
|
||||
EncapsulateSubgraphsPass);
|
||||
|
||||
// Must run after EncapsulateSubgraphsPass.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
|
||||
BuildXlaLaunchOpsPass);
|
||||
|
||||
} // namespace tensorflow
|
67
tensorflow/compiler/jit/legacy_flags/BUILD
Normal file
67
tensorflow/compiler/jit/legacy_flags/BUILD
Normal file
@ -0,0 +1,67 @@
|
||||
# Legacy command line flags for the XLA bridge libraries.
|
||||
|
||||
# Please do not add more flags to this package.
|
||||
|
||||
# The XLA bridge libraries were written in an environment that allowed
|
||||
# command-line flags to be scattered freely throughout the libraries. This
|
||||
# model, while initially convenient, leads to a proliferation in unused command
|
||||
# line flags in tests and binaries, and serious problems in servers, where one
|
||||
# might wish parameters to be different in independent RPC calls to the same
|
||||
# routine.
|
||||
#
|
||||
# Please don't add more flags. If you're a library author, pass options and
|
||||
# parameters explicitly through the library's interface.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
cc_library(
|
||||
name = "encapsulate_subgraphs_pass_flags",
|
||||
srcs = ["encapsulate_subgraphs_pass_flags.cc"],
|
||||
hdrs = ["encapsulate_subgraphs_pass_flags.h"],
|
||||
deps =
|
||||
[
|
||||
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mark_for_compilation_pass_flags",
|
||||
srcs = ["mark_for_compilation_pass_flags.cc"],
|
||||
hdrs = ["mark_for_compilation_pass_flags.h"],
|
||||
deps =
|
||||
[
|
||||
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_check_op_flags",
|
||||
srcs = ["parallel_check_op_flags.cc"],
|
||||
hdrs = ["parallel_check_op_flags.h"],
|
||||
deps =
|
||||
[
|
||||
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -0,0 +1,63 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module.
|
||||
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static EncapsulateSubgraphsPassFlags* flags;
|
||||
static std::vector<Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new EncapsulateSubgraphsPassFlags;
|
||||
flags->tf_xla_parallel_checking = false;
|
||||
flag_list = new std::vector<Flag>({
|
||||
Flag("tf_xla_parallel_checking", &flags->tf_xla_parallel_checking,
|
||||
"Debug tool. Runs both JIT-compiled and interpreted graphs in "
|
||||
"parallel and verifies they produce the same outputs."),
|
||||
});
|
||||
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with the XLA bridge's
|
||||
// encapsulate_subgraphs_pass module.
|
||||
void AppendEncapsulateSubgraphsPassFlags(std::vector<Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the EncapsulateSubgraphsPassFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace tensorflow
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
|
||||
|
||||
// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with the XLA bridge's
|
||||
// encapsulate_subgraphs_pass module.
|
||||
void AppendEncapsulateSubgraphsPassFlags(
|
||||
std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with the XLA bridge's
|
||||
// encapsulate_subgraphs_pass module.
|
||||
typedef struct {
|
||||
bool tf_xla_parallel_checking; // Debug tool. Runs both JIT-compiled and
|
||||
// interpreted graphs in parallel and verifies
|
||||
// they produce the same outputs.
|
||||
} EncapsulateSubgraphsPassFlags;
|
||||
|
||||
// Return a pointer to the EncapsulateSubgraphsPassFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
|
||||
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static MarkForCompilationPassFlags* flags;
|
||||
static std::vector<Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new MarkForCompilationPassFlags;
|
||||
flags->tf_xla_auto_jit = 0;
|
||||
flags->tf_xla_min_cluster_size = 2;
|
||||
flags->tf_xla_max_cluster_size = std::numeric_limits<int32>::max();
|
||||
flags->tf_xla_clustering_debug = false;
|
||||
flag_list = new std::vector<Flag>({
|
||||
Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
|
||||
"Control compilation of operators into XLA computations on CPU and "
|
||||
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
|
||||
"things very likely to be improved; 2 = on for everything. "
|
||||
"Experimental."),
|
||||
Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size,
|
||||
"Minimum number of operators in an XLA compilation. Ignored for "
|
||||
"operators placed on an XLA device or operators explicitly marked "
|
||||
"for compilation."),
|
||||
Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size,
|
||||
"Maximum number of operators in an XLA compilation."),
|
||||
Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
|
||||
"Dump graphs during XLA compilation."),
|
||||
});
|
||||
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with the XLA bridge's
|
||||
// mark_for_compilation_pass module.
|
||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the MarkForCompilationPassFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace tensorflow
|
@ -0,0 +1,59 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
|
||||
|
||||
// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with the XLA bridge's
|
||||
// mark_for_compilation_pass module.
|
||||
void AppendMarkForCompilationPassFlags(
|
||||
std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with the XLA bridge's
|
||||
// mark_for_compilation_pass module.
|
||||
typedef struct {
|
||||
int32 tf_xla_auto_jit; // Control compilation of operators into XLA
|
||||
// computations on CPU and GPU devices. 0 = use
|
||||
// ConfigProto setting; -1 = off; 1 = on for things
|
||||
// very likely to be improved; 2 = on for everything.
|
||||
// Experimental.
|
||||
int32 tf_xla_min_cluster_size; // Minimum number of operators in an XLA
|
||||
// compilation. Ignored for operators placed
|
||||
// on an XLA device or operators explicitly
|
||||
// marked for compilation.
|
||||
int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA
|
||||
// compilation.
|
||||
bool tf_xla_clustering_debug; // Dump graphs during XLA compilation.
|
||||
} MarkForCompilationPassFlags;
|
||||
|
||||
// Return a pointer to the MarkForCompilationPassFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
|
@ -0,0 +1,68 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for the XLA bridge's parallel_check_op module.
|
||||
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static ParallelCheckOpFlags* flags;
|
||||
static std::vector<Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new ParallelCheckOpFlags;
|
||||
flags->parallel_check_failfast = true;
|
||||
flags->parallel_check_atol = "1e-5";
|
||||
flags->parallel_check_rtol = "1e-5";
|
||||
flag_list = new std::vector<Flag>({
|
||||
Flag("parallel_check_failfast", &flags->parallel_check_failfast,
|
||||
"Fail immediately on first parallel-check comparison error."),
|
||||
Flag("parallel_check_atol", &flags->parallel_check_atol,
|
||||
"Absolute error tolerance for parallel-check comparison."),
|
||||
Flag("parallel_check_rtol", &flags->parallel_check_rtol,
|
||||
"Relative error tolerance for parallel-check comparison."),
|
||||
});
|
||||
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with the XLA bridge's
|
||||
// parallel_check_op module.
|
||||
void AppendParallelCheckOpFlags(std::vector<Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the ParallelCheckOpFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
ParallelCheckOpFlags* GetParallelCheckOpFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace tensorflow
|
@ -0,0 +1,52 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
|
||||
|
||||
// Legacy flags for the XLA bridge's parallel_check_op module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with the XLA bridge's
|
||||
// parallel_check_op module.
|
||||
void AppendParallelCheckOpFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with the XLA bridge's
|
||||
// parallel_check_op module.
|
||||
typedef struct {
|
||||
bool parallel_check_failfast; // Fail immediately on first parallel-check
|
||||
// comparison error.
|
||||
string parallel_check_atol; // Absolute error tolerance for parallel-check
|
||||
// comparison.
|
||||
string parallel_check_rtol; // Relative error tolerance for parallel-check
|
||||
// comparison.
|
||||
} ParallelCheckOpFlags;
|
||||
|
||||
// Return a pointer to the ParallelCheckOpFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
ParallelCheckOpFlags* GetParallelCheckOpFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
|
534
tensorflow/compiler/jit/mark_for_compilation_pass.cc
Normal file
534
tensorflow/compiler/jit/mark_for_compilation_pass.cc
Normal file
@ -0,0 +1,534 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kXlaClusterAttr = "_XlaCluster";
|
||||
|
||||
namespace {
|
||||
|
||||
bool HasXLAKernel(const NodeDef& node_def, DeviceType jit_device_type) {
|
||||
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
|
||||
// is really a kind of function call and will be handled by
|
||||
// IsCompilableCall().
|
||||
if (node_def.op() == "SymbolicGradient") return false;
|
||||
return FindKernelDef(jit_device_type, node_def, nullptr, nullptr).ok();
|
||||
}
|
||||
|
||||
// Make sure we don't recurse infinitely on recursive functions.
|
||||
const int kMaxRecursionDepth = 5;
|
||||
|
||||
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
|
||||
int depth, FunctionLibraryRuntime* lib_runtime);
|
||||
|
||||
// Tests whether 'while_def' is a completely compilable loop.
|
||||
// Every operator in the condition and body functions must be compilable for a
|
||||
// while loop to be compilable.
|
||||
bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type,
|
||||
int depth, FunctionLibraryRuntime* lib_runtime) {
|
||||
VLOG(2) << "Loop marking: " << while_def.op();
|
||||
|
||||
const NameAttrList* name_attr;
|
||||
NodeDef call;
|
||||
Status status;
|
||||
status = GetNodeAttr(while_def, "cond", &name_attr);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Missing 'cond' attribute on While node.";
|
||||
return false;
|
||||
}
|
||||
const string cond_func = name_attr->name();
|
||||
call.set_name("while_cond");
|
||||
call.set_op(cond_func);
|
||||
*call.mutable_attr() = name_attr->attr();
|
||||
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
|
||||
VLOG(2) << "Can't compile loop condition: " << cond_func;
|
||||
return false;
|
||||
}
|
||||
status = GetNodeAttr(while_def, "body", &name_attr);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Missing 'body' attribute on While node.";
|
||||
return false;
|
||||
}
|
||||
const string body_func = name_attr->name();
|
||||
call.set_name("while_body");
|
||||
call.set_op(body_func);
|
||||
*call.mutable_attr() = name_attr->attr();
|
||||
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
|
||||
VLOG(2) << "Can't compile loop body: " << body_func;
|
||||
return false;
|
||||
}
|
||||
VLOG(2) << "Loop is compilable.";
|
||||
return true;
|
||||
}
|
||||
|
||||
// Tests whether 'call_def' is a call to a completely compilable function.
|
||||
// Every operator in the function must be compilable for a function to be
|
||||
// compilable.
|
||||
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
|
||||
int depth, FunctionLibraryRuntime* lib_runtime) {
|
||||
VLOG(2) << "Function marking: " << call_def.op();
|
||||
|
||||
if (depth > kMaxRecursionDepth) {
|
||||
VLOG(2) << "Function depth limit exceeded";
|
||||
return false;
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status =
|
||||
lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
|
||||
return false;
|
||||
}
|
||||
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
|
||||
CHECK(fbody);
|
||||
|
||||
for (Node* node : fbody->graph->nodes()) {
|
||||
if (node->IsSource() || node->IsSink()) continue;
|
||||
if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue;
|
||||
if (node->def().op() == "While") {
|
||||
// Handle functional While loop (not in open source build).
|
||||
return IsCompilableWhile(node->def(), jit_device_type, depth + 1,
|
||||
lib_runtime);
|
||||
}
|
||||
if (!HasXLAKernel(node->def(), jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type, depth + 1,
|
||||
lib_runtime)) {
|
||||
VLOG(2) << "Function marking failed: unsupported op " << node->name()
|
||||
<< ": " << node->def().ShortDebugString();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
VLOG(2) << "Function is compilable: " << call_def.op();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns the DeviceType corresponding to 'device'.
|
||||
Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
|
||||
return errors::Internal("Malformed assigned device '", device, "'");
|
||||
}
|
||||
*device_type = DeviceType(parsed.type);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FindCompilationCandidates(
|
||||
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
|
||||
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
|
||||
std::unordered_set<Node*>* candidates) {
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<FunctionLibraryRuntime> lib_runtime(NewFunctionLibraryRuntime(
|
||||
nullptr, env, nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts));
|
||||
|
||||
for (Node* node : graph.nodes()) {
|
||||
if (node->IsSource() || node->IsSink()) continue;
|
||||
|
||||
DeviceType device_type("");
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeviceTypeOfDevice(node->assigned_device_name(), &device_type));
|
||||
|
||||
if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue;
|
||||
|
||||
const string* jit_device_name;
|
||||
CHECK(XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device_name,
|
||||
/*requires_jit=*/nullptr));
|
||||
DeviceType jit_device_type(*jit_device_name);
|
||||
if (!HasXLAKernel(node->def(), jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) {
|
||||
VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
|
||||
<< ": " << node->def().op();
|
||||
continue;
|
||||
}
|
||||
if (node->def().op() == "While" &&
|
||||
!IsCompilableWhile(node->def(), jit_device_type, 0,
|
||||
lib_runtime.get())) {
|
||||
continue;
|
||||
}
|
||||
candidates->insert(node);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Union-Find data structure used to compute clusters. We use our own
|
||||
// implementation because we want one key feature: when merging clusters, we
|
||||
// need to know which value becomes the representative of the merged clusters.
|
||||
// We use the representatives to name nodes in a cycle detection graph, and we
|
||||
// need to control which node is named.
|
||||
// TODO(phawkins): consider merging this code with union-find implementations
|
||||
// in Tensorflow, e.g., in SimplePlacer.
|
||||
class Cluster {
|
||||
public:
|
||||
Cluster();
|
||||
|
||||
int Size() { return FindRoot()->size_; }
|
||||
|
||||
// Merges this cluster with 'other'. This cluster's representative becomes
|
||||
// the representative of the merged cluster; the representative of 'other'
|
||||
// is ignored.
|
||||
void Merge(Cluster* other);
|
||||
|
||||
// Each cluster has an associated integer 'representative', initialized to -1
|
||||
// by default.
|
||||
int GetRepresentative() { return FindRoot()->representative_; }
|
||||
void SetRepresentative(int representative) {
|
||||
FindRoot()->representative_ = representative;
|
||||
}
|
||||
|
||||
private:
|
||||
// Finds the root element of the cluster. Performs path compression.
|
||||
Cluster* FindRoot();
|
||||
|
||||
int representative_;
|
||||
int rank_;
|
||||
int size_; // Size of the cluster.
|
||||
Cluster* parent_;
|
||||
};
|
||||
|
||||
Cluster::Cluster()
|
||||
: representative_(-1), rank_(0), size_(1), parent_(nullptr) {}
|
||||
|
||||
void Cluster::Merge(Cluster* other) {
|
||||
Cluster* a = FindRoot();
|
||||
Cluster* b = other->FindRoot();
|
||||
if (a == b) return;
|
||||
if (a->rank_ > b->rank_) {
|
||||
b->parent_ = a;
|
||||
a->size_ += b->size_;
|
||||
return;
|
||||
}
|
||||
|
||||
a->parent_ = b;
|
||||
if (a->rank_ == b->rank_) {
|
||||
b->rank_++;
|
||||
}
|
||||
b->representative_ = a->representative_;
|
||||
b->size_ += a->size_;
|
||||
}
|
||||
|
||||
Cluster* Cluster::FindRoot() {
|
||||
if (!parent_) return this;
|
||||
// Path compression: update intermediate nodes to point to the root of the
|
||||
// equivalence class.
|
||||
parent_ = parent_->FindRoot();
|
||||
return parent_;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
|
||||
Device* device = flr->device();
|
||||
const string* jit_device_name;
|
||||
CHECK(XlaOpRegistry::GetJitDevice(device->device_type(), &jit_device_name,
|
||||
/*requires_jit=*/nullptr));
|
||||
DeviceType jit_device_type(*jit_device_name);
|
||||
return IsCompilableCall(ndef, jit_device_type, 0, flr);
|
||||
}
|
||||
|
||||
Status MarkForCompilationPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
// TODO(phawkins): precompute the "GetJitDevice" properties each device ahead
|
||||
// of time.
|
||||
OptimizerOptions::GlobalJitLevel global_jit_level =
|
||||
options.session_options->config.graph_options()
|
||||
.optimizer_options()
|
||||
.global_jit_level();
|
||||
if (global_jit_level == OptimizerOptions::DEFAULT) {
|
||||
// To set compilation to be on by default, change the following line.
|
||||
global_jit_level = OptimizerOptions::OFF;
|
||||
}
|
||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
||||
legacy_flags::GetMarkForCompilationPassFlags();
|
||||
if (flags->tf_xla_auto_jit == -1 ||
|
||||
(1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
|
||||
// If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
|
||||
// the setting in ConfigProto.
|
||||
global_jit_level =
|
||||
static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
|
||||
}
|
||||
const FunctionLibraryDefinition* fld = options.flib_def;
|
||||
auto is_compilable = [global_jit_level, fld](const Node* node,
|
||||
const DeviceType& device_type) {
|
||||
const string* jit_device;
|
||||
bool requires_jit;
|
||||
if (!XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device,
|
||||
&requires_jit)) {
|
||||
return false;
|
||||
}
|
||||
// If this device requires a JIT, we must say yes.
|
||||
if (requires_jit) return true;
|
||||
|
||||
// If there is a _XlaCompile annotation, use its value.
|
||||
bool compile = false;
|
||||
Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile);
|
||||
if (status.ok()) return compile;
|
||||
|
||||
status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile);
|
||||
if (status.ok()) return compile;
|
||||
|
||||
// Otherwise use the value of global_jit_level.
|
||||
return global_jit_level > 0;
|
||||
};
|
||||
return RunImpl(options, is_compilable);
|
||||
}
|
||||
|
||||
// Is 'node' an operator that consumes only the shape of its input, not the
|
||||
// data itself?
|
||||
static bool IsShapeConsumerOp(const Node& node) {
|
||||
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
|
||||
node.type_string() == "Size";
|
||||
}
|
||||
|
||||
// Sequence number generator to ensure clusters have unique names.
|
||||
static std::atomic<int64> cluster_sequence_num;
|
||||
|
||||
Status MarkForCompilationPass::RunImpl(
|
||||
const GraphOptimizationPassOptions& options,
|
||||
const std::function<bool(const Node*, const DeviceType&)>&
|
||||
is_compilable_fn) {
|
||||
VLOG(1) << "MarkForCompilationPass::Run";
|
||||
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterJitKernels();
|
||||
|
||||
Graph* graph = options.graph->get();
|
||||
|
||||
std::unordered_set<Node*> compilation_candidates;
|
||||
TF_RETURN_IF_ERROR(FindCompilationCandidates(
|
||||
*graph, options.flib_def,
|
||||
(options.session_options != nullptr) ? options.session_options->env
|
||||
: Env::Default(),
|
||||
is_compilable_fn, &compilation_candidates));
|
||||
|
||||
GraphCycles cycles;
|
||||
for (int i = 0; i < graph->num_node_ids(); ++i) {
|
||||
// We rely on the node IDs in the cycle detection graph being consecutive
|
||||
// integers starting from 0.
|
||||
CHECK_EQ(i, cycles.NewNode());
|
||||
}
|
||||
|
||||
// Compute the loop structure of the graph.
|
||||
std::vector<ControlFlowInfo> control_flow_info;
|
||||
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
|
||||
|
||||
// The clustering code must avoid adding cycles to the graph to prevent
|
||||
// deadlock. However, the graph may contain loops, which would trigger the
|
||||
// cycle detection code. To handle loops, we alter the structure of the cycle
|
||||
// detection graph, disconnecting each loop from the enclosing graph.
|
||||
// Specifically, we:
|
||||
// * add a new "frame" node for each loop.
|
||||
// * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
|
||||
// to/from the corresponding frame node. In essence, we collapse the loop
|
||||
// into a single node for the purpose of cycle detection in the enclosing
|
||||
// graph.
|
||||
// * the body of the loop should now be disconnected from the rest of the
|
||||
// graph; we make it acyclic by breaking loop backedges (edges outgoing from
|
||||
// "NextIteration" nodes.
|
||||
|
||||
// Map from frame name strings to node IDs in the cycle detection graph.
|
||||
std::unordered_map<string, int> frame_nodes;
|
||||
|
||||
// Get the cycle graph node ID for frame 'frame_name', or add one if none
|
||||
// exists.
|
||||
auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) {
|
||||
int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
|
||||
if (frame_id < 0) {
|
||||
// The emplace succeeded; we have not allocated a frame node yet.
|
||||
frame_id = cycles.NewNode();
|
||||
}
|
||||
return frame_id;
|
||||
};
|
||||
|
||||
for (Edge const* edge : graph->edges()) {
|
||||
if (edge->dst()->IsEnter()) {
|
||||
// Lift edges to an "Enter" node to the corresponding frame node.
|
||||
const string& frame_name =
|
||||
control_flow_info[edge->dst()->id()].frame_name;
|
||||
if (!cycles.InsertEdge(edge->src()->id(),
|
||||
GetOrAddFrameNodeId(frame_name))) {
|
||||
return errors::Internal("Cycle detected when adding enter->frame edge");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (edge->src()->IsExit()) {
|
||||
// Lift edges from an "Exit" node to the corresponding frame node.
|
||||
const string& frame_name =
|
||||
control_flow_info[edge->src()->id()].frame_name;
|
||||
if (!cycles.InsertEdge(GetOrAddFrameNodeId(frame_name),
|
||||
edge->dst()->id())) {
|
||||
return errors::Internal("Cycle detected when adding frame->exit edge");
|
||||
}
|
||||
// Drop the original edge.
|
||||
continue;
|
||||
}
|
||||
if (edge->src()->IsNextIteration()) {
|
||||
// Break loop back-edges.
|
||||
continue;
|
||||
}
|
||||
if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) {
|
||||
// This should never happen. All cycles in the graph should contain
|
||||
// a control flow operator.
|
||||
return errors::Internal(
|
||||
"Found cycle in graph without control flow operator during XLA "
|
||||
"compilation.");
|
||||
}
|
||||
}
|
||||
|
||||
// Each compilation candidate belongs to a cluster. The cluster's
|
||||
// representative
|
||||
// names the node in the 'cycles' graph that represents the cluster.
|
||||
std::vector<Cluster> clusters(graph->num_node_ids());
|
||||
std::deque<Cluster*> worklist;
|
||||
for (Node* node : compilation_candidates) {
|
||||
clusters[node->id()].SetRepresentative(node->id());
|
||||
worklist.push_back(&clusters[node->id()]);
|
||||
}
|
||||
|
||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
||||
legacy_flags::GetMarkForCompilationPassFlags();
|
||||
|
||||
// Repeatedly contract edges between clusters that are on the same device,
|
||||
// provided the contraction would not create a cycle.
|
||||
while (!worklist.empty()) {
|
||||
int from = worklist.front()->GetRepresentative();
|
||||
worklist.pop_front();
|
||||
|
||||
Node* node_from = graph->FindNodeId(from);
|
||||
if (node_from->IsControlFlow()) {
|
||||
// Control flow nodes aren't compilation candidates and should never
|
||||
// appear.
|
||||
return errors::Internal("Found control flow node in clustering worklist");
|
||||
}
|
||||
for (int to : cycles.Successors(from)) {
|
||||
if (to >= graph->num_node_ids()) {
|
||||
// Node is a "frame" node that is present only in the cycle detection
|
||||
// graph. No clustering is possible.
|
||||
continue;
|
||||
}
|
||||
Node* node_to = graph->FindNodeId(to);
|
||||
if (compilation_candidates.find(node_to) == compilation_candidates.cend())
|
||||
continue;
|
||||
if (node_from->assigned_device_name() != node_to->assigned_device_name())
|
||||
continue;
|
||||
|
||||
// Ops that consume shapes cannot be the root of a cluster. This is an
|
||||
// optimization.
|
||||
if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Don't exceed the maximum cluster size.
|
||||
if (clusters[from].Size() + clusters[to].Size() >
|
||||
flags->tf_xla_max_cluster_size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If contracting the edge would create a cycle, bail out.
|
||||
// However, just because we can't merge the clusters now does not mean
|
||||
// we won't be able to merge them in the future.
|
||||
// e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
|
||||
// 1->3. But if we first contract 1->2 then we can later contract 1->3.
|
||||
if (!cycles.ContractEdge(from, to)) continue;
|
||||
|
||||
// Merge the clusters. ContractEdge uses 'from' as the number of the
|
||||
// merged node, so make sure 'from' is the chosen representative.
|
||||
clusters[from].Merge(&clusters[to]);
|
||||
|
||||
worklist.push_back(&clusters[from]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Count the number of elements in each cluster.
|
||||
std::vector<int> cluster_sizes(graph->num_node_ids());
|
||||
for (const Node* n : compilation_candidates) {
|
||||
int cluster = clusters[n->id()].GetRepresentative();
|
||||
cluster_sizes[cluster]++;
|
||||
}
|
||||
|
||||
// Names for each cluster.
|
||||
std::unordered_map<int, string> cluster_names;
|
||||
|
||||
// Mark clusters for compilation that:
|
||||
// * are placed on a device that requires compilation (an XlaDevice),
|
||||
// * are explicitly marked for compilation (_XlaCompile=true), or
|
||||
// * have more than flags->tf_xla_min_cluster_size elements (applicable only
|
||||
// if compilation is enabled, otherwise there will be no such candidates).
|
||||
const int min_cluster_size = flags->tf_xla_min_cluster_size;
|
||||
for (Node* n : compilation_candidates) {
|
||||
int cluster = clusters[n->id()].GetRepresentative();
|
||||
|
||||
// Compile if the user marked this node _XlaCompile=true
|
||||
bool compile_attr = false;
|
||||
bool marked_for_compilation = false;
|
||||
if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) {
|
||||
marked_for_compilation = compile_attr;
|
||||
} else if (options.flib_def
|
||||
->GetAttr(n->def(), kXlaCompileAttr, &compile_attr)
|
||||
.ok()) {
|
||||
marked_for_compilation = compile_attr;
|
||||
}
|
||||
|
||||
// Compile if this operator is placed on a device that requires
|
||||
// compilation.
|
||||
bool requires_jit = false;
|
||||
DeviceType device_type("");
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeviceTypeOfDevice(n->assigned_device_name(), &device_type));
|
||||
XlaOpRegistry::GetJitDevice(device_type.type(),
|
||||
/*jit_device_name=*/nullptr, &requires_jit);
|
||||
|
||||
// Or compile if this is a cluster of >= min_cluster_size compilable
|
||||
// operators.
|
||||
if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation ||
|
||||
requires_jit) {
|
||||
string& name = cluster_names[cluster];
|
||||
if (name.empty()) {
|
||||
name = strings::StrCat("cluster_", cluster_sequence_num++);
|
||||
}
|
||||
n->AddAttr(kXlaClusterAttr, name);
|
||||
}
|
||||
}
|
||||
|
||||
if (flags->tf_xla_clustering_debug) {
|
||||
dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph,
|
||||
options.flib_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
55
tensorflow/compiler/jit/mark_for_compilation_pass.h
Normal file
55
tensorflow/compiler/jit/mark_for_compilation_pass.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// An optimization passes that marks nodes that are to be compiled with
|
||||
// attribute kXlaClusterAttr. Nodes with the same cluster ID will be compiled
|
||||
// together.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The attribute that marks nodes to be grouped into functions by the
|
||||
// encapsulate subgraphs pass.
|
||||
extern const char* const kXlaClusterAttr;
|
||||
|
||||
// Pass that marks a subset of operators in the graph with attribute
|
||||
// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass.
|
||||
class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
public:
|
||||
MarkForCompilationPass() = default;
|
||||
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
||||
// Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass
|
||||
// unconditionally, call RunImpl() directly.
|
||||
// is_compilable_fn, if set, is a predicate that must be true for a node to
|
||||
// be compiled.
|
||||
Status RunImpl(const GraphOptimizationPassOptions& options,
|
||||
const std::function<bool(const Node*, const DeviceType&)>&
|
||||
is_compilable_fn = {});
|
||||
};
|
||||
|
||||
// Returns true iff 'ndef' is a call to a function that is compilable. A
|
||||
// function is compilable iff every operator in the function body is
|
||||
// compilable.
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
|
357
tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
Normal file
357
tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
Normal file
@ -0,0 +1,357 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
REGISTER_OP("UncompilableNullary").Output("o: float");
|
||||
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
|
||||
|
||||
void MarkForCompilation(std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def) {
|
||||
// Assign all nodes to the CPU device.
|
||||
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
|
||||
for (Node* n : (*graph)->nodes()) {
|
||||
n->set_assigned_device_name(kCpuDevice);
|
||||
}
|
||||
|
||||
GraphOptimizationPassOptions opt_options;
|
||||
opt_options.graph = graph;
|
||||
opt_options.flib_def = flib_def;
|
||||
MarkForCompilationPass pass;
|
||||
CHECK(pass.RunImpl(opt_options).ok());
|
||||
}
|
||||
|
||||
void MarkForCompilation(std::unique_ptr<Graph>* graph) {
|
||||
FunctionDefLibrary flib;
|
||||
FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
|
||||
MarkForCompilation(graph, &flib_def);
|
||||
}
|
||||
|
||||
std::unordered_map<string, string> GetClusters(const Graph& graph) {
|
||||
std::unordered_map<string, string> ids;
|
||||
for (Node* node : graph.nodes()) {
|
||||
string cluster;
|
||||
if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) {
|
||||
CHECK(!cluster.empty());
|
||||
ids[node->name()] = cluster;
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, Chains) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a =
|
||||
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
|
||||
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
|
||||
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
|
||||
Node* d =
|
||||
ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
|
||||
Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
|
||||
ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
|
||||
builder.ToGraph(graph.get());
|
||||
}
|
||||
|
||||
MarkForCompilation(&graph);
|
||||
auto clusters = GetClusters(*graph);
|
||||
EXPECT_EQ(4, clusters.size());
|
||||
EXPECT_EQ(clusters["B"], clusters["C"]);
|
||||
EXPECT_EQ(clusters["E"], clusters["F"]);
|
||||
EXPECT_NE(clusters["B"], clusters["E"]);
|
||||
EXPECT_TRUE(clusters.find("A") == clusters.cend());
|
||||
EXPECT_TRUE(clusters.find("D") == clusters.cend());
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, UncompilableCycles) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* b =
|
||||
ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
|
||||
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
|
||||
builder.ToGraph(graph.get());
|
||||
}
|
||||
|
||||
MarkForCompilation(&graph);
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_TRUE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, CompilableCycles) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
|
||||
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
|
||||
builder.ToGraph(graph.get());
|
||||
}
|
||||
|
||||
MarkForCompilation(&graph);
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_EQ(3, clusters.size());
|
||||
EXPECT_EQ(clusters["A"], clusters["B"]);
|
||||
EXPECT_EQ(clusters["A"], clusters["C"]);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, UnsupportedTypes) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp(
|
||||
"Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_COMPLEX64)
|
||||
.WithAttr("value", Tensor(DT_COMPLEX64, TensorShape())));
|
||||
Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
|
||||
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
|
||||
builder.ToGraph(graph.get());
|
||||
}
|
||||
|
||||
MarkForCompilation(&graph);
|
||||
auto clusters = GetClusters(*graph);
|
||||
EXPECT_TRUE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, ConcatWithConstArg) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
Tensor t(DT_INT32, TensorShape());
|
||||
t.scalar<int32>()() = 0;
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* dim = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("Dim")
|
||||
.WithAttr("dtype", DT_INT32)
|
||||
.WithAttr("value", t));
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", t));
|
||||
|
||||
NodeBuilder concat_builder("Concat", "Concat",
|
||||
builder.opts().op_registry());
|
||||
concat_builder.Input(dim).Input({a, a}).Attr("N", 2);
|
||||
builder.opts().FinalizeBuilder(&concat_builder);
|
||||
|
||||
builder.ToGraph(graph.get());
|
||||
}
|
||||
|
||||
MarkForCompilation(&graph);
|
||||
auto clusters = GetClusters(*graph);
|
||||
EXPECT_EQ(3, clusters.size()); // Everything should be compiled.
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, FunctionCalls) {
|
||||
FunctionDefLibrary flib;
|
||||
*flib.add_function() = FunctionDefHelper::Define(
|
||||
"CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
|
||||
{{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
|
||||
*flib.add_function() =
|
||||
FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
|
||||
{}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(&flib_def));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
|
||||
Node* a =
|
||||
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
|
||||
Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
|
||||
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
|
||||
ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
|
||||
builder.ToGraph(graph.get());
|
||||
}
|
||||
|
||||
MarkForCompilation(&graph, &flib_def);
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_EQ(2, clusters.size());
|
||||
EXPECT_FALSE(clusters["B"].empty());
|
||||
EXPECT_EQ(clusters["B"], clusters["C"]);
|
||||
EXPECT_TRUE(clusters.find("A") == clusters.cend());
|
||||
EXPECT_TRUE(clusters.find("D") == clusters.cend());
|
||||
}
|
||||
|
||||
// Metadata-only operators such as Shape/Rank/Size may not be the root of a
|
||||
// cluster. This is partially to work around b/26800664, and partially because
|
||||
// we should probably prefer to compile metadata operators with their producers
|
||||
// wherever possible, rather than their consumers.
|
||||
TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a =
|
||||
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
|
||||
// While all of the following ops are notionally compilable, none is
|
||||
// permitted
|
||||
// to start a cluster. So nothing should be compiled.
|
||||
Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
|
||||
Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
|
||||
Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
|
||||
ops::UnaryOp("Shape", d, builder.opts().WithName("C"));
|
||||
builder.ToGraph(graph.get());
|
||||
}
|
||||
MarkForCompilation(&graph);
|
||||
auto clusters = GetClusters(*graph);
|
||||
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
|
||||
}
|
||||
|
||||
static Status GradForUnaryCwise(FunctionDef* g,
|
||||
std::vector<FunctionDefHelper::Node> nodes) {
|
||||
for (auto& n : nodes) {
|
||||
if (n.attr.empty()) {
|
||||
n.attr = {{"T", DT_FLOAT}};
|
||||
}
|
||||
}
|
||||
*g = FunctionDefHelper::Define(
|
||||
// Arg defs
|
||||
{"x: float", "dy: float"},
|
||||
// Ret val defs
|
||||
{"dx: float"},
|
||||
// Attr defs
|
||||
{},
|
||||
// Nodes
|
||||
nodes);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A gradient containing only supported operators
|
||||
Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// clang-format off
|
||||
return GradForUnaryCwise(g, {
|
||||
{{"y"}, "Tanh", {"x"}},
|
||||
{{"y2"}, "Square", {"y"}, {}, {"dy"}},
|
||||
FunctionDefHelper::Const("one", 1.0f),
|
||||
{{"a"}, "Sub", {"one", "y2"}},
|
||||
{{"dx"}, "Mul", {"dy", "a"}},
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
REGISTER_OP_GRADIENT("Supported", SupportedGrad);
|
||||
|
||||
// A gradient containing an unsupported operator.
|
||||
Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// clang-format off
|
||||
return GradForUnaryCwise(g, {
|
||||
{{"y"}, "Tanh", {"x"}},
|
||||
{{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
|
||||
FunctionDefHelper::Const("one", 1.0f),
|
||||
{{"a"}, "Sub", {"one", "y2"}},
|
||||
{{"dx"}, "Mul", {"dy", "a"}},
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
|
||||
|
||||
TEST(XlaCompilationTest, SymbolicGradients) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a =
|
||||
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
|
||||
|
||||
// Builds a Symbolic gradient for Supported
|
||||
NodeBuilder b_builder("B", "SymbolicGradient",
|
||||
builder.opts().op_registry());
|
||||
NameAttrList b_name_attr;
|
||||
b_name_attr.set_name("Supported");
|
||||
b_builder.Attr("f", b_name_attr);
|
||||
b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
|
||||
b_builder.Attr("Tout", {DT_FLOAT});
|
||||
b_builder.Input({a, a});
|
||||
Node* b = builder.opts().FinalizeBuilder(&b_builder);
|
||||
|
||||
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
|
||||
|
||||
// Builds a Symbolic gradient for Unsupported
|
||||
NodeBuilder d_builder("D", "SymbolicGradient",
|
||||
builder.opts().op_registry());
|
||||
NameAttrList d_name_attr;
|
||||
d_name_attr.set_name("Unsupported");
|
||||
d_builder.Attr("f", d_name_attr);
|
||||
d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
|
||||
d_builder.Attr("Tout", {DT_FLOAT});
|
||||
d_builder.Input({c, c});
|
||||
builder.opts().FinalizeBuilder(&d_builder);
|
||||
|
||||
builder.ToGraph(graph.get());
|
||||
}
|
||||
|
||||
MarkForCompilation(&graph);
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_EQ(2, clusters.size());
|
||||
EXPECT_FALSE(clusters["B"].empty());
|
||||
EXPECT_EQ(clusters["B"], clusters["C"]);
|
||||
EXPECT_TRUE(clusters.find("A") == clusters.cend());
|
||||
EXPECT_TRUE(clusters.find("D") == clusters.cend());
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, Loops) {
|
||||
// Regression test for b/32350199, where the autoclustering code introduced a
|
||||
// deadlock in a graph containing a while loop.
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
|
||||
auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
|
||||
auto c = ops::Add(root.WithOpName("C"), a, b);
|
||||
auto enter = ops::Enter(root, c, "aframe");
|
||||
auto next_iter = ops::NextIteration(root, enter);
|
||||
auto exit = ops::Exit(root, next_iter);
|
||||
auto d = ops::Add(root.WithOpName("D"), c, exit);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
root.ToGraph(graph.get());
|
||||
|
||||
MarkForCompilation(&graph);
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
// Nothing should be compiled. In particular, 'd' and 'c' must not be
|
||||
// compiled.
|
||||
EXPECT_EQ(0, clusters.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
154
tensorflow/compiler/jit/parallel_check_op.cc
Normal file
154
tensorflow/compiler/jit/parallel_check_op.cc
Normal file
@ -0,0 +1,154 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
REGISTER_OP("ParallelCheck")
|
||||
.Attr("T: list(type) >= 0")
|
||||
.Input("expected: T")
|
||||
.Input("actual: T")
|
||||
.Output("result: T")
|
||||
.Doc(R"doc(
|
||||
Op that compares two sets of inputs for near-identity, and propagates the first.
|
||||
Inequality is logged to ERROR log.
|
||||
)doc");
|
||||
|
||||
// Inputs 2*N tensors, outputs the first N inputs.
|
||||
// Logs errors if input tensor i and i + N are not (near) identical
|
||||
// in any position.
|
||||
class ParallelCheckOp : public OpKernel {
|
||||
public:
|
||||
explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
template <typename T>
|
||||
int CompareTensors(DataType dtype, const char* v0, const char* v1,
|
||||
int64 num_elts, int input_idx) {
|
||||
int failed = 0;
|
||||
const T* p0 = reinterpret_cast<const T*>(v0);
|
||||
const T* p1 = reinterpret_cast<const T*>(v1);
|
||||
double rtol;
|
||||
legacy_flags::ParallelCheckOpFlags* flags =
|
||||
legacy_flags::GetParallelCheckOpFlags();
|
||||
if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(),
|
||||
&rtol)) {
|
||||
LOG(ERROR) << "can't convert parallel_check_rtol "
|
||||
<< flags->parallel_check_rtol << " to double";
|
||||
}
|
||||
double atol;
|
||||
if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(),
|
||||
&atol)) {
|
||||
LOG(ERROR) << "can't convert parallel_check_atol "
|
||||
<< flags->parallel_check_atol << " to double";
|
||||
}
|
||||
for (int i = 0; i < num_elts; ++i) {
|
||||
bool ok = (p0[i] == p1[i]);
|
||||
VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i];
|
||||
if (!ok) {
|
||||
if (std::is_same<T, float>::value || std::is_same<T, double>::value) {
|
||||
float tolerance =
|
||||
std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i])));
|
||||
T diff = p0[i] - p1[i];
|
||||
if (diff < 0) diff = 0 - diff;
|
||||
ok = (diff <= tolerance);
|
||||
}
|
||||
if (ok) continue;
|
||||
LOG(ERROR) << "Op " << def().name() << " fails equality at output "
|
||||
<< input_idx << " type " << DataTypeString(dtype)
|
||||
<< " element " << i << ": std_val=" << p0[i]
|
||||
<< " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]);
|
||||
if (++failed > 10) break;
|
||||
}
|
||||
}
|
||||
return failed;
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
VLOG(1) << "Compute " << def().name();
|
||||
const int num_pairs = ctx->num_inputs() / 2;
|
||||
for (int i = 0; i < num_pairs; ++i) {
|
||||
CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs));
|
||||
Tensor t0 = ctx->input(i);
|
||||
Tensor t1 = ctx->input(i + num_pairs);
|
||||
int64 num_elts = t0.NumElements();
|
||||
CHECK_EQ(num_elts, t1.NumElements());
|
||||
|
||||
// Compare inputs elementwise for near-exact equality.
|
||||
const char* v0 = t0.tensor_data().data();
|
||||
const char* v1 = t1.tensor_data().data();
|
||||
int failed = 0;
|
||||
switch (ctx->input_dtype(i)) {
|
||||
case DT_INT32:
|
||||
failed =
|
||||
CompareTensors<int32>(ctx->input_dtype(i), v0, v1, num_elts, i);
|
||||
break;
|
||||
case DT_INT64:
|
||||
failed =
|
||||
CompareTensors<int64>(ctx->input_dtype(i), v0, v1, num_elts, i);
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
failed =
|
||||
CompareTensors<float>(ctx->input_dtype(i), v0, v1, num_elts, i);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
failed =
|
||||
CompareTensors<double>(ctx->input_dtype(i), v0, v1, num_elts, i);
|
||||
break;
|
||||
case DT_BOOL:
|
||||
failed =
|
||||
CompareTensors<bool>(ctx->input_dtype(i), v0, v1, num_elts, i);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "unimpl: " << ctx->input_dtype(i);
|
||||
}
|
||||
if (failed > 0) {
|
||||
LOG(ERROR) << "check failed for " << def().name() << " output " << i
|
||||
<< " num_elts: " << num_elts;
|
||||
legacy_flags::ParallelCheckOpFlags* flags =
|
||||
legacy_flags::GetParallelCheckOpFlags();
|
||||
if (flags->parallel_check_failfast) {
|
||||
LOG(QFATAL) << "failfast on first parallel-check failure";
|
||||
}
|
||||
} else {
|
||||
VLOG(1) << "check passed for " << def().name() << " output " << i
|
||||
<< " num_elts: " << num_elts;
|
||||
}
|
||||
|
||||
// Propagate the std value.
|
||||
if (IsRefType(ctx->input_dtype(i))) {
|
||||
ctx->forward_ref_input_to_ref_output(i, i);
|
||||
} else {
|
||||
ctx->set_output(i, ctx->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU),
|
||||
ParallelCheckOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
199
tensorflow/compiler/jit/xla_compilation_cache.cc
Normal file
199
tensorflow/compiler/jit/xla_compilation_cache.cc
Normal file
@ -0,0 +1,199 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
XlaCompilationCache::XlaCompilationCache(const XlaCompiler::Options& options)
|
||||
: compiler_(options) {}
|
||||
|
||||
XlaCompilationCache::~XlaCompilationCache() = default;
|
||||
|
||||
string XlaCompilationCache::DebugString() {
|
||||
return "XLA JIT compilation cache";
|
||||
}
|
||||
|
||||
// Compute a string signature which encodes the shapes of the
|
||||
// arguments in the supplied list.
|
||||
string XlaCompilationCache::SignatureDebugString(const Signature& sig) {
|
||||
string result = sig.name;
|
||||
for (const auto& a : sig.arg_types) {
|
||||
strings::StrAppend(&result, ",", DataTypeString(a.first),
|
||||
a.second.DebugString());
|
||||
}
|
||||
|
||||
for (const auto& v : sig.arg_values) {
|
||||
strings::StrAppend(&result, "; ", v.first, ":", v.second.DebugString());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
|
||||
if (name != other.name) return false;
|
||||
if (arg_types != other.arg_types) return false;
|
||||
|
||||
if (arg_values.size() != other.arg_values.size()) return false;
|
||||
for (int i = 0; i < arg_values.size(); ++i) {
|
||||
if (arg_values[i].first != other.arg_values[i].first ||
|
||||
arg_values[i].second.tensor_data() !=
|
||||
other.arg_values[i].second.tensor_data()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint64 XlaCompilationCache::Signature::Hash::operator()(
|
||||
const XlaCompilationCache::Signature& signature) const {
|
||||
uint64 h = std::hash<string>()(signature.name);
|
||||
for (const auto& arg : signature.arg_types) {
|
||||
h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
|
||||
h = Hash64Combine(h, std::hash<int>()(arg.second.dims()));
|
||||
for (int dim : arg.second.dim_sizes()) {
|
||||
h = Hash64Combine(h, std::hash<int>()(dim));
|
||||
}
|
||||
}
|
||||
for (const auto& arg : signature.arg_values) {
|
||||
h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
|
||||
h = Hash64Combine(h, Hash64(arg.second.tensor_data().data(),
|
||||
arg.second.tensor_data().size()));
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch
|
||||
// op. The first `num_constant_args` arguments must be host-memory Tensors.
|
||||
std::vector<XlaCompiler::Argument> BuildArguments(int num_constant_args,
|
||||
OpKernelContext* ctx) {
|
||||
std::vector<XlaCompiler::Argument> args(ctx->num_inputs());
|
||||
int parameter_num = 0;
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
args[i].type = ctx->input(i).dtype();
|
||||
args[i].shape = ctx->input(i).shape();
|
||||
if (i < num_constant_args || ctx->input(i).NumElements() == 0) {
|
||||
args[i].parameter = -1;
|
||||
args[i].constant_value = ctx->input(i);
|
||||
} else {
|
||||
args[i].parameter = parameter_num;
|
||||
++parameter_num;
|
||||
}
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status XlaCompilationCache::Compile(
|
||||
const NameAttrList& function, int num_constant_args, OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult** compilation_result,
|
||||
xla::LocalExecutable** executable) {
|
||||
VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
std::vector<string> argshapes;
|
||||
VLOG(2) << "num_inputs = " << ctx->num_inputs()
|
||||
<< " num_constant_args= " << num_constant_args;
|
||||
for (int i = 0; i < ctx->num_inputs(); i++) {
|
||||
TensorShape shape = ctx->input(i).shape();
|
||||
VLOG(2) << i << ": dtype=" << ctx->input_dtype(i)
|
||||
<< " present=" << ctx->has_input(i)
|
||||
<< " shape=" << shape.DebugString();
|
||||
argshapes.push_back(shape.DebugString());
|
||||
}
|
||||
VLOG(2) << "num_outputs = " << ctx->num_outputs();
|
||||
for (int i = 0; i < ctx->num_outputs(); i++) {
|
||||
VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i);
|
||||
}
|
||||
}
|
||||
Signature signature;
|
||||
signature.name = Canonicalize(function.name(), function.attr());
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
signature.arg_types.emplace_back(ctx->input_dtype(i),
|
||||
ctx->input(i).shape());
|
||||
if (i < num_constant_args) {
|
||||
signature.arg_values.emplace_back(i, ctx->input(i));
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(2) << "Signature: " << SignatureDebugString(signature);
|
||||
// The outer lock protects the existence of the cache entry. It does not
|
||||
// protect the contents of the cache entry.
|
||||
Entry* entry;
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
// Find or create a cache entry.
|
||||
std::unique_ptr<Entry>& e = cache_[signature];
|
||||
if (!e) {
|
||||
e.reset(new Entry);
|
||||
}
|
||||
entry = e.get();
|
||||
}
|
||||
|
||||
// Acquire the cache entry lock and compile, if necessary.
|
||||
// TODO(phawkins): this locking will need to be restructured when we implement
|
||||
// cache eviction.
|
||||
mutex_lock entry_lock(entry->mu);
|
||||
if (!entry->compiled) {
|
||||
// Do the actual JIT compilation without holding the lock (it can take
|
||||
// a long time.)
|
||||
std::vector<XlaCompiler::Argument> args =
|
||||
BuildArguments(num_constant_args, ctx);
|
||||
|
||||
std::unique_ptr<FunctionLibraryRuntime> flr(NewFunctionLibraryRuntime(
|
||||
compiler_.device_mgr(), ctx->env(), compiler_.device(),
|
||||
TF_GRAPH_DEF_VERSION,
|
||||
ctx->function_library()->GetFunctionLibraryDefinition(),
|
||||
OptimizerOptions(), nullptr /* custom_kernel_creator */));
|
||||
|
||||
entry->compiled = true;
|
||||
entry->compilation_status = compiler_.CompileFunction(
|
||||
flr.get(), function, args, &entry->compilation_result);
|
||||
}
|
||||
*compilation_result = &entry->compilation_result;
|
||||
if (entry->compilation_status.ok() && executable) {
|
||||
if (entry->executable == nullptr &&
|
||||
!entry->compilation_result.computation.IsNull()) {
|
||||
entry->compilation_status = compiler_.BuildExecutable(
|
||||
entry->compilation_result, &entry->executable);
|
||||
}
|
||||
*executable = entry->executable.get();
|
||||
}
|
||||
|
||||
Status status = entry->compilation_status;
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
112
tensorflow/compiler/jit/xla_compilation_cache.h
Normal file
112
tensorflow/compiler/jit/xla_compilation_cache.h
Normal file
@ -0,0 +1,112 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The XlaCompilationCache class caches the results of the XlaCompiler class,
|
||||
// which converts a Tensorflow graph into a compiled XLA compilation.
|
||||
//
|
||||
// Since XLA computations must have static shapes, the cache generates a new
|
||||
// XLA computation for each new set of input shapes.
|
||||
//
|
||||
// Currently no cache eviction policy is implemented and the cache grows without
|
||||
// bound.
|
||||
class XlaCompilationCache : public ResourceBase {
|
||||
public:
|
||||
explicit XlaCompilationCache(const XlaCompiler::Options& options);
|
||||
~XlaCompilationCache() override;
|
||||
|
||||
// Compiles a function into a XlaCompiler::CompilationResult that can be used
|
||||
// to execute an XLA Computation. `compilation_result` must be non-null.
|
||||
// If `executable` is non-null, also builds an xla::LocalExecutable and sets
|
||||
// `executable to point to it. The resulting executable pointer may be null if
|
||||
// the computation has no non-constant outputs.
|
||||
// Compilation results are cached.
|
||||
Status Compile(const NameAttrList& function, int num_constant_args,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult** compilation_result,
|
||||
xla::LocalExecutable** executable);
|
||||
|
||||
xla::Client* client() const { return compiler_.client(); }
|
||||
|
||||
string DebugString() override;
|
||||
|
||||
private:
|
||||
XlaCompiler compiler_;
|
||||
std::unique_ptr<FunctionLibraryRuntime> function_library_runtime_;
|
||||
|
||||
// Describes the types, shapes and any compile-time constant arguments
|
||||
// to a kernel.
|
||||
struct Signature {
|
||||
string name;
|
||||
|
||||
std::vector<std::pair<DataType, TensorShape>> arg_types;
|
||||
|
||||
// List of (argument #, value) pairs for arguments whose values are
|
||||
// part of the JIT signature, and that are therefore constants in any given
|
||||
// JIT compilation. Tensors must be in host memory.
|
||||
std::vector<std::pair<int, Tensor>> arg_values;
|
||||
|
||||
bool operator==(const Signature& other) const;
|
||||
|
||||
struct Hash {
|
||||
uint64 operator()(const Signature& signature) const;
|
||||
};
|
||||
};
|
||||
static string SignatureDebugString(const Signature& sig);
|
||||
|
||||
// The value associated with a cache entry.
|
||||
struct Entry {
|
||||
mutex mu;
|
||||
|
||||
// Have we tried compiling this entry?
|
||||
bool compiled = false;
|
||||
|
||||
// Did compilation succeed?
|
||||
Status compilation_status GUARDED_BY(mu);
|
||||
|
||||
// Output of the XlaCompiler.
|
||||
XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu);
|
||||
|
||||
// The XLA executable compiled from <computation>. May be null if no
|
||||
// executable has been built.
|
||||
std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
|
||||
};
|
||||
|
||||
mutex mu_;
|
||||
std::unordered_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
|
||||
GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
60
tensorflow/compiler/jit/xla_cpu_device.cc
Normal file
60
tensorflow/compiler/jit/xla_cpu_device.cc
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Registers the XLA_CPU device, which is an XlaDevice instantiation that runs
|
||||
// operators using XLA via the XLA "Host" (CPU) backend.
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const DEVICE_XLA_CPU = "XLA_CPU";
|
||||
|
||||
class XlaCpuDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<Device*>* devices) override;
|
||||
};
|
||||
|
||||
Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices) {
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT);
|
||||
(void)registrations;
|
||||
|
||||
std::unique_ptr<XlaDevice> device;
|
||||
TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
|
||||
DEVICE_CPU_XLA_JIT, options, name_prefix,
|
||||
&device));
|
||||
devices->push_back(device.release());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 5> kAllXlaCpuTypes = {
|
||||
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaDeviceLaunchOp, kAllXlaCpuTypes);
|
||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
|
||||
|
||||
} // namespace tensorflow
|
219
tensorflow/compiler/jit/xla_device.cc
Normal file
219
tensorflow/compiler/jit/xla_device.cc
Normal file
@ -0,0 +1,219 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_context.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/* static */ Status XlaDevice::Create(
|
||||
const string& platform_name, const string& device_name, int device_ordinal,
|
||||
const string& jit_device_name, const SessionOptions& options,
|
||||
const string& name_prefix, std::unique_ptr<XlaDevice>* device) {
|
||||
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
|
||||
<< device_ordinal;
|
||||
|
||||
// These are no-ops if they have already been done previously for
|
||||
// this device_name/jit_device_name pair.
|
||||
XlaOpRegistry::RegisterJitKernels();
|
||||
XlaOpRegistry::RegisterJitDevice(device_name, jit_device_name,
|
||||
/*requires_jit=*/true);
|
||||
|
||||
auto platform = perftools::gputools::MultiPlatformManager::PlatformWithName(
|
||||
platform_name);
|
||||
if (!platform.ok()) {
|
||||
return StreamExecutorUtil::ConvertStatus(platform.status());
|
||||
}
|
||||
|
||||
const DeviceAttributes attrs = Device::BuildDeviceAttributes(
|
||||
strings::StrCat(name_prefix, "/device:", device_name, ":",
|
||||
device_ordinal),
|
||||
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
|
||||
strings::StrCat("device: ", device_name, " device"));
|
||||
|
||||
static Allocator* allocator = new XlaDeviceAllocator;
|
||||
device->reset(new XlaDevice(options, attrs, device_ordinal,
|
||||
DeviceType(jit_device_name),
|
||||
platform.ValueOrDie(), allocator));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaDevice::Metadata::Metadata(int device_ordinal,
|
||||
perftools::gputools::Platform* platform,
|
||||
const DeviceType& device_type)
|
||||
: device_ordinal_(device_ordinal),
|
||||
device_type_(device_type),
|
||||
platform_(platform) {}
|
||||
|
||||
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
|
||||
|
||||
perftools::gputools::Platform* XlaDevice::Metadata::platform() const {
|
||||
return platform_;
|
||||
}
|
||||
|
||||
XlaDevice::Metadata::~Metadata() {}
|
||||
|
||||
xla::LocalClient* XlaDevice::Metadata::client() const {
|
||||
auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
|
||||
return client.ValueOrDie();
|
||||
}
|
||||
|
||||
const DeviceType& XlaDevice::Metadata::jit_device_type() const {
|
||||
return device_type_;
|
||||
}
|
||||
|
||||
string XlaDevice::Metadata::DebugString() { return "XLA device metadata"; }
|
||||
|
||||
XlaDevice::XlaDevice(const SessionOptions& options,
|
||||
const DeviceAttributes& attrs, int device_ordinal,
|
||||
const DeviceType& jit_device_name,
|
||||
perftools::gputools::Platform* platform,
|
||||
Allocator* xla_allocator)
|
||||
: LocalDevice(options, attrs, xla_allocator),
|
||||
device_ordinal_(device_ordinal),
|
||||
jit_device_name_(jit_device_name),
|
||||
xla_allocator_(xla_allocator),
|
||||
platform_(platform) {
|
||||
// Store the platform in the resource manager so Ops can retrieve it
|
||||
// e.g., to lazily create a XlaCompilationCache object.
|
||||
TF_CHECK_OK(resource_manager()->Create<Metadata>(
|
||||
resource_manager()->default_container(), "xla_metadata",
|
||||
new Metadata(device_ordinal_, platform_, jit_device_name_)));
|
||||
}
|
||||
XlaDevice::~XlaDevice() {}
|
||||
|
||||
xla::LocalClient* XlaDevice::client() const {
|
||||
// We lazily create the client because the platform commits to the
|
||||
// details of the host hardware when the client is created, so we
|
||||
// don't want to do it until we get a chance to hook the platform up
|
||||
// to a simulator.
|
||||
|
||||
// For now GetOrCreateLocalClient always returns success when passed
|
||||
// a non-null platform. If that changes we may have to plumb in some
|
||||
// way to pass Status back.
|
||||
return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie();
|
||||
}
|
||||
|
||||
Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
|
||||
if (attr.on_host()) {
|
||||
return cpu_allocator();
|
||||
} else {
|
||||
return xla_allocator_;
|
||||
}
|
||||
}
|
||||
|
||||
Status XlaDevice::FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) {
|
||||
VLOG(1) << "XlaDevice::FillContextMap";
|
||||
device_context_map->resize(graph->num_node_ids());
|
||||
XlaDeviceContext* ctx = new XlaDeviceContext(client());
|
||||
for (Node* n : graph->nodes()) {
|
||||
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
|
||||
ctx->Ref();
|
||||
(*device_context_map)[n->id()] = ctx;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
op_kernel->ComputeAsync(context, done);
|
||||
}
|
||||
|
||||
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) {
|
||||
VLOG(1) << "XlaDevice::MakeTensorFromProto";
|
||||
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
tensor_proto.DebugString());
|
||||
}
|
||||
|
||||
Status status;
|
||||
if (alloc_attrs.on_host()) {
|
||||
*tensor = parsed;
|
||||
} else {
|
||||
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
|
||||
Notification n;
|
||||
XlaTransferManager manager(client());
|
||||
manager.CopyCPUTensorToDevice(&parsed, this, ©,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
*tensor = copy;
|
||||
}
|
||||
VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
|
||||
return status;
|
||||
}
|
||||
|
||||
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
||||
const char* jit_device) {
|
||||
XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
|
||||
auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* {
|
||||
return new XlaDeviceDummyOp(context);
|
||||
};
|
||||
for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(jit_device)) {
|
||||
KernelDef* def = new KernelDef(*jit_def);
|
||||
def->set_device_type(device);
|
||||
registrations->op_kernel_registrars.emplace_back(
|
||||
new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp",
|
||||
dummy_factory));
|
||||
}
|
||||
return registrations;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
120
tensorflow/compiler/jit/xla_device.h
Normal file
120
tensorflow/compiler/jit/xla_device.h
Normal file
@ -0,0 +1,120 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// The XlaDevice executes a TensorFlow graph using the XLA linear algebra
|
||||
// runtime.
|
||||
//
|
||||
// Operators assigned to an XlaDevice are compiled into XLA computations.
|
||||
// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state
|
||||
// is managed by XLA.
|
||||
//
|
||||
// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
|
||||
// under different names (e.g., XLA_CPU or XLA_GPU).
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class XlaDevice : public LocalDevice {
|
||||
public:
|
||||
// Wrapper class to store metadata about the XlaDevice in the
|
||||
// resource manager, where it can be looked up e.g., when lazily
|
||||
// creating the XlaCompilationCache device.
|
||||
class Metadata : public ResourceBase {
|
||||
public:
|
||||
Metadata(int device_ordinal, perftools::gputools::Platform* platform,
|
||||
const DeviceType& device_type);
|
||||
~Metadata() override;
|
||||
|
||||
// The index of the device on this host.
|
||||
int device_ordinal() const;
|
||||
|
||||
perftools::gputools::Platform* platform() const;
|
||||
xla::LocalClient* client() const;
|
||||
const DeviceType& jit_device_type() const;
|
||||
|
||||
string DebugString() override;
|
||||
|
||||
private:
|
||||
const int device_ordinal_;
|
||||
const DeviceType device_type_;
|
||||
perftools::gputools::Platform* platform_; // Not owned.
|
||||
};
|
||||
|
||||
// Factory function. 'platform_name' is the name of the XLA platform.
|
||||
// 'device_name' is the name of the Tensorflow device to create.
|
||||
// 'jit_device_name' is the name of the corresponding JIT device.
|
||||
static Status Create(const string& platform_name, const string& device_name,
|
||||
int device_ordinal, const string& jit_device_name,
|
||||
const SessionOptions& options, const string& name_prefix,
|
||||
std::unique_ptr<XlaDevice>* device);
|
||||
|
||||
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
|
||||
int device_ordinal, const DeviceType& jit_device_name,
|
||||
::perftools::gputools::Platform* platform,
|
||||
Allocator* xla_allocator);
|
||||
~XlaDevice() override;
|
||||
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override;
|
||||
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
|
||||
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) override;
|
||||
Status Sync() override { return Status::OK(); }
|
||||
|
||||
Status FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) override;
|
||||
|
||||
Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) override;
|
||||
|
||||
xla::LocalClient* client() const;
|
||||
|
||||
private:
|
||||
// Which hardware device in the client's platform this XlaDevice controls.
|
||||
const int device_ordinal_;
|
||||
// The name of the device that is used to compile Ops for this XlaDevice.
|
||||
const DeviceType& jit_device_name_;
|
||||
Allocator* xla_allocator_; // Not owned.
|
||||
::perftools::gputools::Platform* platform_; // Not owned.
|
||||
};
|
||||
|
||||
// Builds dummy OpKernel registrations on 'device' for the JIT operators
|
||||
// registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations
|
||||
// object that encapsulates the kernel registrations.
|
||||
struct XlaDeviceOpRegistrations {
|
||||
std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
|
||||
op_kernel_registrars;
|
||||
};
|
||||
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
||||
const char* jit_device);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
|
181
tensorflow/compiler/jit/xla_device_context.cc
Normal file
181
tensorflow/compiler/jit/xla_device_context.cc
Normal file
@ -0,0 +1,181 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device_context.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The contents of tensors allocated by XlaDeviceAllocator.
|
||||
struct XlaGlobalData {
|
||||
mutable mutex mu;
|
||||
// May be nullptr if there is no xla::GlobalData backing this Tensor.
|
||||
std::shared_ptr<xla::GlobalData> data GUARDED_BY(mu);
|
||||
};
|
||||
|
||||
// The allocator used for Tensors assigned to the XLA device. The allocator
|
||||
// doesn't actually back Tensors with storage. Instead, each tensor contains
|
||||
// a XlaGlobalData that wraps XLA-managed storage.
|
||||
XlaDeviceAllocator::XlaDeviceAllocator() = default;
|
||||
XlaDeviceAllocator::~XlaDeviceAllocator() = default;
|
||||
|
||||
string XlaDeviceAllocator::Name() { return "xla"; }
|
||||
|
||||
void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
|
||||
// Regardless of the size requested, always allocate a XlaGlobalData. Respect
|
||||
// the aligment request because there is alignment checking even for Tensors
|
||||
// whose data is never accessed.
|
||||
void* p = port::aligned_malloc(sizeof(XlaGlobalData), alignment);
|
||||
VLOG(2) << "Allocated XLA device tensor " << p;
|
||||
return new (p) XlaGlobalData();
|
||||
}
|
||||
|
||||
void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
|
||||
XlaGlobalData* global_data = reinterpret_cast<XlaGlobalData*>(ptr);
|
||||
VLOG(2) << "Deallocated XLA device tensor " << ptr;
|
||||
global_data->~XlaGlobalData();
|
||||
port::aligned_free(ptr);
|
||||
}
|
||||
|
||||
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
|
||||
|
||||
// Don't run any constructors or destructors for complex objects,
|
||||
// since there is no backing store for the tensor to run them
|
||||
// on. strings are the only complex objects currently stored in
|
||||
// Tensors. If others are added, this set of overrides must be
|
||||
// extended to include them.
|
||||
void XlaDeviceAllocator::RunStringCtor(string* p, size_t n) {}
|
||||
void XlaDeviceAllocator::RunStringDtor(string* p, size_t n) {}
|
||||
void XlaDeviceAllocator::RunResourceCtor(ResourceHandle* p, size_t n) {}
|
||||
void XlaDeviceAllocator::RunResourceDtor(ResourceHandle* p, size_t n) {}
|
||||
|
||||
static const XlaGlobalData* CastTensorToXlaGlobalData(const Tensor& tensor) {
|
||||
const XlaGlobalData* expression =
|
||||
reinterpret_cast<const XlaGlobalData*>(tensor.tensor_data().data());
|
||||
return expression;
|
||||
}
|
||||
|
||||
static XlaGlobalData* CastTensorToXlaGlobalData(Tensor* tensor) {
|
||||
const XlaGlobalData* expression =
|
||||
reinterpret_cast<const XlaGlobalData*>(tensor->tensor_data().data());
|
||||
return const_cast<XlaGlobalData*>(expression);
|
||||
}
|
||||
|
||||
XlaTransferManager::XlaTransferManager(xla::Client* client) : client_(client) {}
|
||||
|
||||
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
Device* device,
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const {
|
||||
if (cpu_tensor->NumElements() > 0) {
|
||||
VLOG(2) << "CopyCPUTensorToDevice "
|
||||
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
|
||||
<< " " << reinterpret_cast<const void*>(
|
||||
device_tensor->tensor_data().data())
|
||||
<< cpu_tensor->NumElements();
|
||||
xla::Literal literal;
|
||||
Status status = HostTensorToLiteral(*cpu_tensor, &literal);
|
||||
if (!status.ok()) {
|
||||
done(status);
|
||||
return;
|
||||
}
|
||||
auto gd = client_->TransferToServer(literal);
|
||||
if (!gd.ok()) {
|
||||
done(gd.status());
|
||||
return;
|
||||
}
|
||||
SetTensorGlobalData(
|
||||
std::shared_ptr<xla::GlobalData>(std::move(gd.ValueOrDie())),
|
||||
device_tensor);
|
||||
} else {
|
||||
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
|
||||
}
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
StringPiece tensor_name,
|
||||
Device* device,
|
||||
Tensor* cpu_tensor,
|
||||
StatusCallback done) {
|
||||
if (device_tensor->NumElements() > 0) {
|
||||
VLOG(2) << "CopyDeviceTensorToCPU"
|
||||
<< reinterpret_cast<const void*>(
|
||||
device_tensor->tensor_data().data())
|
||||
<< " "
|
||||
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
|
||||
<< device_tensor->NumElements();
|
||||
std::shared_ptr<xla::GlobalData> global_data =
|
||||
GetTensorGlobalData(*device_tensor);
|
||||
|
||||
xla::Shape shape;
|
||||
Status status =
|
||||
TensorShapeToXLAShape(cpu_tensor->dtype(), cpu_tensor->shape(), &shape);
|
||||
if (!status.ok()) {
|
||||
done(status);
|
||||
return;
|
||||
}
|
||||
auto result = client_->Transfer(*global_data, &shape);
|
||||
if (!result.ok()) {
|
||||
done(result.status());
|
||||
return;
|
||||
}
|
||||
const void* src_ptr = xla::LiteralUtil::InternalData(*result.ValueOrDie());
|
||||
void* dst_ptr = DMAHelper::base(cpu_tensor);
|
||||
size_t total_bytes = cpu_tensor->TotalBytes();
|
||||
memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
} else {
|
||||
VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
|
||||
}
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
std::shared_ptr<xla::GlobalData> XlaTransferManager::GetTensorGlobalData(
|
||||
const Tensor& tensor) {
|
||||
const XlaGlobalData* data = CastTensorToXlaGlobalData(tensor);
|
||||
mutex_lock lock(data->mu);
|
||||
CHECK(data->data);
|
||||
return data->data;
|
||||
}
|
||||
|
||||
void XlaTransferManager::SetTensorGlobalData(
|
||||
std::shared_ptr<xla::GlobalData> global_data, Tensor* tensor) {
|
||||
XlaGlobalData* data = CastTensorToXlaGlobalData(tensor);
|
||||
mutex_lock lock(data->mu);
|
||||
data->data = std::move(global_data);
|
||||
}
|
||||
|
||||
XlaDeviceContext::XlaDeviceContext(xla::Client* client) : manager_(client) {}
|
||||
|
||||
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
Device* device,
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const {
|
||||
manager_.CopyCPUTensorToDevice(cpu_tensor, device, device_tensor, done);
|
||||
}
|
||||
|
||||
void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
StringPiece tensor_name,
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) {
|
||||
manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
|
||||
done);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
92
tensorflow/compiler/jit/xla_device_context.h
Normal file
92
tensorflow/compiler/jit/xla_device_context.h
Normal file
@ -0,0 +1,92 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The allocator used for Tensors assigned to the XLA device. The allocator
|
||||
// doesn't actually back Tensors with storage. Instead, each tensor is a thin
|
||||
// wrapper around XLA-managed storage.
|
||||
class XlaDeviceAllocator : public Allocator {
|
||||
public:
|
||||
XlaDeviceAllocator();
|
||||
~XlaDeviceAllocator() override;
|
||||
|
||||
string Name() override;
|
||||
|
||||
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
|
||||
void DeallocateRaw(void* ptr) override;
|
||||
void GetStats(AllocatorStats* stats) override;
|
||||
|
||||
private:
|
||||
void RunStringCtor(string* p, size_t n) override;
|
||||
void RunStringDtor(string* p, size_t n) override;
|
||||
void RunResourceCtor(ResourceHandle* p, size_t n) override;
|
||||
void RunResourceDtor(ResourceHandle* p, size_t n) override;
|
||||
};
|
||||
|
||||
// Helper class for managing data transfers between host and XLA devices.
|
||||
class XlaTransferManager {
|
||||
public:
|
||||
explicit XlaTransferManager(xla::Client* client);
|
||||
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor, StatusCallback done) const;
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
StringPiece tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done);
|
||||
|
||||
// Helper methods to get/set the xla::GlobalData backing a Tensor on the
|
||||
// XlaDevice.
|
||||
static std::shared_ptr<xla::GlobalData> GetTensorGlobalData(
|
||||
const Tensor& tensor);
|
||||
static void SetTensorGlobalData(std::shared_ptr<xla::GlobalData> global_data,
|
||||
Tensor* tensor);
|
||||
|
||||
private:
|
||||
xla::Client* client_;
|
||||
};
|
||||
|
||||
// DeviceContext for operators assigned to XlaDevice devices. The
|
||||
// implementation must inherit from DeviceContext but otherwise just
|
||||
// wraps the methods in XlaTransferManager.
|
||||
class XlaDeviceContext : public DeviceContext {
|
||||
public:
|
||||
explicit XlaDeviceContext(xla::Client* client);
|
||||
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const override;
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
StringPiece tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) override;
|
||||
|
||||
private:
|
||||
XlaTransferManager manager_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
|
171
tensorflow/compiler/jit/xla_device_launch_op.cc
Normal file
171
tensorflow/compiler/jit/xla_device_launch_op.cc
Normal file
@ -0,0 +1,171 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device_launch_op.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_context.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** compiler) {
|
||||
XlaDevice::Metadata* metadata;
|
||||
Status s = rm->Lookup<XlaDevice::Metadata>(rm->default_container(),
|
||||
"xla_metadata", &metadata);
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
core::ScopedUnref metadata_ref(metadata);
|
||||
XlaCompiler::Options options;
|
||||
options.device_type = metadata->jit_device_type();
|
||||
options.client = metadata->client();
|
||||
options.allow_cpu_custom_calls = false;
|
||||
options.local_executable_has_hybrid_result = false;
|
||||
*compiler = new XlaCompilationCache(options);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
const NameAttrList* func;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
|
||||
function_ = *func;
|
||||
VLOG(1) << "XlaDeviceLaunch created function="
|
||||
<< Canonicalize(function_.name(), function_.attr());
|
||||
DataTypeVector constant_types;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
|
||||
num_constant_args_ = constant_types.size();
|
||||
}
|
||||
|
||||
void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "XlaDeviceLaunch::Compute "
|
||||
<< Canonicalize(function_.name(), function_.attr());
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
// in the ResourceMgr.
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
|
||||
|
||||
XlaCompilationCache* compiler;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
rm->LookupOrCreate<XlaCompilationCache>(
|
||||
rm->default_container(), "xla_compiler", &compiler,
|
||||
[rm](XlaCompilationCache** compiler) {
|
||||
return BuildCompilationCache(rm, compiler);
|
||||
}));
|
||||
// Hold the reference to the JIT during evaluation. (We could probably
|
||||
// free it sooner because the ResourceMgr will retain a reference, but
|
||||
// this is more obviously correct.)
|
||||
core::ScopedUnref compiler_ref(compiler);
|
||||
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
compiler->Compile(function_, num_constant_args_, ctx, &kernel, nullptr));
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
OP_REQUIRES(ctx, ctx->num_outputs() == kernel->outputs.size(),
|
||||
errors::Internal("Unexpected number of outputs"));
|
||||
|
||||
// Run the computation, if any. There might not be a computation if all
|
||||
// outputs were compile-time constants.
|
||||
std::vector<std::unique_ptr<xla::GlobalData>> outputs;
|
||||
if (!kernel->computation.IsNull()) {
|
||||
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
|
||||
|
||||
// Convert argument tensors to xla::GlobalData pointers.
|
||||
std::vector<std::shared_ptr<xla::GlobalData>> arg_handles(
|
||||
kernel->xla_input_shapes.size());
|
||||
std::vector<xla::GlobalData*> arg_ptrs(kernel->xla_input_shapes.size());
|
||||
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
|
||||
int input_num = kernel->xla_input_shapes[i].first;
|
||||
arg_handles[i] =
|
||||
XlaTransferManager::GetTensorGlobalData(ctx->input(input_num));
|
||||
arg_ptrs[i] = arg_handles[i].get();
|
||||
}
|
||||
|
||||
// Execute the computation.
|
||||
xla::ExecutionProfile profile;
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
auto result = compiler->client()->Execute(
|
||||
kernel->computation, arg_ptrs, &kernel->xla_output_shape, &profile);
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
OP_REQUIRES(ctx, result.ok(), result.status());
|
||||
|
||||
VLOG(1) << "Elapsed time: " << elapsed << "us";
|
||||
VLOG(1) << "ExecutionProfile: " << profile.DebugString();
|
||||
|
||||
if (xla::ShapeUtil::IsTuple(kernel->xla_output_shape)) {
|
||||
auto outputs_or_error =
|
||||
compiler->client()->DeconstructTuple(*result.ValueOrDie());
|
||||
OP_REQUIRES(ctx, outputs_or_error.ok(), outputs_or_error.status());
|
||||
outputs = outputs_or_error.ConsumeValueOrDie();
|
||||
} else {
|
||||
outputs.push_back(result.ConsumeValueOrDie());
|
||||
}
|
||||
}
|
||||
|
||||
XlaDeviceContext* device_context = ctx->op_device_context<XlaDeviceContext>();
|
||||
|
||||
// Copy XLA outputs to the operator's outputs.
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(i, kernel->outputs[i].shape, &output));
|
||||
if (kernel->outputs[i].is_constant) {
|
||||
// TODO(phawkins): mark constant _XlaLaunch outputs as HostMemory and
|
||||
// remove the copy from this code.
|
||||
Status status;
|
||||
device_context->CopyCPUTensorToDevice(
|
||||
&kernel->outputs[i].constant_value, nullptr, output,
|
||||
[&status](const Status& s) { status = s; });
|
||||
if (!status.ok()) {
|
||||
ctx->SetStatus(status);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
CHECK_LT(output_num, outputs.size());
|
||||
XlaTransferManager::SetTensorGlobalData(
|
||||
std::shared_ptr<xla::GlobalData>(std::move(outputs[output_num])),
|
||||
output);
|
||||
++output_num;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "Done";
|
||||
}
|
||||
|
||||
XlaDeviceLaunchOp::~XlaDeviceLaunchOp() {
|
||||
VLOG(1) << "XlaDeviceLaunch destroyed";
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
50
tensorflow/compiler/jit/xla_device_launch_op.h
Normal file
50
tensorflow/compiler/jit/xla_device_launch_op.h
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The XlaDeviceLaunchOp is used to replace a region of the TensorFlow graph
|
||||
// which will be compiled and executed using XLA. The XlaDeviceLaunchOp is
|
||||
// responsible for handling interactions with the TensorFlow executor.
|
||||
// Once all inputs are present, and their shapes are known, the op can
|
||||
// use a 'TlaJit' to compile and execute code which is specific
|
||||
// to the shapes of input Tensors.
|
||||
class XlaDeviceLaunchOp : public OpKernel {
|
||||
public:
|
||||
explicit XlaDeviceLaunchOp(OpKernelConstruction* ctx);
|
||||
~XlaDeviceLaunchOp() override;
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
NameAttrList function_;
|
||||
int num_constant_args_;
|
||||
Tensor dummy_tensor_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceLaunchOp);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_
|
36
tensorflow/compiler/jit/xla_device_ops.cc
Normal file
36
tensorflow/compiler/jit/xla_device_ops.cc
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device_context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void XlaDeviceAssignOp::Copy(OpKernelContext* context, Tensor* lhs,
|
||||
const Tensor& rhs) {
|
||||
std::shared_ptr<xla::GlobalData> gd =
|
||||
XlaTransferManager::GetTensorGlobalData(rhs);
|
||||
XlaTransferManager::SetTensorGlobalData(std::move(gd), lhs);
|
||||
}
|
||||
|
||||
XlaDeviceDummyOp::XlaDeviceDummyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) {
|
||||
LOG(FATAL) << "Attempted to execute Op " << name() << "type " << type_string()
|
||||
<< " on an XLA device. This should never happen.";
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
118
tensorflow/compiler/jit/xla_device_ops.h
Normal file
118
tensorflow/compiler/jit/xla_device_ops.h
Normal file
@ -0,0 +1,118 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Common kernel registrations for XLA devices.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device_launch_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/kernels/assign_op.h"
|
||||
#include "tensorflow/core/kernels/constant_op.h"
|
||||
#include "tensorflow/core/kernels/control_flow_ops.h"
|
||||
#include "tensorflow/core/kernels/identity_op.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
#include "tensorflow/core/kernels/sendrecv_ops.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Implementation of Assign for XLA devices.
|
||||
class XlaDeviceAssignOp : public AssignOp {
|
||||
public:
|
||||
using AssignOp::AssignOp;
|
||||
|
||||
void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override;
|
||||
};
|
||||
|
||||
// Dummy OpKernel, used for kernels assigned to an XLA device that should be
|
||||
// compiled. Should never be called at runtime since such ops should be
|
||||
// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an
|
||||
// operator on an XLA device but the compiler did not compile it.
|
||||
class XlaDeviceDummyOp : public OpKernel {
|
||||
public:
|
||||
explicit XlaDeviceDummyOp(OpKernelConstruction* ctx);
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
};
|
||||
|
||||
#define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_XlaLaunch").Device(DEVICE).HostMemory("constants"), KERNEL);
|
||||
|
||||
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_HostSend").Device(DEVICE).HostMemory("tensor"), SendOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_HostRecv").Device(DEVICE).HostMemory("tensor"), RecvOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
||||
ConstantOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), \
|
||||
XlaDeviceDummyOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Variable").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
||||
VariableOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("VariableV2").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
||||
VariableOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("TemporaryVariable").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
||||
TemporaryVariableOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint("T", TYPES), \
|
||||
DestroyTemporaryVariableOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint("dtype", TYPES) \
|
||||
.HostMemory("is_initialized"), \
|
||||
IsVariableInitializedOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Assign").Device(DEVICE).TypeConstraint("T", TYPES), \
|
||||
XlaDeviceAssignOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \
|
||||
ControlTriggerOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \
|
||||
NextIterationOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \
|
||||
SwitchOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("LoopCond") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("input") \
|
||||
.HostMemory("output"), \
|
||||
IdentityOp);
|
||||
|
||||
// TODO(phawkins): do we really need Placeholder? Should it be a real
|
||||
// implementation of Placeholder?
|
||||
|
||||
// TODO(b/32507444): the registrations for the control flow operators are
|
||||
// temporary and exist primarily to work around a bug in the graph partitioning
|
||||
// code.
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
|
65
tensorflow/compiler/jit/xla_gpu_device.cc
Normal file
65
tensorflow/compiler/jit/xla_gpu_device.cc
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
|
||||
// operators using XLA via the XLA "CUDA" (GPU) backend.
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const DEVICE_XLA_GPU = "XLA_GPU";
|
||||
|
||||
class XlaGpuDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<Device*>* devices) override;
|
||||
};
|
||||
|
||||
Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices) {
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
|
||||
(void)registrations;
|
||||
|
||||
std::unique_ptr<XlaDevice> device;
|
||||
Status status =
|
||||
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
|
||||
name_prefix, &device);
|
||||
if (!status.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
LOG(WARNING) << "Failed to create XLA_GPU device: " << status;
|
||||
return Status::OK();
|
||||
}
|
||||
devices->push_back(device.release());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 5> kAllXlaGpuTypes = {
|
||||
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaDeviceLaunchOp, kAllXlaGpuTypes);
|
||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
|
||||
|
||||
} // namespace tensorflow
|
342
tensorflow/compiler/jit/xla_local_launch_op.cc
Normal file
342
tensorflow/compiler/jit/xla_local_launch_op.cc
Normal file
@ -0,0 +1,342 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_local_launch_op.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
|
||||
namespace gpu = perftools::gputools;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("_XlaLaunch")
|
||||
.Input("constants: Tconstants")
|
||||
.Attr("Tconstants: list(type) >= 0")
|
||||
.Input("args: Targs")
|
||||
.Attr("Targs: list(type) >= 0")
|
||||
.Output("results: Tresults")
|
||||
.Attr("Tresults: list(type) >= 0")
|
||||
.Attr("function: func")
|
||||
.Doc("XLA Launch Op. For use by the XLA JIT only.");
|
||||
|
||||
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
|
||||
class XlaAllocator : public xla::DeviceMemoryAllocator {
|
||||
public:
|
||||
XlaAllocator(const perftools::gputools::Platform* platform,
|
||||
OpKernelContext* op_context);
|
||||
~XlaAllocator() override;
|
||||
xla::StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure = true) override;
|
||||
Status Deallocate(int device_ordinal,
|
||||
perftools::gputools::DeviceMemoryBase* mem) override;
|
||||
|
||||
// Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is
|
||||
// interpreted as having data type 'dtype' and shape 'shape'.
|
||||
Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype,
|
||||
const TensorShape& shape, Tensor* tensor) const;
|
||||
|
||||
private:
|
||||
OpKernelContext* const op_context_;
|
||||
|
||||
// Map from pointer address to the owning Tensor; used by
|
||||
// MakeTensorFromBuffer. Also used to automatically release Tensors when the
|
||||
// allocator is freed.
|
||||
std::unordered_map<void*, Tensor> tensors_;
|
||||
};
|
||||
|
||||
XlaAllocator::XlaAllocator(const perftools::gputools::Platform* platform,
|
||||
OpKernelContext* op_context)
|
||||
: xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
|
||||
|
||||
XlaAllocator::~XlaAllocator() = default;
|
||||
|
||||
xla::StatusOr<perftools::gputools::DeviceMemoryBase> XlaAllocator::Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) {
|
||||
AllocatorAttributes allocator_attrs;
|
||||
allocator_attrs.set_on_host(false);
|
||||
|
||||
AllocationAttributes allocation_attrs;
|
||||
allocation_attrs.no_retry_on_failure = !retry_on_failure;
|
||||
|
||||
Tensor t;
|
||||
Status status = op_context_->allocate_temp(
|
||||
DT_UINT8, TensorShape({static_cast<int64>(size)}), &t, allocator_attrs,
|
||||
allocation_attrs);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Allocation failed " << size;
|
||||
return status;
|
||||
}
|
||||
void* data =
|
||||
reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data()));
|
||||
TF_RET_CHECK(data != nullptr);
|
||||
tensors_[data] = t;
|
||||
return perftools::gputools::DeviceMemoryBase(data, size);
|
||||
}
|
||||
|
||||
Status XlaAllocator::Deallocate(int device_ordinal,
|
||||
perftools::gputools::DeviceMemoryBase* mem) {
|
||||
if (mem->opaque() != nullptr) {
|
||||
if (tensors_.erase(mem->opaque()) == 0) {
|
||||
return tensorflow::errors::InvalidArgument("Unknown tensor address");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer,
|
||||
DataType dtype,
|
||||
const TensorShape& shape,
|
||||
Tensor* out_tensor) const {
|
||||
void* ptr = const_cast<void*>(buffer.opaque());
|
||||
auto it = tensors_.find(ptr);
|
||||
if (it == tensors_.end()) {
|
||||
return errors::InvalidArgument("Unknown tensor address");
|
||||
}
|
||||
const Tensor& tensor = it->second;
|
||||
|
||||
int64 output_size = DataTypeSize(dtype) * shape.num_elements();
|
||||
if (tensor.TotalBytes() == output_size) {
|
||||
out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape);
|
||||
} else {
|
||||
Tensor slice = tensor.Slice(0, output_size);
|
||||
out_tensor->UnsafeCopyFromInternal(slice, dtype, shape);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), device_type_(ctx->device_type()) {
|
||||
const NameAttrList* func;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
|
||||
function_ = *func;
|
||||
DataTypeVector constant_types;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
|
||||
num_constant_args_ = constant_types.size();
|
||||
}
|
||||
|
||||
Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) {
|
||||
gpu::Platform::Id platform_id;
|
||||
if (device_type_ == DeviceType(DEVICE_CPU)) {
|
||||
platform_id = gpu::host::kHostPlatformId;
|
||||
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
|
||||
platform_id = gpu::cuda::kCudaPlatformId;
|
||||
} else {
|
||||
return errors::InvalidArgument("Unknown device type for local _XlaLaunch");
|
||||
}
|
||||
|
||||
auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id);
|
||||
if (!platform.ok()) {
|
||||
return StreamExecutorUtil::ConvertStatus(platform.status());
|
||||
}
|
||||
auto client =
|
||||
xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie());
|
||||
if (!client.ok()) {
|
||||
return client.status();
|
||||
}
|
||||
const string* compiler_device;
|
||||
if (!XlaOpRegistry::GetJitDevice(device_type_.type(), &compiler_device,
|
||||
/*requires_jit=*/nullptr)) {
|
||||
return errors::InvalidArgument("No JIT device registered for ",
|
||||
device_type_.type());
|
||||
}
|
||||
XlaCompiler::Options options;
|
||||
options.device_type = DeviceType(*compiler_device);
|
||||
options.client = client.ValueOrDie();
|
||||
options.allow_cpu_custom_calls = (platform_id == gpu::host::kHostPlatformId);
|
||||
options.local_executable_has_hybrid_result = true;
|
||||
*compiler = new XlaCompilationCache(options);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "XlaLocalLaunchOp::Compute "
|
||||
<< Canonicalize(function_.name(), function_.attr());
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
// in the ResourceMgr.
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
|
||||
|
||||
gpu::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
||||
XlaCompilationCache* compiler;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
rm->LookupOrCreate<XlaCompilationCache>(
|
||||
rm->default_container(), "xla_compiler", &compiler,
|
||||
[this](XlaCompilationCache** compiler) {
|
||||
return BuildCompilationCache(compiler);
|
||||
}));
|
||||
// Hold the reference to the JIT during evaluation. (We could probably
|
||||
// free it sooner because the ResourceMgr will retain a reference, but
|
||||
// this is more obviously correct.)
|
||||
core::ScopedUnref compiler_ref(compiler);
|
||||
|
||||
xla::LocalClient* client = static_cast<xla::LocalClient*>(compiler->client());
|
||||
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
xla::LocalExecutable* executable;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
compiler->Compile(function_, num_constant_args_, ctx, &kernel,
|
||||
&executable));
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
// Builds an XLA allocator for the device.
|
||||
XlaAllocator xla_allocator(client->platform(), ctx);
|
||||
XlaLocalRuntimeContext local_runtime_context;
|
||||
|
||||
std::unique_ptr<xla::ShapedBuffer> output;
|
||||
bool output_is_tuple;
|
||||
if (!kernel->computation.IsNull()) {
|
||||
// Build xla::ShapedBuffers that point directly to the Tensor buffers.
|
||||
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
|
||||
arg_buffers.reserve(kernel->xla_input_shapes.size() + 1);
|
||||
arg_buffers.resize(kernel->xla_input_shapes.size());
|
||||
std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size());
|
||||
|
||||
// Pass remaining parameters.
|
||||
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
|
||||
int arg_num = kernel->xla_input_shapes[i].first;
|
||||
const xla::Shape& shape = kernel->xla_input_shapes[i].second;
|
||||
gpu::DeviceMemoryBase dmem(
|
||||
const_cast<char*>(ctx->input(arg_num).tensor_data().data()),
|
||||
ctx->input(arg_num).tensor_data().size());
|
||||
|
||||
arg_buffers[i] =
|
||||
xla::ShapedBuffer::MakeArrayShapedBuffer(
|
||||
shape, client->platform(), client->default_device_ordinal(), dmem)
|
||||
.ConsumeValueOrDie();
|
||||
arg_ptrs[i] = arg_buffers[i].get();
|
||||
}
|
||||
|
||||
// Make the final parameter point at local_runtime_context.
|
||||
if (kernel->requires_runtime_context) {
|
||||
gpu::DeviceMemoryBase local_runtime_context_dmem(
|
||||
&local_runtime_context, sizeof(local_runtime_context));
|
||||
arg_buffers.push_back(
|
||||
xla::ShapedBuffer::MakeArrayShapedBuffer(
|
||||
xla::ShapeUtil::MakeOpaqueShape(), client->platform(),
|
||||
client->default_device_ordinal(), local_runtime_context_dmem)
|
||||
.ConsumeValueOrDie());
|
||||
arg_ptrs.push_back(arg_buffers.back().get());
|
||||
}
|
||||
|
||||
// Execute the computation.
|
||||
VLOG(2) << "Executing computation.";
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(&xla_allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
auto run_result = executable->Run(arg_ptrs, run_options);
|
||||
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
|
||||
|
||||
if (local_runtime_context.error) {
|
||||
ctx->CtxFailure(errors::InvalidArgument(
|
||||
"Compiled kernel returned error: ", local_runtime_context.error_msg));
|
||||
return;
|
||||
}
|
||||
|
||||
output = std::move(run_result.ValueOrDie());
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time: " << elapsed << "us";
|
||||
|
||||
// Computation output should always be a tuple.
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "Result tuple shape: " << output->shape().DebugString();
|
||||
}
|
||||
output_is_tuple = xla::ShapeUtil::IsTuple(output->shape());
|
||||
}
|
||||
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
|
||||
|
||||
// Copy XLA results to the OpOutputList.
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
if (kernel->outputs[i].is_constant) {
|
||||
// Output is a constant
|
||||
const Tensor& const_tensor = kernel->outputs[i].constant_value;
|
||||
const size_t total_bytes = const_tensor.TotalBytes();
|
||||
if (stream && total_bytes > 0) {
|
||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||
VLOG(1) << "Constant output tensor on device";
|
||||
Tensor* output_tensor;
|
||||
TF_CHECK_OK(
|
||||
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
|
||||
|
||||
const void* src_ptr = DMAHelper::base(&const_tensor);
|
||||
void* dst_ptr = DMAHelper::base(output_tensor);
|
||||
gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
|
||||
stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
|
||||
} else {
|
||||
// No copy required.
|
||||
ctx->set_output(i, const_tensor);
|
||||
}
|
||||
} else {
|
||||
const TensorShape& shape = kernel->outputs[i].shape;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
|
||||
|
||||
gpu::DeviceMemoryBase buffer;
|
||||
if (output_is_tuple) {
|
||||
buffer = output->buffer({output_num});
|
||||
} else {
|
||||
CHECK_EQ(0, output_num);
|
||||
buffer = output->buffer({});
|
||||
}
|
||||
Tensor output_tensor;
|
||||
// Looks up the owning Tensor by buffer address.
|
||||
OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer(
|
||||
buffer, ctx->expected_output_dtype(i), shape,
|
||||
&output_tensor));
|
||||
ctx->set_output(i, output_tensor);
|
||||
++output_num;
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
VLOG(3) << ctx->mutable_output(i)->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "Done";
|
||||
}
|
||||
|
||||
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
|
||||
VLOG(1) << "XlaLocalLaunchOp destroyed";
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU),
|
||||
XlaLocalLaunchOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_XlaLaunch").Device(DEVICE_GPU).HostMemory("constants"),
|
||||
XlaLocalLaunchOp);
|
||||
|
||||
} // namespace tensorflow
|
55
tensorflow/compiler/jit/xla_local_launch_op.h
Normal file
55
tensorflow/compiler/jit/xla_local_launch_op.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
|
||||
// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
|
||||
// responsible for handling interactions with the TensorFlow executor.
|
||||
// Once all inputs are present, and their shapes are known, the op can
|
||||
// use a 'XlaCompilationCache' to compile and execute code which is specific
|
||||
// to the shapes of input Tensors.
|
||||
// XlaLocalLaunchOp uses xla::LocalClient::ExecuteLocally and passes
|
||||
// arguments into/out of XLA in device memory.
|
||||
class XlaLocalLaunchOp : public OpKernel {
|
||||
public:
|
||||
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
|
||||
~XlaLocalLaunchOp() override;
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
// Builds a XlaCompilationCache class suitable for the current device.
|
||||
Status BuildCompilationCache(XlaCompilationCache** compiler);
|
||||
|
||||
DeviceType device_type_;
|
||||
NameAttrList function_;
|
||||
int num_constant_args_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
|
352
tensorflow/compiler/tests/BUILD
Normal file
352
tensorflow/compiler/tests/BUILD
Normal file
@ -0,0 +1,352 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
|
||||
load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites")
|
||||
|
||||
generate_backend_suites()
|
||||
|
||||
py_library(
|
||||
name = "xla_test",
|
||||
testonly = 1,
|
||||
srcs = ["xla_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/compiler:compiler_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "depthwise_conv2d_test_kernel",
|
||||
testonly = 1,
|
||||
srcs = ["depthwise_conv2d_test_kernel.cc"],
|
||||
deps = ["//tensorflow/core:framework_lite"],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "binary_ops_test",
|
||||
size = "small",
|
||||
srcs = ["binary_ops_test.py"],
|
||||
shard_count = 5,
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:math_ops_gen",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "clustering_test",
|
||||
size = "small",
|
||||
srcs = ["clustering_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "concat_ops_test",
|
||||
size = "small",
|
||||
srcs = ["concat_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:array_ops_gen",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:gradient_checker",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "conv2d_test",
|
||||
size = "medium",
|
||||
srcs = ["conv2d_test.py"],
|
||||
shard_count = 10,
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "dynamic_stitch_test",
|
||||
size = "small",
|
||||
srcs = ["dynamic_stitch_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "function_test",
|
||||
size = "small",
|
||||
srcs = ["function_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "lrn_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["lrn_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "nary_ops_test",
|
||||
size = "small",
|
||||
srcs = ["nary_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "nullary_ops_test",
|
||||
size = "small",
|
||||
srcs = ["nullary_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "pooling_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["pooling_ops_test.py"],
|
||||
shard_count = 10,
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "reduce_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["reduce_ops_test.py"],
|
||||
shard_count = 5,
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "ternary_ops_test",
|
||||
size = "small",
|
||||
srcs = ["ternary_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "unary_ops_test",
|
||||
size = "small",
|
||||
srcs = ["unary_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "xla_device_test",
|
||||
size = "small",
|
||||
srcs = ["xla_device_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "jit_test",
|
||||
size = "medium",
|
||||
srcs = ["jit_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/contrib/compiler:compiler_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "randomized_tests_library",
|
||||
testonly = 1,
|
||||
srcs = ["randomized_tests.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/jit:common",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "randomized_tests",
|
||||
# This test is randomized, so only run it if explicitly requested.
|
||||
tags = [
|
||||
"manual",
|
||||
"noguitar",
|
||||
"notap",
|
||||
],
|
||||
deps = [":randomized_tests_library"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "lstm",
|
||||
testonly = 1,
|
||||
srcs = ["lstm.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "lstm_test",
|
||||
srcs = ["lstm_test.py"],
|
||||
additional_deps = [
|
||||
":lstm",
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
# An example of ahead-of-time compilation using tfcompile. The
|
||||
# lstm_layer_inference.pbtxt file was generated by running lstm_test
|
||||
# --dump_graph_dir, and the config file was written by hand.
|
||||
#
|
||||
# Run the following to build a minimal benchmark of the computation on Android:
|
||||
# $ bazel build -c opt --config=android_arm \
|
||||
# third_party/tensorflow/compiler/tests:lstm_layer_inference_benchmark
|
||||
#
|
||||
# Currently the resulting binary size is ~190KB
|
||||
tf_library(
|
||||
name = "lstm_layer_inference",
|
||||
testonly = 1,
|
||||
config = "lstm_layer_inference.config.pbtxt",
|
||||
cpp_class = "LSTMLayerInference",
|
||||
graph = "lstm_layer_inference.pbtxt",
|
||||
tags = ["manual"],
|
||||
tfcompile_flags = "--xla_cpu_multi_thread_eigen=false",
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
749
tensorflow/compiler/tests/binary_ops_test.py
Normal file
749
tensorflow/compiler/tests/binary_ops_test.py
Normal file
@ -0,0 +1,749 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test cases for binary operators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class BinaryOpsTest(XLATestCase):
|
||||
"""Test cases for binary operators."""
|
||||
|
||||
def _testBinary(self, op, a, b, expected, equality_test=None):
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
|
||||
pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
|
||||
output = op(pa, pb)
|
||||
result = session.run(output, {pa: a, pb: b})
|
||||
if equality_test is None:
|
||||
equality_test = self.assertAllClose
|
||||
equality_test(result, expected, rtol=1e-3)
|
||||
|
||||
def ListsAreClose(self, result, expected, rtol):
|
||||
"""Tests closeness of two lists of floats."""
|
||||
self.assertEqual(len(result), len(expected))
|
||||
for i in range(len(result)):
|
||||
self.assertAllClose(result[i], expected[i], rtol)
|
||||
|
||||
def testFloatOps(self):
|
||||
for dtype in self.float_types:
|
||||
self._testBinary(
|
||||
gen_math_ops._real_div,
|
||||
np.array([3, 3, -1.5, -8, 44], dtype=dtype),
|
||||
np.array([2, -2, 7, -4, 0], dtype=dtype),
|
||||
expected=np.array(
|
||||
[1.5, -1.5, -0.2142857, 2, float("inf")], dtype=dtype))
|
||||
|
||||
self._testBinary(math_ops.pow, dtype(3), dtype(4), expected=dtype(81))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.pow,
|
||||
np.array([1, 2], dtype=dtype),
|
||||
np.zeros(shape=[0, 2], dtype=dtype),
|
||||
expected=np.zeros(shape=[0, 2], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.pow,
|
||||
np.array([10, 4], dtype=dtype),
|
||||
np.array([2, 3], dtype=dtype),
|
||||
expected=np.array([100, 64], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.pow,
|
||||
dtype(2),
|
||||
np.array([3, 4], dtype=dtype),
|
||||
expected=np.array([8, 16], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.pow,
|
||||
np.array([[2], [3]], dtype=dtype),
|
||||
dtype(4),
|
||||
expected=np.array([[16], [81]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._sigmoid_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array([-60, -36, -14, 0], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._rsqrt_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array([-160, -81, -28, -4], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops._softplus_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array(
|
||||
[3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._tanh_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array([-75, -48, -21, 0], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops._relu_grad,
|
||||
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),
|
||||
np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype),
|
||||
expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops._relu6_grad,
|
||||
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype),
|
||||
np.array(
|
||||
[0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype),
|
||||
expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops._softmax_cross_entropy_with_logits,
|
||||
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype),
|
||||
np.array([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], dtype=dtype),
|
||||
expected=[
|
||||
np.array([1.44019, 2.44019], dtype=dtype),
|
||||
np.array([[-0.067941, -0.112856, -0.063117, 0.243914],
|
||||
[-0.367941, -0.212856, 0.036883, 0.543914]],
|
||||
dtype=dtype),
|
||||
],
|
||||
equality_test=self.ListsAreClose)
|
||||
|
||||
def testIntOps(self):
|
||||
for dtype in self.int_types:
|
||||
self._testBinary(
|
||||
gen_math_ops._truncate_div,
|
||||
np.array([3, 3, -1, -9, -8], dtype=dtype),
|
||||
np.array([2, -2, 7, 2, -4], dtype=dtype),
|
||||
expected=np.array([1, -1, 0, -4, 2], dtype=dtype))
|
||||
|
||||
def testNumericOps(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
math_ops.add,
|
||||
np.array([1, 2], dtype=dtype),
|
||||
np.array([10, 20], dtype=dtype),
|
||||
expected=np.array([11, 22], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.add,
|
||||
dtype(5),
|
||||
np.array([1, 2], dtype=dtype),
|
||||
expected=np.array([6, 7], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.add,
|
||||
np.array([[1], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[8], [9]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.sub,
|
||||
np.array([1, 2], dtype=dtype),
|
||||
np.array([10, 20], dtype=dtype),
|
||||
expected=np.array([-9, -18], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.sub,
|
||||
dtype(5),
|
||||
np.array([1, 2], dtype=dtype),
|
||||
expected=np.array([4, 3], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.sub,
|
||||
np.array([[1], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[-6], [-5]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.maximum,
|
||||
np.array([1, 2], dtype=dtype),
|
||||
np.array([10, 20], dtype=dtype),
|
||||
expected=np.array([10, 20], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.maximum,
|
||||
dtype(5),
|
||||
np.array([1, 20], dtype=dtype),
|
||||
expected=np.array([5, 20], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.maximum,
|
||||
np.array([[10], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[10], [7]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.minimum,
|
||||
np.array([1, 20], dtype=dtype),
|
||||
np.array([10, 2], dtype=dtype),
|
||||
expected=np.array([1, 2], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.minimum,
|
||||
dtype(5),
|
||||
np.array([1, 20], dtype=dtype),
|
||||
expected=np.array([1, 5], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.minimum,
|
||||
np.array([[10], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[7], [2]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.mul,
|
||||
np.array([1, 20], dtype=dtype),
|
||||
np.array([10, 2], dtype=dtype),
|
||||
expected=np.array([10, 40], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.mul,
|
||||
dtype(5),
|
||||
np.array([1, 20], dtype=dtype),
|
||||
expected=np.array([5, 100], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.mul,
|
||||
np.array([[10], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[70], [14]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.squared_difference,
|
||||
np.array([1, 2], dtype=dtype),
|
||||
np.array([10, 20], dtype=dtype),
|
||||
expected=np.array([81, 324], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.squared_difference,
|
||||
dtype(5),
|
||||
np.array([1, 2], dtype=dtype),
|
||||
expected=np.array([16, 9], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.squared_difference,
|
||||
np.array([[1], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[36], [25]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
nn_ops.bias_add,
|
||||
np.array([[1, 2], [3, 4]], dtype=dtype),
|
||||
np.array([2, -1], dtype=dtype),
|
||||
expected=np.array([[3, 1], [5, 3]], dtype=dtype))
|
||||
self._testBinary(
|
||||
nn_ops.bias_add,
|
||||
np.array([[[[1, 2], [3, 4]]]], dtype=dtype),
|
||||
np.array([2, -1], dtype=dtype),
|
||||
expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype))
|
||||
|
||||
def _testDivision(self, dtype):
|
||||
"""Test cases for division operators."""
|
||||
self._testBinary(
|
||||
math_ops.div,
|
||||
np.array([10, 20], dtype=dtype),
|
||||
np.array([10, 2], dtype=dtype),
|
||||
expected=np.array([1, 10], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.div,
|
||||
dtype(40),
|
||||
np.array([2, 20], dtype=dtype),
|
||||
expected=np.array([20, 2], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.div,
|
||||
np.array([[10], [4]], dtype=dtype),
|
||||
dtype(2),
|
||||
expected=np.array([[5], [2]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._floor_div,
|
||||
np.array([3, 3, -1, -9, -8], dtype=dtype),
|
||||
np.array([2, -2, 7, 2, -4], dtype=dtype),
|
||||
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
|
||||
|
||||
def testIntDivision(self):
|
||||
for dtype in self.int_types:
|
||||
self._testDivision(dtype)
|
||||
|
||||
def testFloatDivision(self):
|
||||
for dtype in self.float_types:
|
||||
self._testDivision(dtype)
|
||||
|
||||
def _testRemainder(self, dtype):
|
||||
"""Test cases for remainder operators."""
|
||||
self._testBinary(
|
||||
gen_math_ops._floor_mod,
|
||||
np.array([3, 3, -1, -8], dtype=dtype),
|
||||
np.array([2, -2, 7, -4], dtype=dtype),
|
||||
expected=np.array([1, -1, 6, 0], dtype=dtype))
|
||||
self._testBinary(
|
||||
gen_math_ops._truncate_mod,
|
||||
np.array([3, 3, -1, -8], dtype=dtype),
|
||||
np.array([2, -2, 7, -4], dtype=dtype),
|
||||
expected=np.array([1, 1, -1, 0], dtype=dtype))
|
||||
|
||||
def testIntRemainder(self):
|
||||
for dtype in self.int_types:
|
||||
self._testRemainder(dtype)
|
||||
|
||||
def testFloatRemainder(self):
|
||||
for dtype in self.float_types:
|
||||
self._testRemainder(dtype)
|
||||
|
||||
def testLogicalOps(self):
|
||||
self._testBinary(
|
||||
math_ops.logical_and,
|
||||
np.array([[True, False], [False, True]], dtype=np.bool),
|
||||
np.array([[False, True], [False, True]], dtype=np.bool),
|
||||
expected=np.array([[False, False], [False, True]], dtype=np.bool))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.logical_or,
|
||||
np.array([[True, False], [False, True]], dtype=np.bool),
|
||||
np.array([[False, True], [False, True]], dtype=np.bool),
|
||||
expected=np.array([[True, True], [False, True]], dtype=np.bool))
|
||||
|
||||
def testComparisons(self):
|
||||
self._testBinary(
|
||||
math_ops.equal,
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
np.array([10, 5, 2], dtype=np.float32),
|
||||
expected=np.array([False, True, False], dtype=np.bool))
|
||||
self._testBinary(
|
||||
math_ops.equal,
|
||||
np.float32(5),
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
expected=np.array([False, True, False], dtype=np.bool))
|
||||
self._testBinary(
|
||||
math_ops.equal,
|
||||
np.array([[10], [7], [2]], dtype=np.float32),
|
||||
np.float32(7),
|
||||
expected=np.array([[False], [True], [False]], dtype=np.bool))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.not_equal,
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
np.array([10, 5, 2], dtype=np.float32),
|
||||
expected=np.array([True, False, True], dtype=np.bool))
|
||||
self._testBinary(
|
||||
math_ops.not_equal,
|
||||
np.float32(5),
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
expected=np.array([True, False, True], dtype=np.bool))
|
||||
self._testBinary(
|
||||
math_ops.not_equal,
|
||||
np.array([[10], [7], [2]], dtype=np.float32),
|
||||
np.float32(7),
|
||||
expected=np.array([[True], [False], [True]], dtype=np.bool))
|
||||
|
||||
for greater_op in [math_ops.greater, (lambda x, y: x > y)]:
|
||||
self._testBinary(
|
||||
greater_op,
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
np.array([10, 5, 2], dtype=np.float32),
|
||||
expected=np.array([False, False, True], dtype=np.bool))
|
||||
self._testBinary(
|
||||
greater_op,
|
||||
np.float32(5),
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
expected=np.array([True, False, False], dtype=np.bool))
|
||||
self._testBinary(
|
||||
greater_op,
|
||||
np.array([[10], [7], [2]], dtype=np.float32),
|
||||
np.float32(7),
|
||||
expected=np.array([[True], [False], [False]], dtype=np.bool))
|
||||
|
||||
for greater_equal_op in [math_ops.greater_equal, (lambda x, y: x >= y)]:
|
||||
self._testBinary(
|
||||
greater_equal_op,
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
np.array([10, 5, 2], dtype=np.float32),
|
||||
expected=np.array([False, True, True], dtype=np.bool))
|
||||
self._testBinary(
|
||||
greater_equal_op,
|
||||
np.float32(5),
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
expected=np.array([True, True, False], dtype=np.bool))
|
||||
self._testBinary(
|
||||
greater_equal_op,
|
||||
np.array([[10], [7], [2]], dtype=np.float32),
|
||||
np.float32(7),
|
||||
expected=np.array([[True], [True], [False]], dtype=np.bool))
|
||||
|
||||
for less_op in [math_ops.less, (lambda x, y: x < y)]:
|
||||
self._testBinary(
|
||||
less_op,
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
np.array([10, 5, 2], dtype=np.float32),
|
||||
expected=np.array([True, False, False], dtype=np.bool))
|
||||
self._testBinary(
|
||||
less_op,
|
||||
np.float32(5),
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
expected=np.array([False, False, True], dtype=np.bool))
|
||||
self._testBinary(
|
||||
less_op,
|
||||
np.array([[10], [7], [2]], dtype=np.float32),
|
||||
np.float32(7),
|
||||
expected=np.array([[False], [False], [True]], dtype=np.bool))
|
||||
|
||||
for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]:
|
||||
self._testBinary(
|
||||
less_equal_op,
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
np.array([10, 5, 2], dtype=np.float32),
|
||||
expected=np.array([True, True, False], dtype=np.bool))
|
||||
self._testBinary(
|
||||
less_equal_op,
|
||||
np.float32(5),
|
||||
np.array([1, 5, 20], dtype=np.float32),
|
||||
expected=np.array([False, True, True], dtype=np.bool))
|
||||
self._testBinary(
|
||||
less_equal_op,
|
||||
np.array([[10], [7], [2]], dtype=np.float32),
|
||||
np.float32(7),
|
||||
expected=np.array([[False], [True], [True]], dtype=np.bool))
|
||||
|
||||
def testBroadcasting(self):
|
||||
"""Tests broadcasting behavior of an operator."""
|
||||
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
math_ops.add,
|
||||
np.array(3, dtype=dtype),
|
||||
np.array([10, 20], dtype=dtype),
|
||||
expected=np.array([13, 23], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.add,
|
||||
np.array([10, 20], dtype=dtype),
|
||||
np.array(4, dtype=dtype),
|
||||
expected=np.array([14, 24], dtype=dtype))
|
||||
|
||||
# [1,3] x [4,1] => [4,3]
|
||||
self._testBinary(
|
||||
math_ops.add,
|
||||
np.array([[10, 20, 30]], dtype=dtype),
|
||||
np.array([[1], [2], [3], [4]], dtype=dtype),
|
||||
expected=np.array(
|
||||
[[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]],
|
||||
dtype=dtype))
|
||||
|
||||
# [3] * [4,1] => [4,3]
|
||||
self._testBinary(
|
||||
math_ops.add,
|
||||
np.array([10, 20, 30], dtype=dtype),
|
||||
np.array([[1], [2], [3], [4]], dtype=dtype),
|
||||
expected=np.array(
|
||||
[[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]],
|
||||
dtype=dtype))
|
||||
|
||||
def testFill(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
array_ops.fill,
|
||||
np.array([], dtype=np.int32),
|
||||
dtype(-42),
|
||||
expected=dtype(-42))
|
||||
self._testBinary(
|
||||
array_ops.fill,
|
||||
np.array([1, 2], dtype=np.int32),
|
||||
dtype(7),
|
||||
expected=np.array([[7, 7]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.fill,
|
||||
np.array([3, 2], dtype=np.int32),
|
||||
dtype(50),
|
||||
expected=np.array([[50, 50], [50, 50], [50, 50]], dtype=dtype))
|
||||
|
||||
# Helper method used by testMatMul, testSparseMatMul, testBatchMatMul below.
|
||||
def _testMatMul(self, op):
|
||||
for dtype in self.float_types:
|
||||
self._testBinary(
|
||||
op,
|
||||
np.array([[-0.25]], dtype=dtype),
|
||||
np.array([[8]], dtype=dtype),
|
||||
expected=np.array([[-2]], dtype=dtype))
|
||||
self._testBinary(
|
||||
op,
|
||||
np.array([[100, 10, 0.5]], dtype=dtype),
|
||||
np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype),
|
||||
expected=np.array([[123, 354]], dtype=dtype))
|
||||
self._testBinary(
|
||||
op,
|
||||
np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype),
|
||||
np.array([[100], [10]], dtype=dtype),
|
||||
expected=np.array([[130], [250], [680]], dtype=dtype))
|
||||
self._testBinary(
|
||||
op,
|
||||
np.array([[1000, 100], [10, 1]], dtype=dtype),
|
||||
np.array([[1, 2], [3, 4]], dtype=dtype),
|
||||
expected=np.array([[1300, 2400], [13, 24]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
op,
|
||||
np.array([], dtype=dtype).reshape((2, 0)),
|
||||
np.array([], dtype=dtype).reshape((0, 3)),
|
||||
expected=np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype))
|
||||
|
||||
def testMatMul(self):
|
||||
self._testMatMul(math_ops.matmul)
|
||||
|
||||
# TODO(phawkins): failing on GPU, no registered kernel.
|
||||
def DISABLED_testSparseMatMul(self):
|
||||
# Binary wrappers for sparse_matmul with different hints
|
||||
def SparseMatmulWrapperTF(a, b):
|
||||
return tf.sparse_matmul(a, b, a_is_sparse=True)
|
||||
|
||||
def SparseMatmulWrapperFT(a, b):
|
||||
return tf.sparse_matmul(a, b, b_is_sparse=True)
|
||||
|
||||
def SparseMatmulWrapperTT(a, b):
|
||||
return tf.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True)
|
||||
|
||||
self._testMatMul(tf.sparse_matmul)
|
||||
self._testMatMul(SparseMatmulWrapperTF)
|
||||
self._testMatMul(SparseMatmulWrapperFT)
|
||||
self._testMatMul(SparseMatmulWrapperTT)
|
||||
|
||||
def testBatchMatMul(self):
|
||||
# Same tests as for tf.matmul above.
|
||||
self._testMatMul(math_ops.matmul)
|
||||
|
||||
# Tests with batches of matrices.
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array([[[-0.25]]], dtype=np.float32),
|
||||
np.array([[[8]]], dtype=np.float32),
|
||||
expected=np.array([[[-2]]], dtype=np.float32))
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array([[[-0.25]], [[4]]], dtype=np.float32),
|
||||
np.array([[[8]], [[2]]], dtype=np.float32),
|
||||
expected=np.array([[[-2]], [[8]]], dtype=np.float32))
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array(
|
||||
[[[[1000, 100], [10, 1]], [[2000, 200], [20, 2]]],
|
||||
[[[3000, 300], [30, 3]], [[4000, 400], [40, 4]]]],
|
||||
dtype=np.float32),
|
||||
np.array(
|
||||
[[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[11, 22], [33, 44]],
|
||||
[[55, 66], [77, 88]]]],
|
||||
dtype=np.float32),
|
||||
expected=np.array(
|
||||
[[[[1300, 2400], [13, 24]], [[11400, 13600], [114, 136]]],
|
||||
[[[42900, 79200], [429, 792]], [[250800, 299200], [2508, 2992]]]],
|
||||
dtype=np.float32))
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array([], dtype=np.float32).reshape((2, 2, 0)),
|
||||
np.array([], dtype=np.float32).reshape((2, 0, 3)),
|
||||
expected=np.array(
|
||||
[[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]],
|
||||
dtype=np.float32))
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array([], dtype=np.float32).reshape((0, 2, 4)),
|
||||
np.array([], dtype=np.float32).reshape((0, 4, 3)),
|
||||
expected=np.array([], dtype=np.float32).reshape(0, 2, 3))
|
||||
|
||||
# Regression test for b/31472796.
|
||||
if hasattr(np, "matmul"):
|
||||
x = np.arange(0, 3 * 5 * 16 * 7, dtype=np.float32).reshape((3, 5, 16, 7))
|
||||
self._testBinary(
|
||||
lambda x, y: math_ops.matmul(x, y, adjoint_b=True),
|
||||
x, x,
|
||||
expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
|
||||
|
||||
def testExpandDims(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
array_ops.expand_dims,
|
||||
dtype(7),
|
||||
np.int32(0),
|
||||
expected=np.array([7], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.expand_dims,
|
||||
np.array([42], dtype=dtype),
|
||||
np.int32(0),
|
||||
expected=np.array([[42]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.expand_dims,
|
||||
np.array([], dtype=dtype),
|
||||
np.int32(0),
|
||||
expected=np.array([[]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.expand_dims,
|
||||
np.array([[[1, 2], [3, 4]]], dtype=dtype),
|
||||
np.int32(0),
|
||||
expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.expand_dims,
|
||||
np.array([[[1, 2], [3, 4]]], dtype=dtype),
|
||||
np.int32(1),
|
||||
expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.expand_dims,
|
||||
np.array([[[1, 2], [3, 4]]], dtype=dtype),
|
||||
np.int32(2),
|
||||
expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.expand_dims,
|
||||
np.array([[[1, 2], [3, 4]]], dtype=dtype),
|
||||
np.int32(3),
|
||||
expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype))
|
||||
|
||||
def testPad(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
array_ops.pad,
|
||||
np.array(
|
||||
[[1, 2, 3], [4, 5, 6]], dtype=dtype),
|
||||
np.array(
|
||||
[[1, 2], [2, 1]], dtype=np.int32),
|
||||
expected=np.array(
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 2, 3, 0],
|
||||
[0, 0, 4, 5, 6, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
dtype=dtype))
|
||||
|
||||
def testReshape(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
array_ops.reshape,
|
||||
np.array([], dtype=dtype),
|
||||
np.array([0, 4], dtype=np.int32),
|
||||
expected=np.zeros(shape=[0, 4], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.reshape,
|
||||
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
|
||||
np.array([2, 3], dtype=np.int32),
|
||||
expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.reshape,
|
||||
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
|
||||
np.array([3, 2], dtype=np.int32),
|
||||
expected=np.array([[0, 1], [2, 3], [4, 5]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.reshape,
|
||||
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
|
||||
np.array([-1, 6], dtype=np.int32),
|
||||
expected=np.array([[0, 1, 2, 3, 4, 5]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.reshape,
|
||||
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
|
||||
np.array([6, -1], dtype=np.int32),
|
||||
expected=np.array([[0], [1], [2], [3], [4], [5]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.reshape,
|
||||
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
|
||||
np.array([2, -1], dtype=np.int32),
|
||||
expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.reshape,
|
||||
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
|
||||
np.array([-1, 3], dtype=np.int32),
|
||||
expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
|
||||
|
||||
def testSplit(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x),
|
||||
np.int32(0),
|
||||
np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
|
||||
dtype=dtype),
|
||||
expected=[
|
||||
np.array([[[1], [2]]], dtype=dtype),
|
||||
np.array([[[3], [4]]], dtype=dtype),
|
||||
np.array([[[5], [6]]], dtype=dtype),
|
||||
],
|
||||
equality_test=self.ListsAreClose)
|
||||
|
||||
self._testBinary(
|
||||
lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x),
|
||||
np.int32(1),
|
||||
np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
|
||||
dtype=dtype),
|
||||
expected=[
|
||||
np.array([[[1]], [[3]], [[5]]], dtype=dtype),
|
||||
np.array([[[2]], [[4]], [[6]]], dtype=dtype),
|
||||
],
|
||||
equality_test=self.ListsAreClose)
|
||||
|
||||
def testTile(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
array_ops.tile,
|
||||
np.array([[6]], dtype=dtype),
|
||||
np.array([1, 2], dtype=np.int32),
|
||||
expected=np.array([[6, 6]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.tile,
|
||||
np.array([[1], [2]], dtype=dtype),
|
||||
np.array([1, 2], dtype=np.int32),
|
||||
expected=np.array([[1, 1], [2, 2]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.tile,
|
||||
np.array([[1, 2], [3, 4]], dtype=dtype),
|
||||
np.array([3, 2], dtype=np.int32),
|
||||
expected=np.array(
|
||||
[[1, 2, 1, 2],
|
||||
[3, 4, 3, 4],
|
||||
[1, 2, 1, 2],
|
||||
[3, 4, 3, 4],
|
||||
[1, 2, 1, 2],
|
||||
[3, 4, 3, 4]],
|
||||
dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.tile,
|
||||
np.array([[1, 2], [3, 4]], dtype=dtype),
|
||||
np.array([1, 1], dtype=np.int32),
|
||||
expected=np.array(
|
||||
[[1, 2],
|
||||
[3, 4]],
|
||||
dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.tile,
|
||||
np.array([[1, 2]], dtype=dtype),
|
||||
np.array([3, 1], dtype=np.int32),
|
||||
expected=np.array(
|
||||
[[1, 2],
|
||||
[1, 2],
|
||||
[1, 2]],
|
||||
dtype=dtype))
|
||||
|
||||
def testTranspose(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
array_ops.transpose,
|
||||
np.zeros(shape=[1, 0, 4], dtype=dtype),
|
||||
np.array([1, 2, 0], dtype=np.int32),
|
||||
expected=np.zeros(shape=[0, 4, 1], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.transpose,
|
||||
np.array([[1, 2], [3, 4]], dtype=dtype),
|
||||
np.array([0, 1], dtype=np.int32),
|
||||
expected=np.array([[1, 2], [3, 4]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.transpose,
|
||||
np.array([[1, 2], [3, 4]], dtype=dtype),
|
||||
np.array([1, 0], dtype=np.int32),
|
||||
expected=np.array([[1, 3], [2, 4]], dtype=dtype))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
78
tensorflow/compiler/tests/build_defs.bzl
Normal file
78
tensorflow/compiler/tests/build_defs.bzl
Normal file
@ -0,0 +1,78 @@
|
||||
"""Build rules for Tensorflow/XLA testing."""
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
|
||||
|
||||
def all_backends():
|
||||
if cuda_is_configured():
|
||||
return ["cpu", "gpu"]
|
||||
else:
|
||||
return ["cpu"]
|
||||
|
||||
def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
|
||||
backends=None, **kwargs):
|
||||
"""Generates py_test targets, one per XLA backend.
|
||||
|
||||
This rule generates py_test() targets named name_backend, for each backend
|
||||
in all_backends(). The rule also generates a test suite with named `name` that
|
||||
tests all backends for the test.
|
||||
|
||||
For example, the following rule generates test cases foo_test_cpu,
|
||||
foo_test_gpu, and a test suite name foo_test that tests both.
|
||||
tf_xla_py_test(
|
||||
name="foo_test",
|
||||
srcs="foo_test.py",
|
||||
deps=[...],
|
||||
)
|
||||
|
||||
Args:
|
||||
name: Name of the target.
|
||||
srcs: Sources for the target.
|
||||
deps: Dependencies of the target.
|
||||
tags: Tags to apply to the generated targets.
|
||||
data: Data dependencies of the target.
|
||||
main: Same as py_test's main attribute.
|
||||
backends: A list of backends to test. Supported values include "cpu" and
|
||||
"gpu". If not specified, defaults to all backends.
|
||||
**kwargs: keyword arguments passed onto the generated py_test() rules.
|
||||
"""
|
||||
if backends == None:
|
||||
backends = all_backends()
|
||||
|
||||
test_names = []
|
||||
for backend in backends:
|
||||
test_name = "{}_{}".format(name, backend)
|
||||
backend_tags = ["tf_xla_{}".format(backend)]
|
||||
backend_args = []
|
||||
backend_deps = []
|
||||
backend_data = []
|
||||
if backend == "cpu":
|
||||
backend_args += ["--test_device=XLA_CPU",
|
||||
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
|
||||
elif backend == "gpu":
|
||||
backend_args += ["--test_device=XLA_GPU",
|
||||
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
|
||||
backend_tags += ["requires-gpu-sm35"]
|
||||
else:
|
||||
fail("Unknown backend {}".format(backend))
|
||||
|
||||
native.py_test(
|
||||
name=test_name,
|
||||
srcs=srcs,
|
||||
srcs_version="PY2AND3",
|
||||
args=backend_args,
|
||||
main="{}.py".format(name) if main == None else main,
|
||||
data=data + backend_data,
|
||||
deps=deps + backend_deps,
|
||||
tags=tags + backend_tags,
|
||||
**kwargs
|
||||
)
|
||||
test_names.append(test_name)
|
||||
native.test_suite(name=name, tests=test_names)
|
||||
|
||||
def generate_backend_suites(backends=[]):
|
||||
"""Generates per-backend test_suites that run all tests for a backend."""
|
||||
if not backends:
|
||||
backends = all_backends()
|
||||
for backend in backends:
|
||||
native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend])
|
||||
|
102
tensorflow/compiler/tests/clustering_test.py
Normal file
102
tensorflow/compiler/tests/clustering_test.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the behavior of the auto-compilation pass."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
|
||||
|
||||
class ClusteringTest(XLATestCase):
|
||||
|
||||
def testAdd(self):
|
||||
val1 = np.array([4, 3, 2, 1], dtype=np.float32)
|
||||
val2 = np.array([5, 6, 7, 8], dtype=np.float32)
|
||||
expected = val1 + val2
|
||||
with self.test_session():
|
||||
with self.test_scope():
|
||||
input1 = constant_op.constant(val1, name="const1")
|
||||
input2 = constant_op.constant(val2, name="const2")
|
||||
output = math_ops.add(input1, input2)
|
||||
result = output.eval()
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testAddFromCpuMultiple(self):
|
||||
val1 = np.array([4, 3, 2, 1]).astype(np.float32)
|
||||
val2 = np.array([5, 6, 7, 8]).astype(np.float32)
|
||||
expected = val1 + val2
|
||||
with self.test_session():
|
||||
with ops.device(CPU_DEVICE):
|
||||
input1 = constant_op.constant(val1, name="const1")
|
||||
input2 = constant_op.constant(val2, name="const2")
|
||||
with self.test_scope():
|
||||
output = math_ops.add(input1, input2)
|
||||
for _ in xrange(10):
|
||||
result = output.eval()
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testDeadlock(self):
|
||||
# Builds a graph of the form:
|
||||
# x -> y
|
||||
# | \
|
||||
# z -> w
|
||||
# where x and z are placed on the CPU and y and w are placed on the XLA
|
||||
# device. If y and w are clustered for compilation, then the graph will
|
||||
# deadlock since the clustered graph will contain a self-loop.
|
||||
with self.test_session() as sess:
|
||||
with ops.device(CPU_DEVICE):
|
||||
x = array_ops.placeholder(dtypes.float32, [2])
|
||||
with self.test_scope():
|
||||
y = x * 2
|
||||
with ops.device(CPU_DEVICE):
|
||||
z = y * y
|
||||
with self.test_scope():
|
||||
w = y + z
|
||||
result = sess.run(w, {x: [1.5, 0.5]})
|
||||
self.assertAllClose(result, [12., 2.], rtol=1e-3)
|
||||
|
||||
def testHostMemory(self):
|
||||
with self.test_session() as sess:
|
||||
x = array_ops.placeholder(dtypes.int32)
|
||||
with self.test_scope():
|
||||
y = x + 1
|
||||
with ops.device(CPU_DEVICE):
|
||||
# Place a computation on the CPU, so y and w cannot be merged into the
|
||||
# same JIT compilation.
|
||||
z = y * 2
|
||||
with self.test_scope():
|
||||
# Argument 'y' is a non-constant output of a previous cluster. Make sure
|
||||
# it is properly copied to host memory so it can be used as a
|
||||
# compile-time constant input for this cluster.
|
||||
w = array_ops.reshape(z, y)
|
||||
result = sess.run(w, {x: [1, 0]})
|
||||
expected = np.array([[4], [2]], dtype=np.int32)
|
||||
self.assertAllClose(expected, result, rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
374
tensorflow/compiler/tests/concat_ops_test.py
Normal file
374
tensorflow/compiler/tests/concat_ops_test.py
Normal file
@ -0,0 +1,374 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functional tests for XLA Concat Op."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ConcatTest(XLATestCase):
|
||||
|
||||
def testHStack(self):
|
||||
with self.test_session():
|
||||
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
|
||||
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
|
||||
with self.test_scope():
|
||||
c = array_ops.concat_v2([p1, p2], 0)
|
||||
params = {
|
||||
p1: np.random.rand(4, 4).astype("f"),
|
||||
p2: np.random.rand(4, 4).astype("f")
|
||||
}
|
||||
result = c.eval(feed_dict=params)
|
||||
|
||||
self.assertEqual(result.shape, c.get_shape())
|
||||
self.assertAllEqual(result[:4, :], params[p1])
|
||||
self.assertAllEqual(result[4:, :], params[p2])
|
||||
|
||||
def testVStack(self):
|
||||
with self.test_session():
|
||||
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
|
||||
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
|
||||
with self.test_scope():
|
||||
c = array_ops.concat_v2([p1, p2], 1)
|
||||
params = {
|
||||
p1: np.random.rand(4, 4).astype("f"),
|
||||
p2: np.random.rand(4, 4).astype("f")
|
||||
}
|
||||
result = c.eval(feed_dict=params)
|
||||
|
||||
self.assertEqual(result.shape, c.get_shape())
|
||||
self.assertAllEqual(result[:, :4], params[p1])
|
||||
self.assertAllEqual(result[:, 4:], params[p2])
|
||||
|
||||
def testInt32(self):
|
||||
with self.test_session():
|
||||
p1 = np.random.rand(2, 3).astype("i")
|
||||
p2 = np.random.rand(2, 3).astype("i")
|
||||
x1 = constant_op.constant(p1)
|
||||
x2 = constant_op.constant(p2)
|
||||
with self.test_scope():
|
||||
c = array_ops.concat_v2([x1, x2], 0)
|
||||
result = c.eval()
|
||||
self.assertAllEqual(result[:2, :], p1)
|
||||
self.assertAllEqual(result[2:, :], p2)
|
||||
|
||||
def _testRandom(self, dtype):
|
||||
# Random dims of rank 5
|
||||
shape = np.random.randint(1, 5, size=5)
|
||||
# Random number of tensors, but always > 1.
|
||||
num_tensors = np.random.randint(2, 10)
|
||||
# Random dim to concat on
|
||||
concat_dim = np.random.randint(5)
|
||||
params = {}
|
||||
if dtype == dtypes.bfloat16:
|
||||
dtype_feed = dtypes.float32
|
||||
else:
|
||||
dtype_feed = dtype
|
||||
with self.test_session():
|
||||
p = []
|
||||
for i in np.arange(num_tensors):
|
||||
input_shape = shape
|
||||
input_shape[concat_dim] = np.random.randint(1, 5)
|
||||
placeholder = array_ops.placeholder(dtype_feed, shape=input_shape)
|
||||
p.append(placeholder)
|
||||
|
||||
t = dtype_feed.as_numpy_dtype
|
||||
params[placeholder] = np.random.rand(*input_shape).astype(t)
|
||||
|
||||
if dtype != dtype_feed:
|
||||
concat_inputs = [math_ops.cast(p_i, dtype) for p_i in p]
|
||||
else:
|
||||
concat_inputs = p
|
||||
with self.test_scope():
|
||||
c = array_ops.concat_v2(concat_inputs, concat_dim)
|
||||
if dtype != dtype_feed:
|
||||
c = math_ops.cast(c, dtype_feed)
|
||||
result = c.eval(feed_dict=params)
|
||||
|
||||
self.assertEqual(result.shape, c.get_shape())
|
||||
cur_offset = 0
|
||||
|
||||
for i in np.arange(num_tensors):
|
||||
# The index into the result is the ':' along all dimensions
|
||||
# except the concat_dim. slice(0, size) is used for ':', and
|
||||
# a list of slices is used to index into result.
|
||||
ind = [slice(0, params[p[i]].shape[j]) for j in np.arange(5)]
|
||||
ind[concat_dim] = slice(cur_offset,
|
||||
cur_offset + params[p[i]].shape[concat_dim])
|
||||
cur_offset += params[p[i]].shape[concat_dim]
|
||||
if dtype == dtype_feed:
|
||||
self.assertAllEqual(result[ind], params[p[i]])
|
||||
else:
|
||||
self.assertAllClose(result[ind], params[p[i]], 0.01)
|
||||
|
||||
def testRandom(self):
|
||||
self._testRandom(dtypes.float32)
|
||||
self._testRandom(dtypes.int32)
|
||||
|
||||
def _testGradientsSimple(self):
|
||||
with self.test_session():
|
||||
inp = []
|
||||
inp_tensors = []
|
||||
with self.test_scope():
|
||||
for x in [1, 2, 6]:
|
||||
shape = [10, x, 2]
|
||||
t = np.random.rand(*shape).astype("f")
|
||||
inp.append(t)
|
||||
inp_tensors.append(
|
||||
constant_op.constant(
|
||||
[float(y) for y in t.flatten()],
|
||||
shape=shape,
|
||||
dtype=dtypes.float32))
|
||||
c = array_ops.concat_v2(inp_tensors, 1)
|
||||
output_shape = [10, 9, 2]
|
||||
grad_inp = np.random.rand(*output_shape).astype("f")
|
||||
grad_tensor = constant_op.constant(
|
||||
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
||||
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
||||
concated_grad = array_ops.concat_v2(grad, 1)
|
||||
result = concated_grad.eval()
|
||||
self.assertAllEqual(result, grad_inp)
|
||||
|
||||
def testGradientsSimpleAll(self):
|
||||
self._testGradientsSimple()
|
||||
|
||||
def _testGradientsFirstDim(self):
|
||||
with self.test_session():
|
||||
inp = []
|
||||
inp_tensors = []
|
||||
with self.test_scope():
|
||||
for x in [1, 2, 6]:
|
||||
shape = [x, 10, 2]
|
||||
t = np.random.rand(*shape).astype("f")
|
||||
inp.append(t)
|
||||
inp_tensors.append(
|
||||
constant_op.constant(
|
||||
[float(y) for y in t.flatten()],
|
||||
shape=shape,
|
||||
dtype=dtypes.float32))
|
||||
c = array_ops.concat_v2(inp_tensors, 0)
|
||||
output_shape = [9, 10, 2]
|
||||
grad_inp = np.random.rand(*output_shape).astype("f")
|
||||
grad_tensor = constant_op.constant(
|
||||
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
||||
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
||||
concated_grad = array_ops.concat_v2(grad, 0)
|
||||
result = concated_grad.eval()
|
||||
|
||||
self.assertAllEqual(result, grad_inp)
|
||||
|
||||
def testGradientsFirstDimAll(self):
|
||||
self._testGradientsFirstDim()
|
||||
|
||||
def _testGradientsLastDim(self):
|
||||
with self.test_session():
|
||||
inp = []
|
||||
inp_tensors = []
|
||||
with self.test_scope():
|
||||
for x in [1, 2, 6]:
|
||||
shape = [10, 2, x]
|
||||
t = np.random.rand(*shape).astype("f")
|
||||
inp.append(t)
|
||||
inp_tensors.append(
|
||||
constant_op.constant(
|
||||
[float(y) for y in t.flatten()],
|
||||
shape=shape,
|
||||
dtype=dtypes.float32))
|
||||
c = array_ops.concat_v2(inp_tensors, 2)
|
||||
output_shape = [10, 2, 9]
|
||||
grad_inp = np.random.rand(*output_shape).astype("f")
|
||||
grad_tensor = constant_op.constant(
|
||||
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
||||
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
||||
concated_grad = array_ops.concat_v2(grad, 2)
|
||||
result = concated_grad.eval()
|
||||
|
||||
self.assertAllEqual(result, grad_inp)
|
||||
|
||||
def testGradientsLastDimAll(self):
|
||||
self._testGradientsLastDim()
|
||||
|
||||
def _RunAndVerifyGradientsRandom(self):
|
||||
# Random dims of rank 5
|
||||
input_shape = np.random.randint(1, 5, size=5)
|
||||
# Random number of tensors
|
||||
num_tensors = np.random.randint(1, 10)
|
||||
# Random dim to concat on
|
||||
concat_dim = np.random.randint(5)
|
||||
concat_dim_sizes = np.random.randint(1, 5, size=num_tensors)
|
||||
with self.test_session():
|
||||
inp = []
|
||||
inp_tensors = []
|
||||
with self.test_scope():
|
||||
for x in concat_dim_sizes:
|
||||
shape = input_shape
|
||||
shape[concat_dim] = x
|
||||
t = np.random.rand(*shape).astype("f")
|
||||
inp.append(t)
|
||||
inp_tensors.append(
|
||||
constant_op.constant(
|
||||
[float(y) for y in t.flatten()],
|
||||
shape=shape,
|
||||
dtype=dtypes.float32))
|
||||
c = array_ops.concat_v2(inp_tensors, concat_dim)
|
||||
output_shape = input_shape
|
||||
output_shape[concat_dim] = concat_dim_sizes.sum()
|
||||
grad_inp = np.random.rand(*output_shape).astype("f")
|
||||
grad_tensor = constant_op.constant(
|
||||
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
||||
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
||||
concated_grad = array_ops.concat_v2(grad, concat_dim)
|
||||
result = concated_grad.eval()
|
||||
|
||||
self.assertAllEqual(result, grad_inp)
|
||||
|
||||
def testGradientsRandom(self):
|
||||
for _ in range(5):
|
||||
self._RunAndVerifyGradientsRandom()
|
||||
|
||||
# Re-enable once zero-element Retvals are handled correctly.
|
||||
def DISABLED_testZeroSize(self):
|
||||
# Verify that concat doesn't crash and burn for zero size inputs
|
||||
np.random.seed(7)
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
for shape0 in (), (2,):
|
||||
axis = len(shape0)
|
||||
for shape1 in (), (3,):
|
||||
for n0 in 0, 1, 2:
|
||||
for n1 in 0, 1, 2:
|
||||
x0 = np.random.randn(*(shape0 + (n0,) + shape1))
|
||||
x1 = np.random.randn(*(shape0 + (n1,) + shape1))
|
||||
correct = np.concatenate([x0, x1], axis=axis)
|
||||
# TODO(irving): Make tf.concat handle map, then drop list().
|
||||
xs = list(map(constant_op.constant, [x0, x1]))
|
||||
c = array_ops.concat_v2(xs, axis)
|
||||
self.assertAllEqual(c.eval(), correct)
|
||||
# Check gradients
|
||||
dc = np.random.randn(*c.get_shape().as_list())
|
||||
dxs = sess.run(gradients_impl.gradients(c, xs, dc))
|
||||
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
|
||||
|
||||
def testTensorConcatDim0Grad(self):
|
||||
x_shapes = [[20, 7, 3], [10, 7, 3], [14, 7, 3]]
|
||||
output_shape = [44, 7, 3]
|
||||
x_vals = [
|
||||
np.random.random_sample(x_shape).astype(np.float32)
|
||||
for x_shape in x_shapes
|
||||
]
|
||||
with self.test_session():
|
||||
with self.test_scope():
|
||||
xs = [constant_op.constant(x_val) for x_val in x_vals]
|
||||
output = array_ops.concat_v2(xs, 0)
|
||||
err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
|
||||
output_shape)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testTensorConcatDim1Grad(self):
|
||||
x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]]
|
||||
output_shape = [20, 11, 3]
|
||||
x_vals = [
|
||||
np.random.random_sample(x_shape).astype(np.float32)
|
||||
for x_shape in x_shapes
|
||||
]
|
||||
with self.test_session():
|
||||
with self.test_scope():
|
||||
xs = [constant_op.constant(x_val) for x_val in x_vals]
|
||||
output = array_ops.concat_v2(xs, 1)
|
||||
err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
|
||||
output_shape)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testConcatTuple(self):
|
||||
c1 = np.random.rand(4, 4).astype(np.float32)
|
||||
c2 = np.random.rand(4, 4).astype(np.float32)
|
||||
with self.test_session():
|
||||
with self.test_scope():
|
||||
concat_list_t = array_ops.concat_v2([c1, c2], 0)
|
||||
concat_tuple_t = array_ops.concat_v2((c1, c2), 0)
|
||||
self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
|
||||
|
||||
def testConcatNoScalars(self):
|
||||
with self.test_session():
|
||||
with self.test_scope():
|
||||
scalar = constant_op.constant(7)
|
||||
dim = array_ops.placeholder(dtypes.int32)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"Can't concatenate scalars \(use tf\.pack instead\)"):
|
||||
array_ops.concat_v2([scalar, scalar, scalar], dim)
|
||||
|
||||
|
||||
class ConcatOffsetTest(XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
cdim = constant_op.constant(1, dtypes.int32)
|
||||
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||
off = gen_array_ops._concat_offset(cdim, [s0, s1, s2])
|
||||
ans = sess.run(off)
|
||||
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
|
||||
|
||||
|
||||
class PackTest(XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = sess.run(packed)
|
||||
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
|
||||
|
||||
def testScalars(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
s0 = constant_op.constant(2, dtypes.int32)
|
||||
s1 = constant_op.constant(3, dtypes.int32)
|
||||
s2 = constant_op.constant(5, dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = sess.run(packed)
|
||||
self.assertAllEqual(ans, [2, 3, 5])
|
||||
|
||||
def testEmpty(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
s0 = constant_op.constant([[]], dtypes.int32)
|
||||
s1 = constant_op.constant([[]], dtypes.int32)
|
||||
s2 = constant_op.constant([[]], dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = sess.run(packed)
|
||||
self.assertAllEqual(ans, [[[]], [[]], [[]]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
526
tensorflow/compiler/tests/conv2d_test.py
Normal file
526
tensorflow/compiler/tests/conv2d_test.py
Normal file
@ -0,0 +1,526 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Conv2D via the XLA JIT.
|
||||
|
||||
The canned results in these tests are created by running each test using the
|
||||
Tensorflow CPU device and saving the output.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import nn_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class Conv2DTest(XLATestCase):
|
||||
|
||||
def _VerifyValues(self, input_sizes, filter_sizes, stride, padding, expected):
|
||||
"""Tests that tf.nn.conv2d produces the expected value.
|
||||
|
||||
Args:
|
||||
input_sizes: Input tensor dimensions in
|
||||
[batch, input_rows, input_cols, input_depth].
|
||||
filter_sizes: Filter tensor dimensions in
|
||||
[kernel_rows, kernel_cols, input_depth, output_depth].
|
||||
stride: Stride.
|
||||
padding: Padding type.
|
||||
expected: Expected output.
|
||||
"""
|
||||
total_size_1 = np.prod(input_sizes)
|
||||
total_size_2 = np.prod(filter_sizes)
|
||||
x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes)
|
||||
x2 = np.arange(1, total_size_2 + 1, dtype=np.float32).reshape(filter_sizes)
|
||||
strides = [1, stride, stride, 1]
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
|
||||
t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
|
||||
out = nn_ops.conv2d(
|
||||
t1, t2, strides=strides, padding=padding, data_format="NHWC")
|
||||
value = sess.run(out, {t1: x1, t2: x2})
|
||||
self.assertArrayNear(expected, np.ravel(value), 1e-3)
|
||||
|
||||
def testConv2D1x1Filter(self):
|
||||
expected_output = [
|
||||
30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0,
|
||||
204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0
|
||||
]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
filter_sizes=[1, 1, 3, 3],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2Filter(self):
|
||||
expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
filter_sizes=[2, 2, 3, 3],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x2Filter(self):
|
||||
expected_output = [
|
||||
231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0,
|
||||
936.0, 1029.0
|
||||
]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
filter_sizes=[1, 2, 3, 3],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2FilterStride2(self):
|
||||
expected_output = [2271.0, 2367.0, 2463.0]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
filter_sizes=[2, 2, 3, 3],
|
||||
stride=2,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2FilterStride2Same(self):
|
||||
expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
filter_sizes=[2, 2, 3, 3],
|
||||
stride=2,
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
|
||||
class Conv2DBackpropInputTest(XLATestCase):
|
||||
|
||||
def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride,
|
||||
padding, expected):
|
||||
"""Tests that gen_nn_ops.conv2d_backprop_input produces the expected output.
|
||||
|
||||
Args:
|
||||
input_sizes: Input tensor dimensions in
|
||||
[batch, input_rows, input_cols, input_depth].
|
||||
filter_sizes: Filter tensor dimensions in
|
||||
[kernel_rows, kernel_cols, input_depth, output_depth].
|
||||
out_backprop_sizes: Output gradients tensor dimensions.
|
||||
stride: Stride.
|
||||
padding: Padding type.
|
||||
expected: Expected output.
|
||||
"""
|
||||
total_size_1 = np.prod(filter_sizes)
|
||||
total_size_2 = np.prod(out_backprop_sizes)
|
||||
x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(filter_sizes)
|
||||
x2 = np.arange(
|
||||
1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes)
|
||||
strides = [1, stride, stride, 1]
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
|
||||
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
|
||||
out = gen_nn_ops.conv2d_backprop_input(
|
||||
input_sizes=input_sizes,
|
||||
filter=t1,
|
||||
out_backprop=t2,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format="NHWC")
|
||||
value = sess.run(out, {t1: x1, t2: x2})
|
||||
self.assertArrayNear(expected, np.ravel(value), 1e-3)
|
||||
|
||||
def testConv2D1x1Filter(self):
|
||||
expected_output = [
|
||||
5, 11, 17, 11, 25, 39, 17, 39, 61, 23, 53, 83, 29, 67, 105, 35, 81, 127,
|
||||
41, 95, 149, 47, 109, 171, 53, 123, 193, 59, 137, 215, 65, 151, 237, 71,
|
||||
165, 259, 77, 179, 281, 83, 193, 303, 89, 207, 325, 95, 221, 347.
|
||||
]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 4, 4, 3],
|
||||
filter_sizes=[1, 1, 3, 2],
|
||||
out_backprop_sizes=[1, 4, 4, 2],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x2FilterStride3Width5(self):
|
||||
expected_output = [1, 2, 0, 2, 4]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 1, 5, 1],
|
||||
filter_sizes=[1, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=3,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x2FilterStride3Width6(self):
|
||||
expected_output = [1, 2, 0, 2, 4, 0]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 1, 6, 1],
|
||||
filter_sizes=[1, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=3,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x2FilterStride3Width7(self):
|
||||
expected_output = [1, 2, 0, 2, 4, 0, 0]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 1, 7, 1],
|
||||
filter_sizes=[1, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=3,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2FilterC1Same(self):
|
||||
expected_output = [1, 4, 7, 7, 23, 33]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 2, 3, 1],
|
||||
stride=1,
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2Filter(self):
|
||||
expected_output = [
|
||||
14, 32, 50, 100, 163, 226, 167, 212, 257, 122, 140, 158, 478, 541, 604,
|
||||
437, 482, 527
|
||||
]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
filter_sizes=[2, 2, 3, 3],
|
||||
out_backprop_sizes=[1, 1, 2, 3],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2FilterSame(self):
|
||||
expected_output = [
|
||||
14, 32, 50, 100, 163, 226, 217, 334, 451, 190, 307, 424, 929, 1217,
|
||||
1505, 1487, 1883, 2279
|
||||
]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
filter_sizes=[2, 2, 3, 3],
|
||||
out_backprop_sizes=[1, 2, 3, 3],
|
||||
stride=1,
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x2Filter(self):
|
||||
expected_output = [1, 4, 4, 3, 10, 8, 5, 16, 12]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
filter_sizes=[1, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 3, 2, 1],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x2FilterSame(self):
|
||||
expected_output = [1, 4, 7, 4, 13, 16, 7, 22, 25]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
filter_sizes=[1, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 3, 3, 1],
|
||||
stride=1,
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2FilterStride2(self):
|
||||
expected_output = [1, 2, 5, 4, 6, 0, 0, 0, 0, 0, 3, 6, 13, 8, 12]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 3, 5, 1],
|
||||
filter_sizes=[1, 3, 1, 1],
|
||||
out_backprop_sizes=[1, 2, 2, 1],
|
||||
stride=2,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2FilterStride2Same(self):
|
||||
expected_output = [1, 2, 2, 3, 4, 6]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=2,
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
|
||||
class Conv2DBackpropFilterTest(XLATestCase):
|
||||
|
||||
def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride,
|
||||
padding, expected):
|
||||
"""Tests that gen_nn_ops.conv2d_backprop_filter produces the right output.
|
||||
|
||||
Args:
|
||||
input_sizes: Input tensor dimensions in
|
||||
[batch, input_rows, input_cols, input_depth].
|
||||
filter_sizes: Filter tensor dimensions in
|
||||
[kernel_rows, kernel_cols, input_depth, output_depth].
|
||||
out_backprop_sizes: Output gradients tensor dimensions.
|
||||
stride: Stride.
|
||||
padding: Padding type.
|
||||
expected: Expected output.
|
||||
"""
|
||||
|
||||
total_size_1 = np.prod(input_sizes)
|
||||
total_size_2 = np.prod(out_backprop_sizes)
|
||||
x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes)
|
||||
x2 = np.arange(
|
||||
1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes)
|
||||
strides = [1, stride, stride, 1]
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
|
||||
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
|
||||
tensor = gen_nn_ops.conv2d_backprop_filter(
|
||||
input=t1,
|
||||
filter_sizes=filter_sizes,
|
||||
out_backprop=t2,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format="NHWC")
|
||||
|
||||
value = sess.run(tensor, {t1: x1, t2: x2})
|
||||
self.assertArrayNear(expected, np.ravel(value), 1e-5)
|
||||
|
||||
def testConv2D1x1Filter(self):
|
||||
expected_output = [8056, 8432, 8312, 8704, 8568, 8976]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 4, 4, 3],
|
||||
filter_sizes=[1, 1, 3, 2],
|
||||
out_backprop_sizes=[1, 4, 4, 2],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x2Filter(self):
|
||||
expected_output = [120, 141]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
filter_sizes=[1, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 3, 2, 1],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2FilterDepth1(self):
|
||||
expected_output = [5, 8, 14, 17]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2Filter(self):
|
||||
expected_output = [
|
||||
17, 22, 27, 22, 29, 36, 27, 36, 45, 32, 43, 54, 37, 50, 63, 42, 57, 72,
|
||||
62, 85, 108, 67, 92, 117, 72, 99, 126, 77, 106, 135, 82, 113, 144, 87,
|
||||
120, 153
|
||||
]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
filter_sizes=[2, 2, 3, 3],
|
||||
out_backprop_sizes=[1, 1, 2, 3],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x2FilterStride3Width5(self):
|
||||
expected_output = [9, 12]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 1, 5, 1],
|
||||
filter_sizes=[1, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=3,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x2FilterStride3Width6(self):
|
||||
expected_output = [9, 12]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 1, 6, 1],
|
||||
filter_sizes=[1, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=3,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x2FilterStride3Width7(self):
|
||||
expected_output = [9, 12]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 1, 7, 1],
|
||||
filter_sizes=[1, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=3,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x3Filter(self):
|
||||
expected_output = [5, 8, 11]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 1, 4, 1],
|
||||
filter_sizes=[1, 3, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x3FilterSame(self):
|
||||
expected_output = [20, 30, 20]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 1, 4, 1],
|
||||
filter_sizes=[1, 3, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 4, 1],
|
||||
stride=1,
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D1x3FilterSameOutbackprop2(self):
|
||||
expected_output = [7, 10, 3]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 1, 4, 1],
|
||||
filter_sizes=[1, 3, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=2,
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2FilterC1Same(self):
|
||||
expected_output = [91, 58, 32, 17]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 2, 3, 1],
|
||||
stride=1,
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2FilterStride2(self):
|
||||
expected_output = [92, 102, 112]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 3, 5, 1],
|
||||
filter_sizes=[1, 3, 1, 1],
|
||||
out_backprop_sizes=[1, 2, 2, 1],
|
||||
stride=2,
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2FilterStride2Same(self):
|
||||
expected_output = [7, 2, 16, 5]
|
||||
self._VerifyValues(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
out_backprop_sizes=[1, 1, 2, 1],
|
||||
stride=2,
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
|
||||
class DepthwiseConv2DTest(XLATestCase):
|
||||
|
||||
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
|
||||
def ConfigsToTest(self):
|
||||
input_sizes = [[4, 35, 35, 2], [4, 147, 147, 2], [3, 299, 299, 3],
|
||||
[5, 183, 183, 1]]
|
||||
filter_sizes = [[5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]]
|
||||
strides = [1, 3, 2, 2]
|
||||
# pylint: disable=invalid-name
|
||||
VALID = "VALID"
|
||||
SAME = "SAME"
|
||||
# pylint: enable=invalid-name
|
||||
paddings = [SAME, VALID, SAME, SAME, SAME]
|
||||
for i, f, s, p in zip(input_sizes, filter_sizes, strides, paddings):
|
||||
yield i, f, s, p
|
||||
|
||||
def _VerifyValues(self, input_size, filter_size, stride, padding):
|
||||
imag = np.random.rand(*input_size).astype(np.float32)
|
||||
filt = np.random.rand(*filter_size).astype(np.float32)
|
||||
strides = [1, stride, stride, 1]
|
||||
|
||||
with self.test_session():
|
||||
with self.test_scope():
|
||||
imag_ph = array_ops.placeholder(dtypes.float32, shape=input_size)
|
||||
filt_ph = array_ops.placeholder(dtypes.float32, shape=filter_size)
|
||||
feed_dict = {imag_ph: imag, filt_ph: filt}
|
||||
xla_out = nn_impl.depthwise_conv2d(imag_ph, filt_ph, strides,
|
||||
padding).eval(feed_dict=feed_dict)
|
||||
|
||||
with self.test_session():
|
||||
with ops.device(self.CPU_DEVICE):
|
||||
imag_ph = array_ops.placeholder(dtypes.float32, shape=input_size)
|
||||
filt_ph = array_ops.placeholder(dtypes.float32, shape=filter_size)
|
||||
feed_dict = {imag_ph: imag, filt_ph: filt}
|
||||
cpu_out = nn_impl.depthwise_conv2d(imag_ph, filt_ph, strides,
|
||||
padding).eval(feed_dict=feed_dict)
|
||||
|
||||
self.assertAllClose(xla_out, cpu_out)
|
||||
|
||||
# This is disabled because we need a mechanism to set command-line flags,
|
||||
# i.e. an implementation of SetCommandLineOption() below.
|
||||
#
|
||||
# def _VerifyDummy(self, input_size, filter_size, stride, padding):
|
||||
# imag = np.random.rand(*input_size).astype(np.float32)
|
||||
# filt = np.random.rand(*filter_size).astype(np.float32)
|
||||
# strides = [1, stride, stride, 1]
|
||||
#
|
||||
# with self.test_session():
|
||||
# with self.test_scope():
|
||||
# imag_ph = tf.placeholder(tf.float32, shape=input_size)
|
||||
# filt_ph = tf.placeholder(tf.float32, shape=filter_size)
|
||||
# feed_dict = {imag_ph: imag, filt_ph: filt}
|
||||
# SetCommandLineOption(
|
||||
# "tf_tla_depthwise_conv2d_custom_func",
|
||||
# "DummyDepthwiseConv2dKernel")
|
||||
# xla_out = tf.nn.depthwise_conv2d(
|
||||
# imag_ph, filt_ph, strides, padding).eval(feed_dict=feed_dict)
|
||||
# SetCommandLineOption(
|
||||
# "tf_tla_depthwise_conv2d_custom_func", "")
|
||||
#
|
||||
# expected = np.array(range(np.ravel(xla_out).shape[0]), dtype=np.float32)
|
||||
# self.assertAllClose(np.ravel(xla_out), expected)
|
||||
|
||||
def testBasic(self):
|
||||
for i, f, s, p in self.ConfigsToTest():
|
||||
self._VerifyValues(i, f, s, p)
|
||||
|
||||
# Test disabled until _VerifyDummy(), above can be implemented.
|
||||
# def testCustomFunc(self):
|
||||
# if self.has_custom_call:
|
||||
# for i, f, s, p in self.ConfigsToTest():
|
||||
# self._VerifyDummy(i, f, s, p)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
30
tensorflow/compiler/tests/depthwise_conv2d_test_kernel.cc
Normal file
30
tensorflow/compiler/tests/depthwise_conv2d_test_kernel.cc
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::int64;
|
||||
|
||||
// A dummy implementation that fills the output with 0, 1, 2,...
|
||||
// to test the custom call implementation of DepthwiseConv2dNative op.
|
||||
// TODO(keveman): Test this after adding a real implementation for the kernel.
|
||||
extern "C" void DummyDepthwiseConv2dKernel(float* output, void** inputs) {
|
||||
const int64* output_size = reinterpret_cast<const int64*>(inputs[4]);
|
||||
const int64 total_size =
|
||||
output_size[0] * output_size[1] * output_size[2] * output_size[3];
|
||||
for (int64 i = 0; i < total_size; ++i) {
|
||||
*(output + i) = i;
|
||||
}
|
||||
}
|
86
tensorflow/compiler/tests/dynamic_stitch_test.py
Normal file
86
tensorflow/compiler/tests/dynamic_stitch_test.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.dynamic_stitch."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class DynamicStitchTest(XLATestCase):
|
||||
|
||||
def _AssertDynamicStitchResultIs(self, indices, data, expected):
|
||||
with self.test_session() as session:
|
||||
index_placeholders = [
|
||||
array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices
|
||||
]
|
||||
data_placeholders = [
|
||||
array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in data
|
||||
]
|
||||
with self.test_scope():
|
||||
output = data_flow_ops.dynamic_stitch(index_placeholders,
|
||||
data_placeholders)
|
||||
|
||||
feed_dict = {}
|
||||
for placeholder, value in zip(index_placeholders, indices):
|
||||
feed_dict[placeholder] = value
|
||||
for placeholder, value in zip(data_placeholders, data):
|
||||
feed_dict[placeholder] = value
|
||||
result = session.run(output, feed_dict=feed_dict)
|
||||
self.assertAllClose(expected, result, rtol=1e-3)
|
||||
|
||||
def testSimpleEmpty(self):
|
||||
idx1 = np.array([0, 2], dtype=np.int32)
|
||||
idx2 = np.array([[1], [3]], dtype=np.int32)
|
||||
val1 = np.array([[], []], dtype=np.int32)
|
||||
val2 = np.array([[[]], [[]]], dtype=np.int32)
|
||||
self._AssertDynamicStitchResultIs(
|
||||
[idx1, idx2], [val1, val2],
|
||||
expected=np.array([[], [], [], []], np.int32))
|
||||
|
||||
def testSimple1D(self):
|
||||
val1 = np.array([0, 4, 7], dtype=np.int32)
|
||||
val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32)
|
||||
val3 = np.array([0, 40, 70], dtype=np.float32)
|
||||
val4 = np.array([10, 60, 20, 30, 50], dtype=np.float32)
|
||||
expected = np.array([0, 10, 20, 30, 40, 50, 60, 70], dtype=np.float32)
|
||||
self._AssertDynamicStitchResultIs(
|
||||
[val1, val2], [val3, val4], expected=expected)
|
||||
|
||||
def testSimple2D(self):
|
||||
val1 = np.array([0, 4, 7], dtype=np.int32)
|
||||
val2 = np.array([1, 6], dtype=np.int32)
|
||||
val3 = np.array([2, 3, 5], dtype=np.int32)
|
||||
val4 = np.array([[0, 1], [40, 41], [70, 71]], dtype=np.float32)
|
||||
val5 = np.array([[10, 11], [60, 61]], dtype=np.float32)
|
||||
val6 = np.array([[20, 21], [30, 31], [50, 51]], dtype=np.float32)
|
||||
expected = np.array(
|
||||
[[0, 1], [10, 11], [20, 21], [30, 31], [40, 41], [50, 51], [60, 61],
|
||||
[70, 71]],
|
||||
dtype=np.float32)
|
||||
self._AssertDynamicStitchResultIs(
|
||||
[val1, val2, val3], [val4, val5, val6], expected=expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
130
tensorflow/compiler/tests/function_test.py
Normal file
130
tensorflow/compiler/tests/function_test.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test cases for Tensorflow functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class FunctionTest(XLATestCase):
|
||||
|
||||
def testFunction(self):
|
||||
"""Executes a simple TensorFlow function."""
|
||||
|
||||
def APlus2B(a, b):
|
||||
return a + b * 2
|
||||
|
||||
aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
|
||||
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
|
||||
expected = APlus2B(aval, bval)
|
||||
|
||||
with self.test_session() as sess:
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def Foo(a, b):
|
||||
return APlus2B(a, b)
|
||||
|
||||
a = constant_op.constant(aval, name="a")
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_f = Foo(a, b)
|
||||
result = sess.run(call_f)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testNestedFunctions(self):
|
||||
"""Executes two nested TensorFlow functions."""
|
||||
|
||||
def TimesTwo(x):
|
||||
return x * 2
|
||||
|
||||
def APlus2B(a, b):
|
||||
return a + TimesTwo(b)
|
||||
|
||||
aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
|
||||
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
|
||||
expected = APlus2B(aval, bval)
|
||||
|
||||
with self.test_session() as sess:
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def Foo(a, b):
|
||||
return APlus2B(a, b)
|
||||
|
||||
a = constant_op.constant(aval, name="a")
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_g = Foo(a, b)
|
||||
result = sess.run(call_g)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testFunctionMultipleRetvals(self):
|
||||
"""Executes a function with multiple return values."""
|
||||
|
||||
# This function will run on the XLA device
|
||||
def Func(a, b):
|
||||
return a + b, a - b
|
||||
|
||||
aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
|
||||
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
|
||||
expected = Func(aval, bval)
|
||||
|
||||
with self.test_session() as sess:
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def Foo(a, b):
|
||||
return Func(a, b)
|
||||
|
||||
a = constant_op.constant(aval, name="a")
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_f = Foo(a, b)
|
||||
result = sess.run(call_f)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testFunctionsNoInline(self):
|
||||
|
||||
@function.Defun(dtypes.float32, noinline=True)
|
||||
def TimesTwo(x):
|
||||
return x * 2
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def APlus2B(a, b):
|
||||
return a + TimesTwo(b)
|
||||
|
||||
aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
|
||||
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
|
||||
expected = aval + bval * 2
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
a = array_ops.placeholder(dtypes.float32, name="a")
|
||||
b = array_ops.placeholder(dtypes.float32, name="b")
|
||||
call = APlus2B(a, b)
|
||||
result = sess.run(call, {a: aval, b: bval})
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
459
tensorflow/compiler/tests/jit_test.py
Normal file
459
tensorflow/compiler/tests/jit_test.py
Normal file
@ -0,0 +1,459 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for JIT compilation on the CPU and GPU devices."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.compiler import jit
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
jit_scope = jit.experimental_jit_scope
|
||||
|
||||
|
||||
def CompiledKernel(fn, *inputs, **kwargs):
|
||||
"""Execute 'fn' as a compiled XLA kernel, with 'inputs'."""
|
||||
name = kwargs.pop("name", None)
|
||||
noinline = kwargs.pop("noinline", None)
|
||||
|
||||
@function.Defun(func_name=name, noinline=noinline, compiled=True)
|
||||
def Compiled(*args):
|
||||
return fn(*args)
|
||||
|
||||
return Compiled(*inputs)
|
||||
|
||||
|
||||
def RunMetadataLabels(run_metadata):
|
||||
"""Returns all labels in run_metadata."""
|
||||
labels = []
|
||||
for dev_stats in run_metadata.step_stats.dev_stats:
|
||||
for node_stats in dev_stats.node_stats:
|
||||
labels.append(node_stats.timeline_label)
|
||||
return labels
|
||||
|
||||
|
||||
def InLabels(labels, substr):
|
||||
"""Returns true iff one of the labels contains substr."""
|
||||
return any([substr in x for x in labels])
|
||||
|
||||
|
||||
def MetadataHasXlaLaunch(run_metadata):
|
||||
"""Returns true if there is a _XlaLaunch kernel in run_metadata's timeline."""
|
||||
|
||||
# TODO(phawkins): find a less hacky way to test whether a kernel ran.
|
||||
return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch")
|
||||
|
||||
|
||||
class JitLaunchTest(test.TestCase):
|
||||
|
||||
# Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel.
|
||||
# Verifies that the outputs match and that XLA was invoked. 'fn' must take
|
||||
# the same number of tensors as arguments that are in 'args', and must return
|
||||
# a tuple of output tensors.
|
||||
# If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node
|
||||
# actually ran. However, it is sometimes possible for _XlaLaunch ops to be
|
||||
# constant-folded away, so the check is optional.
|
||||
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
|
||||
with session_lib.Session() as sess:
|
||||
placeholders = []
|
||||
feeds = {}
|
||||
for arg in args:
|
||||
placeholder = array_ops.placeholder(
|
||||
dtypes.as_dtype(arg.dtype), list(arg.shape))
|
||||
placeholders.append(placeholder)
|
||||
feeds[placeholder] = arg
|
||||
|
||||
compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline)
|
||||
direct_op = fn(*placeholders)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
compiled = sess.run(compiled_op,
|
||||
feeds,
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
print("Compiled Result {}".format(compiled))
|
||||
|
||||
if require_kernel_launch:
|
||||
self.assert_(MetadataHasXlaLaunch(run_metadata))
|
||||
|
||||
direct = sess.run(direct_op, feeds)
|
||||
print("Direct Result {}".format(direct))
|
||||
|
||||
if (isinstance(compiled, (tuple, list)) and
|
||||
(isinstance(direct, (tuple, list)))):
|
||||
for (x, y) in zip(compiled, direct):
|
||||
self.assertAllClose(x, y, rtol=1e-1)
|
||||
else:
|
||||
self.assertAllClose(compiled, direct)
|
||||
|
||||
def testNoOutputs(self):
|
||||
with session_lib.Session() as sess:
|
||||
# Build a function with a single Const node, whose output is ignored.
|
||||
fdef = function_pb2.FunctionDef()
|
||||
fdef.signature.name = "KernelWithNoOutputs"
|
||||
node = node_def_pb2.NodeDef()
|
||||
node.op = "Const"
|
||||
node.name = "ignored"
|
||||
node.attr["dtype"].type = dtypes.int32.as_datatype_enum
|
||||
tensor = tensor_util.make_tensor_proto([0], dtype=dtypes.int32, shape=[])
|
||||
node.attr["value"].tensor.CopyFrom(tensor)
|
||||
fdef.node_def.extend([node])
|
||||
|
||||
# Check that calling the result as a compiled kernel doesn't crash.
|
||||
@function.Defun(compiled=True)
|
||||
def KernelWithNoOutputs():
|
||||
return constant_op.constant(100)
|
||||
|
||||
# Hack to override the definition. By accessing .definition, we
|
||||
# force the _DefinedFunction initialized internally. Then, we
|
||||
# replace it's internal FunctionDef proto. We do this hack here
|
||||
# because one typically can't construct KernelWithNoOutputs
|
||||
# function via Defun decorator directly.
|
||||
_ = KernelWithNoOutputs.definition
|
||||
foo = KernelWithNoOutputs
|
||||
foo._definition = fdef
|
||||
call = KernelWithNoOutputs()
|
||||
sess.run(call, {})
|
||||
|
||||
def testAliasing(self):
|
||||
"""Regression test for compiled functions that return an aliased buffer.
|
||||
|
||||
XLA returns aliased buffers if outputs are identical. Tests that
|
||||
we handle that case.
|
||||
"""
|
||||
|
||||
def AddOnceReturnTwice(x):
|
||||
y = math_ops.add(x, x)
|
||||
return y, y
|
||||
|
||||
# Exercises compling a function (say, Foo) which calls another
|
||||
# function (say, Bar) which is not inlined. When the compiler compiles
|
||||
# Foo, it needs to symbolic execute Bar correctly regardless whether
|
||||
# Bar is inlined or not.
|
||||
#
|
||||
# Tests compiled=True and noinline=True.
|
||||
self._compare(
|
||||
AddOnceReturnTwice, [np.array(
|
||||
[[[0.5, -1.0]]], dtype=np.float32)],
|
||||
noinline=True)
|
||||
# Tests compiled=True and noinline=False.
|
||||
self._compare(
|
||||
AddOnceReturnTwice, [np.array(
|
||||
[[[0.5, -1.0]]], dtype=np.float32)],
|
||||
noinline=False)
|
||||
|
||||
def testOneConstOutput(self):
|
||||
"""Test consisting of a single constant return value."""
|
||||
|
||||
def OneConstOutput():
|
||||
return constant_op.constant([-3, 44, 99])
|
||||
|
||||
self._compare(OneConstOutput, [], require_kernel_launch=False)
|
||||
|
||||
def testConstZeroElementOutput(self):
|
||||
"""Test consisting of a constant zero element return value."""
|
||||
|
||||
def ConstZeroElementOutput():
|
||||
return array_ops.fill([7, 0], 3.0)
|
||||
|
||||
self._compare(ConstZeroElementOutput, [], require_kernel_launch=False)
|
||||
|
||||
def testSomeConstOutputs(self):
|
||||
"""Test kernels that return a mixture of const and non-const outputs."""
|
||||
|
||||
def SomeConstOutputs(x):
|
||||
return constant_op.constant(
|
||||
[-2, 7]), array_ops.identity(x), constant_op.constant(3.5)
|
||||
|
||||
self._compare(
|
||||
SomeConstOutputs, [np.array(
|
||||
[[1, 2, 3], [4, 5, 6]], dtype=np.float32)])
|
||||
|
||||
def testInt32Input(self):
|
||||
"""Test an int32-typed input.
|
||||
|
||||
On a GPU, int32 tensors will be placed in host memory.
|
||||
"""
|
||||
|
||||
def AddToSelf(x):
|
||||
return math_ops.add(x, x)
|
||||
|
||||
self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)])
|
||||
|
||||
def testMandatoryConstantInput(self):
|
||||
"""Tests an operator that has a mandatory-constant shape input."""
|
||||
|
||||
def FillWithFloat(x):
|
||||
return array_ops.fill(x, 9.5)
|
||||
|
||||
self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)])
|
||||
|
||||
def testMnistForwardFunc(self):
|
||||
"""Compute inference function from MNIST beginners tutorial."""
|
||||
batch_size = 16
|
||||
image_size = 28 * 28
|
||||
num_classes = 10
|
||||
|
||||
# Define a TensorFlow function to compute the forward pass.
|
||||
def MnistForward(w, b, x):
|
||||
return nn_ops.softmax(math_ops.matmul(x, w) + b)
|
||||
|
||||
w = np.random.random_sample((image_size, num_classes)).astype(np.float32)
|
||||
b = np.random.random_sample((num_classes)).astype(np.float32)
|
||||
x = np.random.random_sample((batch_size, image_size)).astype(np.float32)
|
||||
self._compare(MnistForward, [w, b, x])
|
||||
|
||||
def testExplicitMarking(self):
|
||||
"""Test explicit marking of operators to compile."""
|
||||
batch_size = 16
|
||||
image_size = 28 * 28
|
||||
num_classes = 10
|
||||
|
||||
with ops.Graph().as_default():
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
w = array_ops.placeholder(dtypes.float32)
|
||||
b = array_ops.placeholder(dtypes.float32)
|
||||
with jit_scope():
|
||||
y1 = math_ops.matmul(x, w)
|
||||
y2 = math_ops.add(y1, b)
|
||||
with jit_scope():
|
||||
y = math_ops.square(y2)
|
||||
|
||||
dw = np.random.random_sample((image_size, num_classes)).astype(np.float32)
|
||||
db = np.random.random_sample((num_classes)).astype(np.float32)
|
||||
dx = np.random.random_sample((batch_size, image_size)).astype(np.float32)
|
||||
with session_lib.Session() as sess:
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
output = sess.run(y, {x: dx,
|
||||
w: dw,
|
||||
b: db},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
|
||||
# TODO(phawkins): really we would like to test that there were exactly
|
||||
# two kernel launches. However, we have no reliable way to determine
|
||||
# that.
|
||||
self.assert_(MetadataHasXlaLaunch(run_metadata))
|
||||
|
||||
expected = np.square(np.dot(dx, dw) + db)
|
||||
self.assertAllClose(expected, output, rtol=1e-1)
|
||||
|
||||
|
||||
class XlaCompilationTest(test.TestCase):
|
||||
"""Tests for auto-compilation on CPU/GPU devices."""
|
||||
|
||||
def testReshape(self):
|
||||
"""Tests an operator with compile-time constant and non-constant inputs."""
|
||||
|
||||
with self.test_session() as sess:
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
y = array_ops.placeholder(dtypes.int32)
|
||||
with jit_scope():
|
||||
# Reshape's first argument is non-constant in the JIT, but its second
|
||||
# (shape) argument will be treated as a compile-time constant for
|
||||
# each JIT compilation.
|
||||
# We do not use a tf.const() argument since we want to ensure the
|
||||
# shape is still a run-time argument to the JIT, and not
|
||||
# statically known as part of the JIT compilation's input graph.
|
||||
z = array_ops.reshape(x, y)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
out = sess.run(z,
|
||||
{x: np.array([1, 2, 3, 4, 5, 6], np.float32),
|
||||
y: [-1, 3]},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
self.assert_(MetadataHasXlaLaunch(run_metadata))
|
||||
self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
|
||||
|
||||
def testIgnoredArguments(self):
|
||||
"""Tests that JIT computations can ignore formal parameters."""
|
||||
|
||||
with self.test_session() as sess:
|
||||
x = array_ops.placeholder(dtypes.int32)
|
||||
y = array_ops.placeholder(dtypes.int32)
|
||||
with jit_scope():
|
||||
z = math_ops.add(x, x)
|
||||
w = math_ops.add(y, y)
|
||||
# Pulls 'w' into the same compilation via control dependencies.
|
||||
with ops.control_dependencies([w]):
|
||||
n = control_flow_ops.no_op()
|
||||
with ops.control_dependencies([n]):
|
||||
t = math_ops.add(z, z)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
out = sess.run(t, {x: np.int32(7),
|
||||
y: np.int32(404)},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
self.assert_(MetadataHasXlaLaunch(run_metadata))
|
||||
self.assertAllClose(28, out)
|
||||
|
||||
def testLoops(self):
|
||||
"""Tests that compilation accepts computations containing loops."""
|
||||
|
||||
with self.test_session() as session:
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
with jit_scope():
|
||||
c = lambda i, _: math_ops.less(i, 5)
|
||||
b = lambda i, x: (i + 1, x * 2.0 + 1.0)
|
||||
_, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
result = session.run(y, {x: np.float32(2)},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
self.assert_(MetadataHasXlaLaunch(run_metadata))
|
||||
self.assertAllClose(result, np.float32(95), rtol=1e-1)
|
||||
|
||||
def testCond(self):
|
||||
"""Tests that compilation handles switch operators."""
|
||||
|
||||
with self.test_session() as session:
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
y = array_ops.placeholder(dtypes.float32)
|
||||
c = array_ops.placeholder(dtypes.bool)
|
||||
with jit_scope():
|
||||
z = x + 1.0
|
||||
w = control_flow_ops.cond(c, lambda: z, lambda: y)
|
||||
t = math_ops.add(z, w)
|
||||
|
||||
# If JIT compilation chooses to cluster z and t, then execution will
|
||||
# deadlock.
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
result = session.run(t, {x: np.float32(2),
|
||||
y: np.float32(4),
|
||||
c: True},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
self.assert_(MetadataHasXlaLaunch(run_metadata))
|
||||
self.assertAllClose(result, np.float32(6), rtol=1e-1)
|
||||
|
||||
def testNestedFunction(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
|
||||
@function.Defun(compiled=True)
|
||||
def Bar(x, y):
|
||||
return x + 2 * y
|
||||
|
||||
@function.Defun(compiled=True)
|
||||
def Foo(x):
|
||||
return Bar(x * x, x * x * x)
|
||||
|
||||
@function.Defun()
|
||||
def Entry(x):
|
||||
return Foo(x)
|
||||
|
||||
inp = array_ops.placeholder(dtypes.float32)
|
||||
out = Entry(inp)
|
||||
|
||||
with self.test_session(graph=g, use_gpu=True) as sess:
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
val = sess.run(out,
|
||||
feed_dict={inp: [2., 10.]},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
self.assertAllClose(val, [20., 2100.])
|
||||
|
||||
def testLoopDeadlock(self):
|
||||
"""Regression test for bug that caused deadlocks in graphs with loops."""
|
||||
|
||||
with self.test_session() as session:
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
with jit_scope():
|
||||
y = x + 1.0
|
||||
c = lambda i, _x, _y: math_ops.less(i, 5)
|
||||
b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0)
|
||||
_, _, w = control_flow_ops.while_loop(c, b,
|
||||
(constant_op.constant(0), y, x))
|
||||
u = w + y
|
||||
result = session.run(u, {x: np.float32(2)})
|
||||
self.assertAllClose(result, np.float32(63), rtol=1e-1)
|
||||
|
||||
def testGradient(self):
|
||||
"""Tests that the backprop function is properly compiled."""
|
||||
|
||||
def _Run(compiled):
|
||||
|
||||
@function.Defun(compiled=compiled)
|
||||
def Forward(x):
|
||||
return math_ops.log(x)
|
||||
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
y = Forward(x)
|
||||
dx, = gradients_impl.gradients(y, [x], 1.0)
|
||||
|
||||
cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
|
||||
optimizer_options=config_pb2.OptimizerOptions(
|
||||
opt_level=config_pb2.OptimizerOptions.L1,
|
||||
do_function_inlining=True)))
|
||||
with session_lib.Session(graph=g, config=cfg) as sess:
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
dx_val = sess.run(dx,
|
||||
feed_dict={x: 100.},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
self.assertAllClose(dx_val, 0.01)
|
||||
return RunMetadataLabels(run_metadata)
|
||||
|
||||
# SymGrad[f=log(x)](x, dy) = 1/x * dy
|
||||
#
|
||||
# Note: we don't need to compute log(x) for dx due to graph pruning.
|
||||
|
||||
# Do not compile the backprop. We should see one Reciprocal and one Mul.
|
||||
labels = _Run(compiled=False)
|
||||
self.assertFalse(InLabels(labels, "Log"))
|
||||
self.assertTrue(InLabels(labels, "Reciprocal"))
|
||||
self.assertTrue(InLabels(labels, "Mul"))
|
||||
self.assertFalse(InLabels(labels, "_XlaLaunch"))
|
||||
|
||||
# Compile the backprop. One _XlaLaunch.
|
||||
labels = _Run(compiled=True)
|
||||
self.assertFalse(InLabels(labels, "Log"))
|
||||
self.assertFalse(InLabels(labels, "Reciprocal"))
|
||||
self.assertFalse(InLabels(labels, "Mul"))
|
||||
self.assertTrue(InLabels(labels, "_XlaLaunch"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
129
tensorflow/compiler/tests/lrn_ops_test.py
Normal file
129
tensorflow/compiler/tests/lrn_ops_test.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Local Response Normalization ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
|
||||
|
||||
# Local response normalization tests. The forward tests are copied from
|
||||
# tensorflow/python/kernel_tests/lrn_op_test.py
|
||||
class LRNTest(XLATestCase):
|
||||
|
||||
def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0,
|
||||
beta=0.5):
|
||||
"""Compute expected result."""
|
||||
output = copy.deepcopy(input_image)
|
||||
batch_size = input_image.shape[0]
|
||||
rows = input_image.shape[1]
|
||||
cols = input_image.shape[2]
|
||||
depth = input_image.shape[3]
|
||||
for b in range(batch_size):
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
for d in range(depth):
|
||||
begin = max(0, d - lrn_depth_radius)
|
||||
end = min(depth, d + lrn_depth_radius + 1)
|
||||
patch = input_image[b, r, c, begin:end]
|
||||
output[b, r, c, d] /= (
|
||||
np.power(bias + alpha * np.sum(patch * patch), beta))
|
||||
return output
|
||||
|
||||
def _RunAndVerify(self, dtype):
|
||||
with self.test_session():
|
||||
# random shape
|
||||
shape = np.random.randint(1, 16, size=4)
|
||||
# Make depth at least 2 to make it meaningful
|
||||
shape[3] += 1
|
||||
p = array_ops.placeholder(dtype, shape=shape)
|
||||
# random depth_radius, bias, alpha, beta
|
||||
lrn_depth_radius = np.random.randint(1, shape[3])
|
||||
bias = 1.0 + np.random.rand()
|
||||
alpha = 2.0 * np.random.rand()
|
||||
beta = 2.0 * np.random.rand()
|
||||
with self.test_scope():
|
||||
lrn_t = nn.local_response_normalization(
|
||||
p,
|
||||
name="lrn",
|
||||
depth_radius=lrn_depth_radius,
|
||||
bias=bias,
|
||||
alpha=alpha,
|
||||
beta=beta)
|
||||
params = {p: np.random.rand(*shape).astype("f")}
|
||||
result = lrn_t.eval(feed_dict=params)
|
||||
expected = self._LRN(
|
||||
params[p],
|
||||
lrn_depth_radius=lrn_depth_radius,
|
||||
bias=bias,
|
||||
alpha=alpha,
|
||||
beta=beta)
|
||||
err = np.amax(np.abs(result - expected))
|
||||
print("LRN error for bias ", bias, "alpha ", alpha, " beta ", beta, " is ",
|
||||
err)
|
||||
if dtype == dtypes.float32:
|
||||
self.assertTrue(err < 1e-4)
|
||||
else:
|
||||
self.assertTrue(err < 1e-2)
|
||||
self.assertShapeEqual(expected, lrn_t)
|
||||
|
||||
def testCompute(self):
|
||||
for _ in range(2):
|
||||
self._RunAndVerify(dtypes.float32)
|
||||
|
||||
def testLrnGrad(self):
|
||||
# Test for LRNGrad that compares against the CPU implementation.
|
||||
shape = [1, 2, 3, 4]
|
||||
total_size = np.prod(shape)
|
||||
in_image_vals = np.arange(1, total_size + 1, dtype=np.float32)
|
||||
out_image_vals = np.arange(1, total_size + 1, dtype=np.float32)
|
||||
out_grads_vals = np.arange(1, total_size + 1, dtype=np.float32)
|
||||
depth_radius = np.random.randint(1, shape[3])
|
||||
bias = 1.0 + np.random.rand()
|
||||
alpha = 1.0 * np.random.rand()
|
||||
beta = 1.0 * np.random.rand()
|
||||
|
||||
with self.test_session():
|
||||
in_image = constant_op.constant(in_image_vals, shape=shape)
|
||||
out_image = constant_op.constant(out_image_vals, shape=shape)
|
||||
out_grads = constant_op.constant(out_grads_vals, shape=shape)
|
||||
with ops.device(CPU_DEVICE):
|
||||
expected = gen_nn_ops._lrn_grad(out_grads, in_image, out_image,
|
||||
depth_radius, bias, alpha, beta)
|
||||
with self.test_scope():
|
||||
actual = gen_nn_ops._lrn_grad(out_grads, in_image, out_image,
|
||||
depth_radius, bias, alpha, beta)
|
||||
expected_val = expected.eval()
|
||||
actual_val = actual.eval()
|
||||
self.assertAllClose(actual_val, expected_val, rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
158
tensorflow/compiler/tests/lstm.py
Normal file
158
tensorflow/compiler/tests/lstm.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A simple LSTM layer with benchmarks.
|
||||
|
||||
This sets up a simple LSTM (Long Short Term Memory) layer, unrolled to a fixed
|
||||
length sequence. The only deviation from standard LSTM cells is that
|
||||
activations are clipped, inspired by the GNMT machine translation model.
|
||||
The GNMT paper has more details: https://arxiv.org/abs/1609.08144
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
def Clip(x):
|
||||
"""Clips x to the range [-1., 1.]."""
|
||||
return math_ops.maximum(math_ops.minimum(x, 1.), -1.)
|
||||
|
||||
|
||||
def LSTMCellWeightsShape(num_inputs, num_nodes):
|
||||
"""Returns the shape of the weights for a single LSTM cell."""
|
||||
# Dimension 0 accounts for combining x with the previous m state.
|
||||
# Dimension 1 accounts for the in value and the (in, forget, out) gates.
|
||||
return [num_inputs + num_nodes, 4 * num_nodes]
|
||||
|
||||
|
||||
def LSTMCell(weights, m_prev, c_prev, x, pad):
|
||||
"""Unrolls a single LSTM cell with clipped activations forward by one step.
|
||||
|
||||
Args:
|
||||
weights: Weight matrix with shape LSTMCellWeightsShape.
|
||||
m_prev: Previous m states with shape [batch_size, num_nodes].
|
||||
c_prev: Previous c states with shape [batch_size, num_nodes].
|
||||
x: Input with shape [batch_size, num_inputs].
|
||||
pad: Padding with shape [batch_size, 1]. Each padding value is either
|
||||
0 or 1, where 1 indicates padding; i.e. the input is shorter than the
|
||||
sequence length, and the (m, c) states should simply be passed through
|
||||
from the previous states.
|
||||
Returns:
|
||||
The next (m, c) states, each with shape [batch_size, num_nodes].
|
||||
"""
|
||||
# Apply weights to the input and previous hidden state.
|
||||
# The matmul here is the "big" operation.
|
||||
xm = array_ops.concat_v2([x, m_prev], 1)
|
||||
xmw = math_ops.matmul(xm, weights)
|
||||
|
||||
# Element-wise ops for the standard LSTM cell, with clipped activations.
|
||||
# XLA can fuse these operations into a single loop.
|
||||
in_value, in_gate, forget_gate, out_gate = array_ops.split(
|
||||
value=xmw, num_or_size_splits=4, axis=1)
|
||||
in_value = math_ops.tanh(in_value)
|
||||
in_gate = math_ops.sigmoid(in_gate)
|
||||
forget_gate = math_ops.sigmoid(forget_gate)
|
||||
out_gate = math_ops.sigmoid(out_gate)
|
||||
c_next = Clip(Clip(forget_gate * c_prev) + Clip(in_gate * in_value))
|
||||
m_next = Clip(out_gate * c_next)
|
||||
|
||||
# Account for padding.
|
||||
c_next = c_prev * pad + c_next * (1.0 - pad)
|
||||
m_next = m_prev * pad + m_next * (1.0 - pad)
|
||||
|
||||
return m_next, c_next
|
||||
|
||||
|
||||
def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq):
|
||||
"""Unrolls a layer of LSTM cells forward by the sequence length.
|
||||
|
||||
The sequence length is determined by the length of x_seq and pad_seq, which
|
||||
must be the same.
|
||||
|
||||
Args:
|
||||
cell_name: Base name of each cell.
|
||||
weights: Weight matrix with shape LSTMCellWeightsShape.
|
||||
m: Initial m states with shape [batch_size, num_nodes].
|
||||
c: Initial c states with shape [batch_size, num_nodes].
|
||||
x_seq: List of inputs, each with shape [batch_size, num_inputs].
|
||||
The length of the list is the sequence length.
|
||||
pad_seq: List of paddings, each with shape [batch_size, 1].
|
||||
The length of the list is the sequence length.
|
||||
Each padding value is either 0 or 1, where 1 indicates padding;
|
||||
i.e. the input is shorter than the sequence length.
|
||||
Returns:
|
||||
List of per-sequence-step outputs, each with shape [batch_size, num_nodes].
|
||||
Raises:
|
||||
ValueError: If len(x_seq) != len(pad_seq).
|
||||
"""
|
||||
if len(x_seq) != len(pad_seq):
|
||||
raise ValueError('length of x_seq(%d) != pad_seq(%d)' %
|
||||
(len(x_seq), len(pad_seq)))
|
||||
out_seq = []
|
||||
for seq in range(len(x_seq)):
|
||||
with ops.name_scope('%s_%d' % (cell_name, seq)):
|
||||
m, c = LSTMCell(weights, m, c, x_seq[seq], pad_seq[seq])
|
||||
out_seq.append(array_ops.identity(m, name='out'))
|
||||
return out_seq
|
||||
|
||||
|
||||
def RandomVar(shape, name=None):
|
||||
"""Returns a variable of the given shape initialized to random values."""
|
||||
return variables.Variable(
|
||||
random_ops.random_uniform(shape), dtype=dtypes.float32, name=name)
|
||||
|
||||
|
||||
def RandomInputs(batch_size, seq_length, num_inputs):
|
||||
"""Returns randomly initialized (x_seq, pad_seq) sequences."""
|
||||
x_seq = []
|
||||
pad_seq = []
|
||||
with ops.name_scope('inputs'):
|
||||
for seq in range(seq_length):
|
||||
x_seq.append(RandomVar([batch_size, num_inputs], name='x_seq_%d' % seq))
|
||||
# Real padding values are always a sequence of 0 followed by a
|
||||
# sequence of 1, but random values are fine for benchmarking.
|
||||
pad_seq.append(RandomVar([batch_size, 1], name='pad_seq_%d' % seq))
|
||||
return x_seq, pad_seq
|
||||
|
||||
|
||||
def BuildLSTMLayer(batch_size, seq_length, num_inputs, num_nodes):
|
||||
"""Builds a single LSTM layer with random weights and inputs.
|
||||
|
||||
Args:
|
||||
batch_size: Inputs are fed in batches of this size.
|
||||
seq_length: The sequence length to unroll the LSTM layer.
|
||||
num_inputs: Dimension of inputs that are fed into each LSTM cell.
|
||||
num_nodes: The number of nodes in each LSTM cell.
|
||||
|
||||
Returns:
|
||||
(out_seq, weights) pair. The out_seq is a list of per-sequence-step
|
||||
outputs, each with shape [batch_size, num_nodes]. The weights are a list of
|
||||
weight variables that may be trained.
|
||||
"""
|
||||
weights = RandomVar(
|
||||
LSTMCellWeightsShape(num_inputs, num_nodes), name='weights')
|
||||
m = array_ops.zeros([batch_size, num_nodes], name='init_m')
|
||||
c = array_ops.zeros([batch_size, num_nodes], name='init_c')
|
||||
x_seq, pad_seq = RandomInputs(batch_size, seq_length, num_inputs)
|
||||
|
||||
out_seq = LSTMLayer('lstm', weights, m, c, x_seq, pad_seq)
|
||||
return out_seq, [weights]
|
20
tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt
Normal file
20
tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt
Normal file
@ -0,0 +1,20 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed{ id{node_name:"inputs/x_seq_0/read"} shape{dim{size:128}dim{size:1024}} }
|
||||
feed{ id{node_name:"inputs/x_seq_1/read"} shape{dim{size:128}dim{size:1024}} }
|
||||
feed{ id{node_name:"inputs/x_seq_2/read"} shape{dim{size:128}dim{size:1024}} }
|
||||
feed{ id{node_name:"inputs/x_seq_3/read"} shape{dim{size:128}dim{size:1024}} }
|
||||
feed{ id{node_name:"inputs/x_seq_4/read"} shape{dim{size:128}dim{size:1024}} }
|
||||
feed{ id{node_name:"inputs/pad_seq_0/read"} shape{dim{size:128}dim{size:1}} }
|
||||
feed{ id{node_name:"inputs/pad_seq_1/read"} shape{dim{size:128}dim{size:1}} }
|
||||
feed{ id{node_name:"inputs/pad_seq_2/read"} shape{dim{size:128}dim{size:1}} }
|
||||
feed{ id{node_name:"inputs/pad_seq_3/read"} shape{dim{size:128}dim{size:1}} }
|
||||
feed{ id{node_name:"inputs/pad_seq_4/read"} shape{dim{size:128}dim{size:1}} }
|
||||
feed{ id{node_name:"weights/read"} shape{dim{size:2048}dim{size:4096}} }
|
||||
feed{ id{node_name:"init_c"} shape{dim{size:128}dim{size:1024}} }
|
||||
feed{ id{node_name:"init_m"} shape{dim{size:128}dim{size:1024}} }
|
||||
|
||||
fetch{ id{node_name:"lstm_0/out"} }
|
||||
fetch{ id{node_name:"lstm_1/out"} }
|
||||
fetch{ id{node_name:"lstm_2/out"} }
|
||||
fetch{ id{node_name:"lstm_3/out"} }
|
||||
fetch{ id{node_name:"lstm_4/out"} }
|
5828
tensorflow/compiler/tests/lstm_layer_inference.pbtxt
Normal file
5828
tensorflow/compiler/tests/lstm_layer_inference.pbtxt
Normal file
File diff suppressed because it is too large
Load Diff
293
tensorflow/compiler/tests/lstm_test.py
Normal file
293
tensorflow/compiler/tests/lstm_test.py
Normal file
@ -0,0 +1,293 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the LSTM cell and layer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import lstm
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import flags as flags_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
flags = flags_lib
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_integer('batch_size', 128,
|
||||
'Inputs are fed in batches of this size, for both '
|
||||
'inference and training. Larger values cause the matmul '
|
||||
'in each LSTM cell to have higher dimensionality.')
|
||||
flags.DEFINE_integer('seq_length', 60,
|
||||
'Length of the unrolled sequence of LSTM cells in a layer.'
|
||||
'Larger values cause more LSTM matmuls to be run.')
|
||||
flags.DEFINE_integer('num_inputs', 1024,
|
||||
'Dimension of inputs that are fed into each LSTM cell.')
|
||||
flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.')
|
||||
flags.DEFINE_string('device', 'gpu',
|
||||
'TensorFlow device to assign ops to, e.g. "gpu", "cpu". '
|
||||
'For details see documentation for tf.Graph.device.')
|
||||
|
||||
flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in '
|
||||
'*.pbtxt format to this directory.')
|
||||
|
||||
|
||||
def _DumpGraph(graph, basename):
|
||||
if FLAGS.dump_graph_dir:
|
||||
name = os.path.join(FLAGS.dump_graph_dir, basename + '.pbtxt')
|
||||
with open(name, 'w') as f:
|
||||
f.write(str(graph.as_graph_def()))
|
||||
|
||||
|
||||
def _Sigmoid(x):
|
||||
return 1. / (1. + np.exp(-x))
|
||||
|
||||
|
||||
def _Clip(x):
|
||||
return np.maximum(np.minimum(x, 1.), -1.)
|
||||
|
||||
|
||||
class LSTMTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# The tests for a single LSTM cell and LSTM layer use these values as
|
||||
# inputs. We always set the dimensionality of num_inputs=1; thus batch_size
|
||||
# actually represents the different input cases.
|
||||
self._inputs = np.array([[-1.], [-.5], [0.], [.5], [1.]], np.float32)
|
||||
self._batch_size = len(self._inputs)
|
||||
|
||||
def _NextC(self, inputs, weight, m_prev, c_prev):
|
||||
"""Returns the next c states of an LSTM cell."""
|
||||
x = (inputs + m_prev) * weight
|
||||
return _Clip(_Clip(_Sigmoid(x) * c_prev) + _Clip(_Sigmoid(x) * np.tanh(x)))
|
||||
|
||||
def _NextM(self, inputs, weight, m_prev, c_prev):
|
||||
"""Returns the next m states of an LSTM cell."""
|
||||
x = (inputs + m_prev) * weight
|
||||
return _Clip(_Sigmoid(x) * self._NextC(inputs, weight, m_prev, c_prev))
|
||||
|
||||
def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar,
|
||||
pad_scalar):
|
||||
with self.test_session() as sess:
|
||||
num_inputs = 1
|
||||
num_nodes = 1
|
||||
|
||||
weights = init_weights(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
|
||||
m_prev = constant_op.constant([[m_prev_scalar]] * self._batch_size)
|
||||
c_prev = constant_op.constant([[c_prev_scalar]] * self._batch_size)
|
||||
x = constant_op.constant(self._inputs)
|
||||
pad = constant_op.constant([[pad_scalar]] * self._batch_size)
|
||||
|
||||
m, c = lstm.LSTMCell(weights, m_prev, c_prev, x, pad)
|
||||
_DumpGraph(sess.graph, 'lstm_cell_%s_%d_%d_%d' %
|
||||
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
|
||||
|
||||
# Initialize variables and run the unrolled LSTM step.
|
||||
sess.run(variables.global_variables_initializer())
|
||||
return sess.run([m, c])
|
||||
|
||||
def testLSTMCell(self):
|
||||
# Run with all-0 weights, no padding.
|
||||
m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 0., 0., 0.)
|
||||
self.assertAllClose(m, [[0.]] * self._batch_size)
|
||||
self.assertAllClose(c, [[0.]] * self._batch_size)
|
||||
m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 0., 1., 0.)
|
||||
self.assertAllClose(m, [[.25]] * self._batch_size)
|
||||
self.assertAllClose(c, [[.5]] * self._batch_size)
|
||||
m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 1., 0., 0.)
|
||||
self.assertAllClose(m, [[.0]] * self._batch_size)
|
||||
self.assertAllClose(c, [[.0]] * self._batch_size)
|
||||
m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 1., 1., 0.)
|
||||
self.assertAllClose(m, [[.25]] * self._batch_size)
|
||||
self.assertAllClose(c, [[.5]] * self._batch_size)
|
||||
|
||||
# Run with all-1 weights, no padding.
|
||||
for m_prev in [0., 1.]:
|
||||
for c_prev in [0., 1.]:
|
||||
m, c = self._RunLSTMCell('ones',
|
||||
init_ops.ones_initializer(), m_prev, c_prev,
|
||||
0.)
|
||||
self.assertAllClose(m, self._NextM(self._inputs, 1., m_prev, c_prev))
|
||||
self.assertAllClose(c, self._NextC(self._inputs, 1., m_prev, c_prev))
|
||||
|
||||
# Run with random weights.
|
||||
for weight in np.random.rand(3):
|
||||
weight_tf = constant_op.constant(weight, dtypes.float32)
|
||||
random_weight = lambda shape, w=weight_tf: array_ops.fill(shape, w)
|
||||
|
||||
# No padding.
|
||||
for m_prev in [0., 1.]:
|
||||
for c_prev in [0., 1.]:
|
||||
m, c = self._RunLSTMCell('random', random_weight, m_prev, c_prev, 0.)
|
||||
self.assertAllClose(m,
|
||||
self._NextM(self._inputs, weight, m_prev, c_prev))
|
||||
self.assertAllClose(c,
|
||||
self._NextC(self._inputs, weight, m_prev, c_prev))
|
||||
|
||||
# Set padding.
|
||||
for m_prev in [0., 1.]:
|
||||
for c_prev in [0., 1.]:
|
||||
m, c = self._RunLSTMCell('random', random_weight, m_prev, c_prev, 1.)
|
||||
self.assertAllClose(m, [[m_prev]] * self._batch_size)
|
||||
self.assertAllClose(c, [[c_prev]] * self._batch_size)
|
||||
|
||||
def testLSTMLayerErrors(self):
|
||||
num_inputs = 1
|
||||
num_nodes = 1
|
||||
seq_length = 3
|
||||
|
||||
weights = array_ops.zeros(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
|
||||
m = constant_op.constant([[0.]] * self._batch_size)
|
||||
c = constant_op.constant([[0.]] * self._batch_size)
|
||||
x_seq = [constant_op.constant(self._inputs)] * seq_length
|
||||
pad = constant_op.constant([[0.]] * self._batch_size)
|
||||
|
||||
with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
|
||||
lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad])
|
||||
with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
|
||||
lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad] * 2)
|
||||
with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
|
||||
lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad] * 4)
|
||||
|
||||
def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar,
|
||||
pad_scalar):
|
||||
with self.test_session() as sess:
|
||||
num_inputs = 1
|
||||
num_nodes = 1
|
||||
seq_length = 3
|
||||
|
||||
weights = init_weights(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
|
||||
m_init = constant_op.constant([[m_init_scalar]] * self._batch_size)
|
||||
c_init = constant_op.constant([[c_init_scalar]] * self._batch_size)
|
||||
x_seq = [constant_op.constant(self._inputs)] * seq_length
|
||||
pad_seq = [constant_op.constant([[pad_scalar]] * self._batch_size)
|
||||
] * seq_length
|
||||
|
||||
out_seq = lstm.LSTMLayer('lstm', weights, m_init, c_init, x_seq, pad_seq)
|
||||
_DumpGraph(sess.graph, 'lstm_layer_%s_%d_%d_%d' %
|
||||
(basename, m_init_scalar, c_init_scalar, pad_scalar))
|
||||
|
||||
# Initialize variables and run the unrolled LSTM layer.
|
||||
sess.run(variables.global_variables_initializer())
|
||||
return sess.run(out_seq)
|
||||
|
||||
def testLSTMLayer(self):
|
||||
# Run with all-0 weights, no padding.
|
||||
o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 0., 0., 0.)
|
||||
self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
|
||||
o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 0., 1., 0.)
|
||||
self.assertAllClose(o, [[[.25]] * self._batch_size,
|
||||
[[.125]] * self._batch_size,
|
||||
[[.0625]] * self._batch_size])
|
||||
o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 1., 0., 0.)
|
||||
self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
|
||||
o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 1., 1., 0.)
|
||||
self.assertAllClose(o, [[[.25]] * self._batch_size,
|
||||
[[.125]] * self._batch_size,
|
||||
[[.0625]] * self._batch_size])
|
||||
|
||||
# Run with all-1 weights, no padding.
|
||||
weight1 = 1.
|
||||
for m_init in [0., 1.]:
|
||||
for c_init in [0., 1.]:
|
||||
o = self._RunLSTMLayer('ones',
|
||||
init_ops.ones_initializer(), m_init, c_init, 0.)
|
||||
m0 = self._NextM(self._inputs, weight1, m_init, c_init)
|
||||
c0 = self._NextC(self._inputs, weight1, m_init, c_init)
|
||||
self.assertAllClose(o[0], m0)
|
||||
m1 = self._NextM(self._inputs, weight1, m0, c0)
|
||||
c1 = self._NextC(self._inputs, weight1, m0, c0)
|
||||
self.assertAllClose(o[1], m1)
|
||||
m2 = self._NextM(self._inputs, weight1, m1, c1)
|
||||
self.assertAllClose(o[2], m2)
|
||||
|
||||
# Run with random weights.
|
||||
for weight in np.random.rand(3):
|
||||
weight_tf = constant_op.constant(weight, dtypes.float32)
|
||||
random_weight = lambda shape, w=weight_tf: array_ops.fill(shape, w)
|
||||
|
||||
# No padding.
|
||||
for m_init in [0., 1.]:
|
||||
for c_init in [0., 1.]:
|
||||
o = self._RunLSTMLayer('random', random_weight, m_init, c_init, 0.)
|
||||
m0 = self._NextM(self._inputs, weight, m_init, c_init)
|
||||
c0 = self._NextC(self._inputs, weight, m_init, c_init)
|
||||
self.assertAllClose(o[0], m0)
|
||||
m1 = self._NextM(self._inputs, weight, m0, c0)
|
||||
c1 = self._NextC(self._inputs, weight, m0, c0)
|
||||
self.assertAllClose(o[1], m1)
|
||||
m2 = self._NextM(self._inputs, weight, m1, c1)
|
||||
self.assertAllClose(o[2], m2)
|
||||
|
||||
# Set padding.
|
||||
o = self._RunLSTMLayer('random', random_weight, 0., 0., 1.)
|
||||
self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
|
||||
o = self._RunLSTMLayer('random', random_weight, 0., 1., 1.)
|
||||
self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
|
||||
o = self._RunLSTMLayer('random', random_weight, 1., 0., 1.)
|
||||
self.assertAllClose(o, [[[1.]] * self._batch_size] * 3)
|
||||
o = self._RunLSTMLayer('random', random_weight, 1., 1., 1.)
|
||||
self.assertAllClose(o, [[[1.]] * self._batch_size] * 3)
|
||||
|
||||
|
||||
class LSTMBenchmark(test.Benchmark):
|
||||
"""Mcro-benchmarks for a single layer of LSTM cells."""
|
||||
|
||||
def _LayerBuilder(self, do_training):
|
||||
out_seq, weights = lstm.BuildLSTMLayer(FLAGS.batch_size, FLAGS.seq_length,
|
||||
FLAGS.num_inputs, FLAGS.num_nodes)
|
||||
name, fetches = ('lstm_layer_inference', out_seq)
|
||||
if do_training:
|
||||
# Not a real loss function, but good enough for benchmarking backprop.
|
||||
loss = math_ops.reduce_sum(math_ops.add_n(out_seq))
|
||||
dw = gradients_impl.gradients(loss, weights)
|
||||
name, fetches = ('lstm_layer_training', dw)
|
||||
|
||||
_DumpGraph(ops.get_default_graph(),
|
||||
'%s_%d_%d_%d_%d' % (name, FLAGS.batch_size, FLAGS.seq_length,
|
||||
FLAGS.num_inputs, FLAGS.num_nodes))
|
||||
return name, fetches
|
||||
|
||||
def benchmarkLayerInference(self):
|
||||
xla_test.Benchmark(self, lambda: self._LayerBuilder(False), False,
|
||||
FLAGS.device)
|
||||
|
||||
def benchmarkLayerInferenceXLA(self):
|
||||
xla_test.Benchmark(self, lambda: self._LayerBuilder(False), True,
|
||||
FLAGS.device)
|
||||
|
||||
def benchmarkLayerTraining(self):
|
||||
xla_test.Benchmark(self, lambda: self._LayerBuilder(True), False,
|
||||
FLAGS.device)
|
||||
|
||||
def benchmarkLayerTrainingXLA(self):
|
||||
xla_test.Benchmark(self, lambda: self._LayerBuilder(True), True,
|
||||
FLAGS.device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
209
tensorflow/compiler/tests/nary_ops_test.py
Normal file
209
tensorflow/compiler/tests/nary_ops_test.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test cases for operators with > 3 or arbitrary numbers of arguments."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class NAryOpsTest(XLATestCase):
|
||||
|
||||
def _testNAry(self, op, args, expected):
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
placeholders = [
|
||||
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
|
||||
for arg in args
|
||||
]
|
||||
feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
|
||||
output = op(placeholders)
|
||||
result = session.run(output, feeds)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testFloat(self):
|
||||
self._testNAry(math_ops.add_n,
|
||||
[np.array([[1, 2, 3]], dtype=np.float32)],
|
||||
expected=np.array([[1, 2, 3]], dtype=np.float32))
|
||||
|
||||
self._testNAry(math_ops.add_n,
|
||||
[np.array([1, 2], dtype=np.float32),
|
||||
np.array([10, 20], dtype=np.float32)],
|
||||
expected=np.array([11, 22], dtype=np.float32))
|
||||
self._testNAry(math_ops.add_n,
|
||||
[np.array([-4], dtype=np.float32),
|
||||
np.array([10], dtype=np.float32),
|
||||
np.array([42], dtype=np.float32)],
|
||||
expected=np.array([48], dtype=np.float32))
|
||||
|
||||
def testConcat(self):
|
||||
self._testNAry(
|
||||
lambda x: array_ops.concat_v2(x, 0), [
|
||||
np.array(
|
||||
[[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
|
||||
[[7, 8, 9], [10, 11, 12]], dtype=np.float32)
|
||||
],
|
||||
expected=np.array(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.float32))
|
||||
|
||||
self._testNAry(
|
||||
lambda x: array_ops.concat_v2(x, 1), [
|
||||
np.array(
|
||||
[[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
|
||||
[[7, 8, 9], [10, 11, 12]], dtype=np.float32)
|
||||
],
|
||||
expected=np.array(
|
||||
[[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32))
|
||||
|
||||
def testSplitV(self):
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
output = session.run(
|
||||
array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]],
|
||||
dtype=np.float32),
|
||||
[2, 2], 1))
|
||||
expected = [np.array([[1, 2], [5, 6], [9, 0]], dtype=np.float32),
|
||||
np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)]
|
||||
self.assertAllEqual(output, expected)
|
||||
|
||||
def testStridedSlice(self):
|
||||
self._testNAry(lambda x: array_ops.strided_slice(*x),
|
||||
[np.array([[], [], []], dtype=np.float32),
|
||||
np.array([1, 0], dtype=np.int32),
|
||||
np.array([3, 0], dtype=np.int32),
|
||||
np.array([1, 1], dtype=np.int32)],
|
||||
expected=np.array([[], []], dtype=np.float32))
|
||||
|
||||
self._testNAry(lambda x: array_ops.strided_slice(*x),
|
||||
[np.array([[], [], []], dtype=np.float32),
|
||||
np.array([1, 0], dtype=np.int64),
|
||||
np.array([3, 0], dtype=np.int64),
|
||||
np.array([1, 1], dtype=np.int64)],
|
||||
expected=np.array([[], []], dtype=np.float32))
|
||||
|
||||
self._testNAry(lambda x: array_ops.strided_slice(*x),
|
||||
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
dtype=np.float32),
|
||||
np.array([1, 1], dtype=np.int32),
|
||||
np.array([3, 3], dtype=np.int32),
|
||||
np.array([1, 1], dtype=np.int32)],
|
||||
expected=np.array([[5, 6], [8, 9]], dtype=np.float32))
|
||||
|
||||
self._testNAry(lambda x: array_ops.strided_slice(*x),
|
||||
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
dtype=np.float32),
|
||||
np.array([0, 2], dtype=np.int32),
|
||||
np.array([2, 0], dtype=np.int32),
|
||||
np.array([1, -1], dtype=np.int32)],
|
||||
expected=np.array([[3, 2], [6, 5]], dtype=np.float32))
|
||||
|
||||
self._testNAry(lambda x: x[0][0:2, array_ops.newaxis, ::-1],
|
||||
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
dtype=np.float32)],
|
||||
expected=np.array([[[3, 2, 1]], [[6, 5, 4]]],
|
||||
dtype=np.float32))
|
||||
|
||||
self._testNAry(lambda x: x[0][1, :, array_ops.newaxis],
|
||||
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
dtype=np.float32)],
|
||||
expected=np.array([[4], [5], [6]], dtype=np.float32))
|
||||
|
||||
def testStridedSliceGrad(self):
|
||||
# Tests cases where input shape is empty.
|
||||
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
|
||||
[np.array([], dtype=np.int32),
|
||||
np.array([], dtype=np.int32),
|
||||
np.array([], dtype=np.int32),
|
||||
np.array([], dtype=np.int32),
|
||||
np.float32(0.5)],
|
||||
expected=np.array(np.float32(0.5), dtype=np.float32))
|
||||
|
||||
# Tests case where input shape is non-empty, but gradients are empty.
|
||||
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
|
||||
[np.array([3], dtype=np.int32),
|
||||
np.array([0], dtype=np.int32),
|
||||
np.array([0], dtype=np.int32),
|
||||
np.array([1], dtype=np.int32),
|
||||
np.array([], dtype=np.float32)],
|
||||
expected=np.array([0, 0, 0], dtype=np.float32))
|
||||
|
||||
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
|
||||
[np.array([3, 0], dtype=np.int32),
|
||||
np.array([1, 0], dtype=np.int32),
|
||||
np.array([3, 0], dtype=np.int32),
|
||||
np.array([1, 1], dtype=np.int32),
|
||||
np.array([[], []], dtype=np.float32)],
|
||||
expected=np.array([[], [], []], dtype=np.float32))
|
||||
|
||||
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
|
||||
[np.array([3, 3], dtype=np.int32),
|
||||
np.array([1, 1], dtype=np.int32),
|
||||
np.array([3, 3], dtype=np.int32),
|
||||
np.array([1, 1], dtype=np.int32),
|
||||
np.array([[5, 6], [8, 9]], dtype=np.float32)],
|
||||
expected=np.array([[0, 0, 0], [0, 5, 6], [0, 8, 9]],
|
||||
dtype=np.float32))
|
||||
|
||||
def ssg_test(x):
|
||||
return array_ops.strided_slice_grad(*x, shrink_axis_mask=0x4,
|
||||
new_axis_mask=0x1)
|
||||
|
||||
self._testNAry(ssg_test,
|
||||
[np.array([3, 1, 3], dtype=np.int32),
|
||||
np.array([0, 0, 0, 2], dtype=np.int32),
|
||||
np.array([0, 3, 1, -4], dtype=np.int32),
|
||||
np.array([1, 2, 1, -3], dtype=np.int32),
|
||||
np.array([[[1], [2]]], dtype=np.float32)],
|
||||
expected=np.array([[[0, 0, 1]], [[0, 0, 0]], [[0, 0, 2]]],
|
||||
dtype=np.float32))
|
||||
|
||||
ssg_test2 = lambda x: array_ops.strided_slice_grad(*x, new_axis_mask=0x15)
|
||||
self._testNAry(ssg_test2,
|
||||
[np.array([4, 4], dtype=np.int32),
|
||||
np.array([0, 0, 0, 1, 0], dtype=np.int32),
|
||||
np.array([0, 3, 0, 4, 0], dtype=np.int32),
|
||||
np.array([1, 2, 1, 2, 1], dtype=np.int32),
|
||||
np.array([[[[[1], [2]]], [[[3], [4]]]]], dtype=np.float32)],
|
||||
expected=np.array([[0, 1, 0, 2], [0, 0, 0, 0], [0, 3, 0, 4],
|
||||
[0, 0, 0, 0]], dtype=np.float32))
|
||||
|
||||
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
|
||||
[np.array([3, 3], dtype=np.int32),
|
||||
np.array([0, 2], dtype=np.int32),
|
||||
np.array([2, 0], dtype=np.int32),
|
||||
np.array([1, -1], dtype=np.int32),
|
||||
np.array([[1, 2], [3, 4]], dtype=np.float32)],
|
||||
expected=np.array([[0, 2, 1], [0, 4, 3], [0, 0, 0]],
|
||||
dtype=np.float32))
|
||||
|
||||
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
|
||||
[np.array([3, 3], dtype=np.int32),
|
||||
np.array([2, 2], dtype=np.int32),
|
||||
np.array([0, 1], dtype=np.int32),
|
||||
np.array([-1, -2], dtype=np.int32),
|
||||
np.array([[1], [2]], dtype=np.float32)],
|
||||
expected=np.array([[0, 0, 0], [0, 0, 2], [0, 0, 1]],
|
||||
dtype=np.float32))
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
61
tensorflow/compiler/tests/nullary_ops_test.py
Normal file
61
tensorflow/compiler/tests/nullary_ops_test.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test cases for operators with no arguments."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class NullaryOpsTest(XLATestCase):
|
||||
|
||||
def _testNullary(self, op, expected):
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
output = op()
|
||||
result = session.run(output)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testNoOp(self):
|
||||
with self.test_session():
|
||||
with self.test_scope():
|
||||
output = control_flow_ops.no_op()
|
||||
# This should not crash.
|
||||
output.run()
|
||||
|
||||
def testConstants(self):
|
||||
constants = [
|
||||
np.float32(42),
|
||||
np.array([], dtype=np.float32),
|
||||
np.array([1, 2], dtype=np.float32),
|
||||
np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
|
||||
np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
|
||||
dtype=np.float32),
|
||||
np.array([[[]], [[]]], dtype=np.float32),
|
||||
np.array([[[[1]]]], dtype=np.float32),
|
||||
]
|
||||
for c in constants:
|
||||
self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
511
tensorflow/compiler/tests/pooling_ops_test.py
Normal file
511
tensorflow/compiler/tests/pooling_ops_test.py
Normal file
@ -0,0 +1,511 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functional tests for pooling operations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
def NHWCToNCHW(input_tensor):
|
||||
"""Convert the input from NHWC format to NCHW.
|
||||
|
||||
Args:
|
||||
input_tensor: a 4-D tensor, or a 4-element array representing the same.
|
||||
|
||||
Returns:
|
||||
the converted tensor or a shape array
|
||||
"""
|
||||
if isinstance(input_tensor, ops.Tensor):
|
||||
return array_ops.transpose(input_tensor, [0, 3, 1, 2])
|
||||
else:
|
||||
return [input_tensor[0], input_tensor[3], input_tensor[1], input_tensor[2]]
|
||||
|
||||
|
||||
def NCHWToNHWC(input_tensor):
|
||||
"""Convert the input from NCHW format to NHWC.
|
||||
|
||||
Args:
|
||||
input_tensor: a 4-D tensor, or a 4-element array representing the same.
|
||||
|
||||
Returns:
|
||||
the converted tensor or a shape array
|
||||
"""
|
||||
if isinstance(input_tensor, ops.Tensor):
|
||||
return array_ops.transpose(input_tensor, [0, 2, 3, 1])
|
||||
else:
|
||||
return [input_tensor[0], input_tensor[2], input_tensor[3], input_tensor[1]]
|
||||
|
||||
|
||||
def GetTestConfigs():
|
||||
"""Get all the valid tests configs to run.
|
||||
|
||||
Returns:
|
||||
all the valid test configs
|
||||
"""
|
||||
test_configs = ["NHWC", "NCHW"]
|
||||
return test_configs
|
||||
|
||||
|
||||
class PoolingTest(XLATestCase):
|
||||
|
||||
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, expected):
|
||||
"""Verifies the output values of the pooling function.
|
||||
|
||||
Args:
|
||||
pool_func: Function to be called, currently only co.MaxPool.
|
||||
input_sizes: Input tensor dimensions.
|
||||
ksize: The kernel size dimensions
|
||||
strides: The stride dimensions
|
||||
padding: Padding type.
|
||||
data_format: The data format we use to run the pooling operation.
|
||||
expected: An array containing the expected operation outputs.
|
||||
"""
|
||||
total_size = np.prod(input_sizes)
|
||||
# Initializes the input tensor with array containing incrementing
|
||||
# numbers from 1.
|
||||
x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32)
|
||||
x = x.reshape(input_sizes)
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
inputs = array_ops.placeholder(dtypes.float32)
|
||||
t = inputs
|
||||
if data_format == "NCHW":
|
||||
t = NHWCToNCHW(t)
|
||||
ksize = NHWCToNCHW(ksize)
|
||||
strides = NHWCToNCHW(strides)
|
||||
t = pool_func(t,
|
||||
ksize=ksize,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format)
|
||||
if data_format == "NCHW":
|
||||
t = NCHWToNHWC(t)
|
||||
actual = sess.run(t, {inputs: x})
|
||||
self.assertAllClose(expected, actual.flatten(), rtol=1e-5, atol=1e-6)
|
||||
|
||||
def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
|
||||
expected):
|
||||
"""Verifies the output values of the pooling function.
|
||||
|
||||
Args:
|
||||
pool_func: Function to be called, co.MaxPool, co.AvgPool,
|
||||
or the Lua version.
|
||||
input_sizes: Input tensor dimensions.
|
||||
ksize: The kernel size dimensions
|
||||
strides: The stride dimensions
|
||||
padding: Padding type.
|
||||
expected: An array containing the expected operation outputs.
|
||||
"""
|
||||
for data_format in GetTestConfigs():
|
||||
self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, expected)
|
||||
|
||||
def testMaxPoolValidPadding(self):
|
||||
expected_output = [13.0, 14.0, 15.0]
|
||||
self._VerifyValues(nn_ops.max_pool,
|
||||
input_sizes=[1, 3, 3, 3],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testMaxPoolSamePadding(self):
|
||||
expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0]
|
||||
self._VerifyValues(nn_ops.max_pool,
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
def testMaxPoolSamePaddingNonSquareWindow(self):
|
||||
# input is:
|
||||
# [1.0, 2.0
|
||||
# 3.0 4.0]
|
||||
#
|
||||
# Window of [x, x] should do:
|
||||
#
|
||||
# [max(1.0, 2.0), max(2.0, padded0),
|
||||
# max(3.0, 4.0), max(4.0, padded0)]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 2, 2, 1],
|
||||
ksize=[1, 1, 2, 1],
|
||||
strides=[1, 1, 1, 1],
|
||||
padding="SAME",
|
||||
expected=[2.0, 2.0, 4.0, 4.0])
|
||||
|
||||
def testMaxPoolValidPaddingUnevenStride(self):
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 1, 2, 1],
|
||||
padding="VALID",
|
||||
expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0])
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 1, 1],
|
||||
padding="VALID",
|
||||
expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0])
|
||||
|
||||
def testMaxPoolSamePaddingFilter4(self):
|
||||
expected_output = [
|
||||
21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0, 54.0, 55.0, 56.0,
|
||||
61.0, 62.0, 63.0, 64.0
|
||||
]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 4, 4, 4],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
def testMaxPoolSamePaddingFilter8(self):
|
||||
expected_output = [
|
||||
145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0,
|
||||
163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 177.0, 178.0, 179.0, 180.0,
|
||||
181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0,
|
||||
191.0, 192.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0,
|
||||
289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 305.0, 306.0,
|
||||
307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0,
|
||||
317.0, 318.0, 319.0, 320.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0,
|
||||
407.0, 408.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0,
|
||||
433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0,
|
||||
443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 465.0, 466.0, 467.0, 468.0,
|
||||
469.0, 470.0, 471.0, 472.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0,
|
||||
487.0, 488.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0,
|
||||
505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0
|
||||
]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 8, 8, 8],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
# Tests for DepthwiseMaxPooling on CPU only.
|
||||
def testDepthwiseMaxPool1x1DepthWindow1(self):
|
||||
# input is:
|
||||
# [1.0, ..., 10.0] along depth,
|
||||
#
|
||||
# We maxpool by depth in patches of 2.
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 1, 1, 10],
|
||||
ksize=[1, 1, 1, 2],
|
||||
strides=[1, 1, 1, 2],
|
||||
padding="SAME",
|
||||
expected=[2.0, 4.0, 6.0, 8.0, 10.0])
|
||||
|
||||
def testDepthwiseMaxPool2x2DepthWindow3(self):
|
||||
# input is:
|
||||
#
|
||||
# a 2x2x6 cube, and we depthwise max across 3 to produce a 2x2x2
|
||||
# output. Each node has contiguous values, so the depthwise max
|
||||
# should be multiples of 3.0.
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 2, 2, 6],
|
||||
ksize=[1, 1, 1, 3],
|
||||
strides=[1, 1, 1, 3],
|
||||
padding="SAME",
|
||||
expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0])
|
||||
|
||||
def testKernelSmallerThanStrideValid(self):
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 3, 3, 1],
|
||||
padding="VALID",
|
||||
expected=[9, 12, 30, 33])
|
||||
|
||||
def testKernelSmallerThanStrideSame(self):
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
ksize=[1, 1, 1, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=[1, 3, 7, 9])
|
||||
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
ksize=[1, 1, 1, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=[1, 3, 9, 11])
|
||||
|
||||
# Average pooling
|
||||
def testAvgPoolValidPadding(self):
|
||||
expected_output = [7, 8, 9]
|
||||
self._VerifyValues(
|
||||
nn_ops.avg_pool,
|
||||
input_sizes=[1, 3, 3, 3],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testAvgPoolSamePadding(self):
|
||||
expected_output = [7., 8., 9., 11.5, 12.5, 13.5]
|
||||
self._VerifyValues(
|
||||
nn_ops.avg_pool,
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=expected_output)
|
||||
|
||||
|
||||
class PoolGradTest(XLATestCase):
|
||||
|
||||
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
|
||||
def _VerifyOneTest(self, pool_func, pool_grad_func, input_sizes, ksize,
|
||||
strides, padding, data_format):
|
||||
"""Verifies the output values of the pooling gradient function.
|
||||
|
||||
Args:
|
||||
pool_func: Forward pooling function
|
||||
pool_grad_func: Pooling gradient function for pool_grad_func
|
||||
input_sizes: Input tensor dimensions.
|
||||
ksize: The kernel size dimensions
|
||||
strides: The stride dimensions
|
||||
padding: Padding type.
|
||||
data_format: The data format we use to run the pooling operation.
|
||||
"""
|
||||
total_size = np.prod(input_sizes)
|
||||
x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes)
|
||||
with self.test_session() as sess:
|
||||
# Use the forward pool function to compute some corresponding outputs
|
||||
# (needed for the CPU device, and we need the shape in both cases).
|
||||
with ops.device(self.CPU_DEVICE):
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=input_sizes)
|
||||
outputs = pool_func(
|
||||
inputs,
|
||||
ksize=ksize,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format="NHWC")
|
||||
|
||||
output_vals = np.array(sess.run(outputs, {inputs: x}))
|
||||
output_gradient_vals = np.arange(
|
||||
1, output_vals.size + 1, dtype=np.float32)
|
||||
output_gradient_vals = output_gradient_vals.reshape(output_vals.shape)
|
||||
|
||||
# Use the Tensorflow CPU pooling gradient to compute the expected input
|
||||
# gradients.
|
||||
with ops.device(self.CPU_DEVICE):
|
||||
output_gradients = array_ops.placeholder(
|
||||
dtypes.float32, shape=output_vals.shape)
|
||||
expected_input_gradients = pool_grad_func(
|
||||
inputs,
|
||||
outputs,
|
||||
output_gradients,
|
||||
ksize=ksize,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format="NHWC")
|
||||
expected_input_gradient_vals = sess.run(
|
||||
expected_input_gradients,
|
||||
{inputs: x,
|
||||
output_gradients: output_gradient_vals})
|
||||
|
||||
# Run the gradient op on the XLA device
|
||||
with self.test_scope():
|
||||
outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape)
|
||||
xla_inputs = inputs
|
||||
xla_outputs = outputs
|
||||
xla_output_gradients = output_gradients
|
||||
xla_ksize = ksize
|
||||
xla_strides = strides
|
||||
if data_format == "NCHW":
|
||||
xla_inputs = NHWCToNCHW(inputs)
|
||||
xla_outputs = NHWCToNCHW(outputs)
|
||||
xla_output_gradients = NHWCToNCHW(output_gradients)
|
||||
xla_ksize = NHWCToNCHW(ksize)
|
||||
xla_strides = NHWCToNCHW(strides)
|
||||
actual_input_gradients = pool_grad_func(
|
||||
xla_inputs,
|
||||
xla_outputs,
|
||||
xla_output_gradients,
|
||||
ksize=xla_ksize,
|
||||
strides=xla_strides,
|
||||
padding=padding,
|
||||
data_format=data_format)
|
||||
if data_format == "NCHW":
|
||||
actual_input_gradients = NCHWToNHWC(actual_input_gradients)
|
||||
actual = sess.run(actual_input_gradients, {
|
||||
inputs: x,
|
||||
outputs: output_vals,
|
||||
output_gradients: output_gradient_vals
|
||||
})
|
||||
|
||||
# Compare the Tensorflow and XLA results.
|
||||
self.assertAllClose(
|
||||
expected_input_gradient_vals.flatten(),
|
||||
actual.flatten(),
|
||||
rtol=1e-5,
|
||||
atol=1e-6)
|
||||
self.assertShapeEqual(actual, inputs)
|
||||
|
||||
def _VerifyValues(self, pool_func, pool_grad_func, input_sizes, ksize,
|
||||
strides, padding):
|
||||
"""Verifies the output values of the pooling function.
|
||||
|
||||
Args:
|
||||
pool_func: Pooling function to be called, e.g., tf.nn.max_pool
|
||||
pool_grad_func: Corresponding pooling gradient function.
|
||||
input_sizes: Input tensor dimensions.
|
||||
ksize: The kernel size dimensions
|
||||
strides: The stride dimensions
|
||||
padding: Padding type.
|
||||
"""
|
||||
for data_format in GetTestConfigs():
|
||||
self._VerifyOneTest(pool_func, pool_grad_func, input_sizes, ksize,
|
||||
strides, padding, data_format)
|
||||
|
||||
def _TestPooling(self, forward_op, backward_op):
|
||||
# VALID padding
|
||||
self._VerifyValues(
|
||||
forward_op,
|
||||
backward_op,
|
||||
input_sizes=[1, 3, 3, 3],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="VALID")
|
||||
|
||||
# SAME padding
|
||||
self._VerifyValues(
|
||||
forward_op,
|
||||
backward_op,
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME")
|
||||
|
||||
# SAME padding, non square window
|
||||
self._VerifyValues(
|
||||
forward_op,
|
||||
backward_op,
|
||||
input_sizes=[1, 2, 2, 1],
|
||||
ksize=[1, 1, 2, 1],
|
||||
strides=[1, 1, 1, 1],
|
||||
padding="SAME")
|
||||
|
||||
# VALID padding, uneven stride
|
||||
self._VerifyValues(
|
||||
forward_op,
|
||||
backward_op,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 1, 2, 1],
|
||||
padding="VALID")
|
||||
self._VerifyValues(
|
||||
forward_op,
|
||||
backward_op,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 1, 1],
|
||||
padding="VALID")
|
||||
|
||||
# SAME padding, size 4 input
|
||||
self._VerifyValues(
|
||||
forward_op,
|
||||
backward_op,
|
||||
input_sizes=[1, 4, 4, 4],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME")
|
||||
|
||||
# SAME padding, size 8 input
|
||||
self._VerifyValues(
|
||||
forward_op,
|
||||
backward_op,
|
||||
input_sizes=[1, 8, 8, 8],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME")
|
||||
|
||||
def testMaxPool(self):
|
||||
self._TestPooling(nn_ops.max_pool, gen_nn_ops._max_pool_grad)
|
||||
|
||||
def testAvgPool(self):
|
||||
# Wrapper around AvgPoolGrad that ignores extra arguments needed by
|
||||
# MaxPoolGrad.
|
||||
def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding,
|
||||
data_format):
|
||||
del outputs # Unused by average-pooling gradients.
|
||||
return gen_nn_ops._avg_pool_grad(
|
||||
inputs.get_shape().as_list(),
|
||||
output_gradients,
|
||||
ksize=ksize,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format)
|
||||
|
||||
self._TestPooling(nn_ops.avg_pool, AvgPoolGrad)
|
||||
|
||||
# The CPU implementation of AvgPoolGrad doesn't accept kernels smaller than
|
||||
# the stride size, so we only run the following tests on MaxPoolGrad.
|
||||
|
||||
def testMaxPoolKernelSmallerThanStrideValid(self):
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
gen_nn_ops._max_pool_grad,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 3, 3, 1],
|
||||
padding="VALID")
|
||||
|
||||
def testMaxPoolKernelSmallerThanStrideSame(self):
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
gen_nn_ops._max_pool_grad,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
ksize=[1, 1, 1, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME")
|
||||
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
gen_nn_ops._max_pool_grad,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
ksize=[1, 1, 1, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
2097
tensorflow/compiler/tests/randomized_tests.cc
Normal file
2097
tensorflow/compiler/tests/randomized_tests.cc
Normal file
File diff suppressed because it is too large
Load Diff
125
tensorflow/compiler/tests/reduce_ops_test.py
Normal file
125
tensorflow/compiler/tests/reduce_ops_test.py
Normal file
@ -0,0 +1,125 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for reduction operators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ReduceOpsTest(XLATestCase):
|
||||
|
||||
def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs,
|
||||
rtol=1e-4, atol=1e-4):
|
||||
"""Tests that the output of 'tf_reduce_fn' matches numpy's output."""
|
||||
|
||||
for test_input in test_inputs:
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
a = array_ops.placeholder(dtype)
|
||||
index = array_ops.placeholder(dtypes.int32)
|
||||
out = tf_reduce_fn(a, index)
|
||||
result = sess.run(out, {a: test_input, index: [0]})
|
||||
self.assertAllClose(result, np_reduce_fn(test_input, axis=0),
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
result = sess.run(out, {a: test_input, index: [1]})
|
||||
self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
result = sess.run(out, {a: test_input, index: [-1]})
|
||||
self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
|
||||
sess.run(out, {a: test_input, index: [-33]})
|
||||
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
|
||||
sess.run(out, {a: test_input, index: [2]})
|
||||
|
||||
FLOAT_DATA = [
|
||||
np.zeros(shape=(2, 0)),
|
||||
np.zeros(shape=(0, 30)),
|
||||
np.arange(1, 7).reshape(2, 3),
|
||||
np.arange(-10, -4).reshape(2, 3),
|
||||
np.arange(-4, 2).reshape(2, 3),
|
||||
]
|
||||
NONEMPTY_FLOAT_DATA = [
|
||||
np.arange(1, 7).reshape(2, 3),
|
||||
np.arange(-10, -4).reshape(2, 3),
|
||||
np.arange(-4, 2).reshape(2, 3),
|
||||
]
|
||||
BOOL_DATA = [
|
||||
np.array([], dtype=np.bool).reshape(2, 0),
|
||||
np.array([], dtype=np.bool).reshape(0, 3),
|
||||
np.array([[False, True, False], [True, True, False]]),
|
||||
]
|
||||
|
||||
def testReduceSum(self):
|
||||
self._testReduction(math_ops.reduce_sum, np.sum, np.float32,
|
||||
self.FLOAT_DATA)
|
||||
|
||||
def testReduceProd(self):
|
||||
self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
|
||||
self.FLOAT_DATA)
|
||||
|
||||
def testReduceMin(self):
|
||||
|
||||
def reference_min(inp, axis):
|
||||
"""Wrapper around np.amin that returns +infinity for an empty input."""
|
||||
if inp.shape[axis] == 0:
|
||||
return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf'))
|
||||
return np.amin(inp, axis)
|
||||
|
||||
self._testReduction(math_ops.reduce_min, reference_min, np.float32,
|
||||
self.FLOAT_DATA)
|
||||
|
||||
def testReduceMax(self):
|
||||
|
||||
def reference_max(inp, axis):
|
||||
"""Wrapper around np.amax that returns -infinity for an empty input."""
|
||||
if inp.shape[axis] == 0:
|
||||
return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf'))
|
||||
return np.amax(inp, axis)
|
||||
|
||||
self._testReduction(math_ops.reduce_max, reference_max, np.float32,
|
||||
self.FLOAT_DATA)
|
||||
|
||||
def testReduceMean(self):
|
||||
# TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
|
||||
# reducing across zero inputs.
|
||||
self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
|
||||
self.NONEMPTY_FLOAT_DATA)
|
||||
|
||||
def testReduceAll(self):
|
||||
self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA)
|
||||
|
||||
def testReduceAny(self):
|
||||
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
110
tensorflow/compiler/tests/ternary_ops_test.py
Normal file
110
tensorflow/compiler/tests/ternary_ops_test.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test cases for ternary operators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class TernaryOpsTest(XLATestCase):
|
||||
|
||||
def _testTernary(self, op, a, b, c, expected):
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
|
||||
pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
|
||||
pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c")
|
||||
output = op(pa, pb, pc)
|
||||
result = session.run(output, {pa: a, pb: b, pc: c})
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testLinspace(self):
|
||||
self._testTernary(
|
||||
math_ops.linspace,
|
||||
np.float32(1),
|
||||
np.float32(2),
|
||||
np.int32(1),
|
||||
expected=np.array([1], dtype=np.float32))
|
||||
self._testTernary(
|
||||
math_ops.linspace,
|
||||
np.float32(1),
|
||||
np.float32(4),
|
||||
np.int32(3),
|
||||
expected=np.array([1, 2.5, 4], dtype=np.float32))
|
||||
|
||||
def testRange(self):
|
||||
self._testTernary(
|
||||
math_ops.range,
|
||||
np.int32(1),
|
||||
np.int32(2),
|
||||
np.int32(1),
|
||||
expected=np.array([1], dtype=np.int32))
|
||||
self._testTernary(
|
||||
math_ops.range,
|
||||
np.int32(1),
|
||||
np.int32(7),
|
||||
np.int32(2),
|
||||
expected=np.array([1, 3, 5], dtype=np.int32))
|
||||
|
||||
def testSelect(self):
|
||||
self._testTernary(
|
||||
array_ops.where,
|
||||
np.array(0, dtype=np.bool),
|
||||
np.array(2, dtype=np.float32),
|
||||
np.array(7, dtype=np.float32),
|
||||
expected=np.array(7, dtype=np.float32))
|
||||
|
||||
self._testTernary(
|
||||
array_ops.where,
|
||||
np.array([0, 1, 1, 0], dtype=np.bool),
|
||||
np.array([1, 2, 3, 4], dtype=np.float32),
|
||||
np.array([5, 6, 7, 8], dtype=np.float32),
|
||||
expected=np.array([5, 2, 3, 8], dtype=np.float32))
|
||||
|
||||
self._testTernary(
|
||||
array_ops.where,
|
||||
np.array([0, 1, 0], dtype=np.bool),
|
||||
np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
|
||||
np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
|
||||
expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32))
|
||||
|
||||
def testSlice(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testTernary(
|
||||
array_ops.slice,
|
||||
np.array([[], [], []], dtype=dtype),
|
||||
np.array([1, 0], dtype=np.int32),
|
||||
np.array([2, 0], dtype=np.int32),
|
||||
expected=np.array([[], []], dtype=dtype))
|
||||
|
||||
self._testTernary(
|
||||
array_ops.slice,
|
||||
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype),
|
||||
np.array([0, 1], dtype=np.int32),
|
||||
np.array([2, 1], dtype=np.int32),
|
||||
expected=np.array([[2], [5]], dtype=dtype))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
346
tensorflow/compiler/tests/unary_ops_test.py
Normal file
346
tensorflow/compiler/tests/unary_ops_test.py
Normal file
@ -0,0 +1,346 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for XLA JIT compiler."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class UnaryOpsTest(XLATestCase):
|
||||
"""Test cases for unary operators."""
|
||||
|
||||
def _testUnary(self, op, inp, expected, equality_test=None):
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
pinp = array_ops.placeholder(
|
||||
dtypes.as_dtype(inp.dtype), inp.shape, name="a")
|
||||
output = op(pinp)
|
||||
result = session.run(output, {pinp: inp})
|
||||
if equality_test is None:
|
||||
equality_test = self.assertAllClose
|
||||
equality_test(result, expected, rtol=1e-3)
|
||||
|
||||
def ListsAreClose(self, result, expected, rtol):
|
||||
"""Tests closeness of two lists of floats."""
|
||||
self.assertEqual(len(result), len(expected))
|
||||
for i in range(len(result)):
|
||||
self.assertAllClose(result[i], expected[i], rtol)
|
||||
|
||||
def testAllTypeOps(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testUnary(
|
||||
array_ops.diag,
|
||||
np.array([1, 2, 3, 4], dtype=dtype),
|
||||
np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
|
||||
dtype=dtype))
|
||||
self._testUnary(
|
||||
array_ops.diag_part,
|
||||
np.arange(36).reshape([2, 3, 2, 3]).astype(dtype),
|
||||
np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
array_ops.identity,
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
expected=np.array([[-1, 1]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
array_ops.matrix_diag,
|
||||
np.array([[1, 2], [3, 4]], dtype=dtype),
|
||||
np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype))
|
||||
self._testUnary(
|
||||
array_ops.matrix_diag_part,
|
||||
np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype),
|
||||
np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
array_ops.prevent_gradient,
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
expected=np.array([[-1, 1]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
array_ops.squeeze,
|
||||
np.array([[[[[]]]]], dtype=dtype),
|
||||
expected=np.array([], dtype=dtype))
|
||||
self._testUnary(
|
||||
array_ops.squeeze,
|
||||
np.array([[[1], [2]]], dtype=dtype),
|
||||
expected=np.array([1, 2], dtype=dtype))
|
||||
self._testUnary(
|
||||
array_ops.squeeze,
|
||||
np.array([[[1]], [[2]]], dtype=dtype),
|
||||
expected=np.array([1, 2], dtype=dtype))
|
||||
self._testUnary(
|
||||
array_ops.squeeze,
|
||||
np.array([[[1, 2], [3, 4]]], dtype=dtype),
|
||||
expected=np.array([[1, 2], [3, 4]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
array_ops.stop_gradient,
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
expected=np.array([[-1, 1]], dtype=dtype))
|
||||
|
||||
def testFloatOps(self):
|
||||
for dtype in self.float_types:
|
||||
self._testUnary(
|
||||
math_ops.ceil,
|
||||
np.array([[-1.7, 1.2]], dtype=dtype),
|
||||
expected=np.array([[-1, 2]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
math_ops.exp,
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
expected=np.array([[0.36787945, 2.7182817]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
math_ops.floor,
|
||||
np.array([[-1.7, 1.2]], dtype=dtype),
|
||||
expected=np.array([[-2, 1]], dtype=dtype))
|
||||
|
||||
# Tests for tf.nn ops.
|
||||
self._testUnary(
|
||||
nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0))
|
||||
|
||||
# TODO(b/31644876): enable this test case when fixed.
|
||||
# self._testUnary(tf.nn.l2_loss, dtype(4), dtype(10))
|
||||
|
||||
self._testUnary(
|
||||
nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10))
|
||||
|
||||
self._testUnary(
|
||||
math_ops.reciprocal,
|
||||
np.array([[1, 2]], dtype=dtype),
|
||||
expected=np.array([[1, 0.5]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
math_ops.log,
|
||||
np.array([[1, 2]], dtype=dtype),
|
||||
expected=np.array([[0, 0.69314718]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
math_ops.rsqrt,
|
||||
np.array([[4, 16]], dtype=dtype),
|
||||
expected=np.array([[0.5, 0.25]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
math_ops.sigmoid,
|
||||
np.array(
|
||||
[[1, 1, 1, 1],
|
||||
[1, 2, 3, 4]],
|
||||
dtype=dtype),
|
||||
expected=np.array(
|
||||
[[0.7310586, 0.7310586, 0.7310586, 0.7310586],
|
||||
[0.7310586, 0.880797, 0.95257413, 0.98201376]],
|
||||
dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
math_ops.sqrt,
|
||||
np.array([[4, 9]], dtype=dtype),
|
||||
expected=np.array([[2, 3]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
math_ops.tanh,
|
||||
np.array(
|
||||
[[1, 1, 1, 1],
|
||||
[1, 2, 3, 4]],
|
||||
dtype=dtype),
|
||||
expected=np.array(
|
||||
[[0.76159418, 0.76159418, 0.76159418, 0.76159418],
|
||||
[0.76159418, 0.96402758, 0.99505478, 0.99932933]],
|
||||
dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
nn_ops.log_softmax,
|
||||
np.array(
|
||||
[[1, 1, 1, 1],
|
||||
[1, 2, 3, 4]],
|
||||
dtype=dtype),
|
||||
expected=np.array(
|
||||
[[-1.3862944, -1.3862944, -1.3862944, -1.3862944],
|
||||
[-3.4401896, -2.4401896, -1.4401897, -0.44018969]],
|
||||
dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
nn_ops.relu,
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
expected=np.array([[0, 1]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
nn_ops.relu6,
|
||||
np.array([[-0.05, 6.05, 5]], dtype=dtype),
|
||||
expected=np.array([[0, 6, 5]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
nn_ops.softmax,
|
||||
np.array(
|
||||
[[1, 1, 1, 1],
|
||||
[1, 2, 3, 4]],
|
||||
dtype=dtype),
|
||||
expected=np.array(
|
||||
[[0.25, 0.25, 0.25, 0.25],
|
||||
[0.032058604, 0.087144323, 0.23688284, 0.64391428]],
|
||||
dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
nn_ops.softplus,
|
||||
np.array([[-2, 0, 8]], dtype=dtype),
|
||||
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
|
||||
|
||||
def testNumericOps(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testUnary(
|
||||
math_ops.abs,
|
||||
np.array([[2, -1]], dtype=dtype),
|
||||
expected=np.array([[2, 1]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
math_ops.neg,
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
expected=np.array([[1, -1]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
math_ops.square,
|
||||
np.array([[-2, 3]], dtype=dtype),
|
||||
expected=np.array([[4, 9]], dtype=dtype))
|
||||
|
||||
self._testUnary(
|
||||
array_ops.zeros_like,
|
||||
np.array([[4, 3], [2, 1]], dtype=dtype),
|
||||
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
|
||||
|
||||
def testLogicalOps(self):
|
||||
self._testUnary(
|
||||
math_ops.logical_not,
|
||||
np.array([[True, False], [False, True]], dtype=np.bool),
|
||||
expected=np.array([[False, True], [True, False]], dtype=np.bool))
|
||||
|
||||
def testBiasAddGrad(self):
|
||||
self._testUnary(
|
||||
gen_nn_ops.bias_add_grad,
|
||||
np.array([[1., 2.], [3., 4.]], dtype=np.float32),
|
||||
expected=np.array([4., 6.], dtype=np.float32))
|
||||
|
||||
self._testUnary(lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"),
|
||||
np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]],
|
||||
dtype=np.float32),
|
||||
expected=np.array([10., 26.], dtype=np.float32))
|
||||
|
||||
def testCast(self):
|
||||
shapes = [[], [4], [2, 3], [2, 0, 4]]
|
||||
types = [dtypes.bool, dtypes.int32, dtypes.float32]
|
||||
for shape in shapes:
|
||||
for src_type in types:
|
||||
for dst_type in types:
|
||||
src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype)
|
||||
src = src.reshape(shape)
|
||||
|
||||
dst = src.astype(dst_type.as_numpy_dtype)
|
||||
self._testUnary(
|
||||
lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
|
||||
src,
|
||||
expected=dst)
|
||||
|
||||
def testInvertPermutation(self):
|
||||
self._testUnary(
|
||||
array_ops.invert_permutation,
|
||||
np.array([1, 2, 0], np.int32),
|
||||
expected=np.array([2, 0, 1], dtype=np.int32))
|
||||
|
||||
def testRank(self):
|
||||
rank_op = lambda x: array_ops.rank_internal(x, optimize=False)
|
||||
for dtype in self.numeric_types:
|
||||
self._testUnary(rank_op, dtype(7), expected=np.int32(0))
|
||||
self._testUnary(
|
||||
rank_op, np.array(
|
||||
[[], []], dtype=dtype), expected=np.int32(2))
|
||||
self._testUnary(
|
||||
rank_op, np.array(
|
||||
[-1, 1], dtype=dtype), expected=np.int32(1))
|
||||
self._testUnary(
|
||||
rank_op, np.array(
|
||||
[[-1, 1]], dtype=dtype), expected=np.int32(2))
|
||||
self._testUnary(
|
||||
rank_op,
|
||||
np.array([[-1], [1], [4]], dtype=dtype),
|
||||
expected=np.int32(2))
|
||||
|
||||
def testShape(self):
|
||||
shape_op = lambda x: array_ops.shape_internal(x, optimize=False)
|
||||
for dtype in self.numeric_types:
|
||||
self._testUnary(shape_op, dtype(7), expected=np.array([], dtype=np.int32))
|
||||
self._testUnary(
|
||||
shape_op,
|
||||
np.array([[], []], dtype=dtype),
|
||||
expected=np.array([2, 0], dtype=np.int32))
|
||||
self._testUnary(
|
||||
shape_op,
|
||||
np.array([-1, 1], dtype=dtype),
|
||||
expected=np.array([2], dtype=np.int32))
|
||||
self._testUnary(
|
||||
shape_op,
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
expected=np.array([1, 2], dtype=np.int32))
|
||||
self._testUnary(
|
||||
shape_op,
|
||||
np.array([[-1], [1], [4]], dtype=dtype),
|
||||
expected=np.array([3, 1], dtype=np.int32))
|
||||
|
||||
def testSize(self):
|
||||
size_op = lambda x: array_ops.size_internal(x, optimize=False)
|
||||
for dtype in self.numeric_types:
|
||||
self._testUnary(size_op, dtype(7), expected=np.int32(1))
|
||||
self._testUnary(
|
||||
size_op, np.array([[], []], dtype=dtype), expected=np.int32(0))
|
||||
self._testUnary(
|
||||
size_op, np.array([-1, 1], dtype=dtype), expected=np.int32(2))
|
||||
self._testUnary(
|
||||
size_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2))
|
||||
self._testUnary(
|
||||
size_op,
|
||||
np.array([[-1], [1], [4]], dtype=dtype),
|
||||
expected=np.int32(3))
|
||||
|
||||
def testUnpack(self):
|
||||
self._testUnary(
|
||||
array_ops.unpack,
|
||||
np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
|
||||
expected=[
|
||||
np.array([1., 2.], dtype=np.float32),
|
||||
np.array([3., 4.], dtype=np.float32),
|
||||
np.array([5., 6.], dtype=np.float32),
|
||||
],
|
||||
equality_test=self.ListsAreClose)
|
||||
|
||||
self._testUnary(lambda x: array_ops.unstack(x, axis=1),
|
||||
np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
|
||||
expected=[
|
||||
np.array([1., 3., 5.], dtype=np.float32),
|
||||
np.array([2., 4., 6.], dtype=np.float32),
|
||||
],
|
||||
equality_test=self.ListsAreClose)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
81
tensorflow/compiler/tests/xla_device_test.py
Normal file
81
tensorflow/compiler/tests/xla_device_test.py
Normal file
@ -0,0 +1,81 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test cases for XLA devices."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class XlaDeviceTest(test.TestCase):
|
||||
|
||||
def testCopies(self):
|
||||
"""Tests that copies between GPU and XLA devices work."""
|
||||
if not test.is_gpu_available():
|
||||
return
|
||||
|
||||
with session_lib.Session() as sess:
|
||||
x = array_ops.placeholder(dtypes.float32, [2])
|
||||
with ops.device("GPU"):
|
||||
y = x * 2
|
||||
with ops.device("device:XLA_CPU:0"):
|
||||
z = y * y
|
||||
with ops.device("GPU"):
|
||||
w = y + z
|
||||
result = sess.run(w, {x: [1.5, 0.5]})
|
||||
self.assertAllClose(result, [12., 2.], rtol=1e-3)
|
||||
|
||||
def testLoops(self):
|
||||
"""Tests that loops work on XLA devices."""
|
||||
|
||||
with session_lib.Session() as session:
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
with ops.device("device:XLA_CPU:0"):
|
||||
c = lambda i, _: math_ops.less(i, 5)
|
||||
b = lambda i, x: (i + 1, x * 2.0 + 1.0)
|
||||
_, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
|
||||
|
||||
result = session.run(y, {x: np.float32(2)})
|
||||
self.assertAllClose(result, np.float32(95), rtol=1e-3)
|
||||
|
||||
def testCond(self):
|
||||
"""Tests that tf.cond works on XLA devices."""
|
||||
|
||||
with session_lib.Session() as session:
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
y = array_ops.placeholder(dtypes.float32)
|
||||
c = array_ops.placeholder(dtypes.bool)
|
||||
with ops.device("device:XLA_CPU:0"):
|
||||
z = x + 1.0
|
||||
w = control_flow_ops.cond(c, lambda: z, lambda: y)
|
||||
t = math_ops.add(z, w)
|
||||
|
||||
result = session.run(t, {x: np.float32(2), y: np.float32(4), c: True})
|
||||
self.assertAllClose(result, np.float32(6), rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
148
tensorflow/compiler/tests/xla_test.py
Normal file
148
tensorflow/compiler/tests/xla_test.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Definition of XLA test case."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
|
||||
from tensorflow.contrib.compiler import jit
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('test_device', None,
|
||||
'Tensorflow device on which to place operators under test')
|
||||
flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.')
|
||||
flags.DEFINE_string('disabled_manifest', None,
|
||||
'Path to a file with a list of tests that should not run.')
|
||||
|
||||
|
||||
class XLATestCase(test.TestCase):
|
||||
"""XLA test cases are parameterized test cases."""
|
||||
|
||||
def __init__(self, method_name='runTest'):
|
||||
super(XLATestCase, self).__init__(method_name)
|
||||
self.device = FLAGS.test_device
|
||||
self.has_custom_call = (self.device == 'XLA_CPU')
|
||||
self.all_tf_types = [
|
||||
dtypes.DType(types_pb2.DataType.Value(name))
|
||||
for name in FLAGS.types.split(',')
|
||||
]
|
||||
self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
|
||||
self.int_types = [
|
||||
dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_integer
|
||||
]
|
||||
self.float_types = [
|
||||
dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_floating
|
||||
]
|
||||
self.numeric_types = self.int_types + self.float_types
|
||||
|
||||
# Parse the manifest file, if any, into a regex identifying tests to
|
||||
# disable
|
||||
self.disabled_regex = None
|
||||
if FLAGS.disabled_manifest is not None:
|
||||
comments_re = re.compile('#.*$')
|
||||
manifest_file = open(FLAGS.disabled_manifest, 'r')
|
||||
lines = manifest_file.read().splitlines()
|
||||
lines = [comments_re.sub('', l).strip() for l in lines]
|
||||
self.disabled_regex = re.compile('|'.join(lines))
|
||||
manifest_file.close()
|
||||
|
||||
def setUp(self):
|
||||
name = '{}.{}'.format(type(self).__name__, self._testMethodName)
|
||||
if self.disabled_regex is not None and self.disabled_regex.match(name):
|
||||
logging.info('Disabled test case: %s', name)
|
||||
self.skipTest('{} is disabled by manifest.'.format(name))
|
||||
return
|
||||
logging.info('Start test case: %s', name)
|
||||
|
||||
def tearDown(self):
|
||||
logging.info('End test case: %s', self._testMethodName)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def test_session(self):
|
||||
"""Custom implementation of test_session() for XLA tests.
|
||||
|
||||
We override the standard Tensorflow test_session() since it is too
|
||||
specific to CPU and GPU tests. In particular, we want to disable soft
|
||||
placement and explicitly assign ops to devices under test.
|
||||
|
||||
Yields:
|
||||
A session to use when running a test case.
|
||||
"""
|
||||
graph = ops.Graph()
|
||||
with session.Session(graph=graph) as sess, graph.as_default():
|
||||
yield sess
|
||||
|
||||
@contextlib.contextmanager
|
||||
def test_scope(self):
|
||||
"""Test scope that runs tests on a Tensorflow/XLA device.
|
||||
|
||||
Uses a compilation_scope() to mark operators to compile.
|
||||
|
||||
Yields:
|
||||
A scope to apply to the operators under test.
|
||||
"""
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
yield
|
||||
|
||||
|
||||
def Benchmark(tf_bench, builder_fn, use_xla_jit, device):
|
||||
"""Build a graph and run benchmarks against it, with or without XLA.
|
||||
|
||||
Args:
|
||||
tf_bench: An instance of tf.test.Benchmark, used to run the benchmark.
|
||||
builder_fn: A function that builds a graph when invoked, and returns
|
||||
(name, fetches), where name is the name of the test, and fetches
|
||||
is a list of tensors to fetch as output.
|
||||
use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF.
|
||||
device: The tensorflow device to run on, e.g. "cpu", "gpu".
|
||||
"""
|
||||
|
||||
with ops.Graph().as_default():
|
||||
name = None
|
||||
targets = []
|
||||
with ops.device(device):
|
||||
fetches = []
|
||||
jit_scope = jit.experimental_jit_scope
|
||||
with jit_scope(compile_ops=use_xla_jit):
|
||||
name, fetches = builder_fn()
|
||||
|
||||
# We only want to benchmark the operations themselves, and not the data
|
||||
# transfer of the result(s). Non-compiled identity ops ensure XLA
|
||||
# doesn't know we're dropping the results, otherwise it might compile
|
||||
# away the entire computation.
|
||||
for fetch in fetches:
|
||||
targets.append(array_ops.identity(fetch).op)
|
||||
|
||||
config = config_pb2.ConfigProto(allow_soft_placement=True)
|
||||
with session.Session(config=config) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
xla = 'xla_' if use_xla_jit else ''
|
||||
tf_bench.run_op_benchmark(
|
||||
sess, targets, name='%s_%s%s' % (name, xla, device))
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user