commit
b0d6bf3425
10
WORKSPACE
10
WORKSPACE
@ -80,3 +80,13 @@ new_http_archive(
|
||||
"http://download.tensorflow.org/models/stylize_v1.zip",
|
||||
],
|
||||
)
|
||||
|
||||
new_http_archive(
|
||||
name = "speech_commands",
|
||||
build_file = "models.BUILD",
|
||||
sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c",
|
||||
urls = [
|
||||
"http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||
"http://download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||
],
|
||||
)
|
||||
|
@ -286,6 +286,7 @@ filegroup(
|
||||
"//tensorflow/contrib/data/python/util:all_files",
|
||||
"//tensorflow/contrib/decision_trees/proto:all_files",
|
||||
"//tensorflow/contrib/distributions:all_files",
|
||||
"//tensorflow/contrib/eager/python:all_files",
|
||||
"//tensorflow/contrib/factorization:all_files",
|
||||
"//tensorflow/contrib/factorization/kernels:all_files",
|
||||
"//tensorflow/contrib/ffmpeg:all_files",
|
||||
|
@ -1931,6 +1931,8 @@ bool CopyGraph(TF_Graph* src_graph, TF_Graph* dst_graph,
|
||||
TF_ImportGraphDefOptionsAddInputMapping(opts.get(), src.first.data(),
|
||||
src.second, dst_inputs[i]);
|
||||
}
|
||||
opts.get()->opts.skip_mapped_nodes = true;
|
||||
|
||||
// We use the pivot node to control constants in `src_graph`
|
||||
TF_Operation* pivot = dst_inputs[0].oper;
|
||||
TF_ImportGraphDefOptionsAddControlDependency(opts.get(), pivot);
|
||||
|
@ -64,19 +64,14 @@ struct TFE_Context {
|
||||
// One FunctionLibraryRuntime per device.
|
||||
// func_libs[i] is the FunctionLibraryRuntime corresponding to
|
||||
// session->devices[i].
|
||||
std::vector<std::unique_ptr<tensorflow::FunctionLibraryRuntime> > func_libs;
|
||||
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
|
||||
|
||||
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
|
||||
tensorflow::Fprint128Hasher>
|
||||
kernel_cache;
|
||||
|
||||
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
|
||||
for (int i = 0; i < session->devices.size(); ++i) {
|
||||
if (session->devices[i] == d) {
|
||||
return func_libs[i].get();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
return pflr->GetFLR(d->name());
|
||||
}
|
||||
|
||||
const std::vector<tensorflow::Device*>& devices() { return session->devices; }
|
||||
@ -132,12 +127,9 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
|
||||
}
|
||||
|
||||
TFE_Context* ret = new TFE_Context(session);
|
||||
ret->func_libs.resize(ret->devices().size());
|
||||
for (int i = 0; i < ret->devices().size(); ++i) {
|
||||
ret->func_libs[i] = tensorflow::NewFunctionLibraryRuntime(
|
||||
ret->session->device_mgr, opts->options.env, ret->devices()[i],
|
||||
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {});
|
||||
}
|
||||
ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime(
|
||||
ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION,
|
||||
&ret->func_lib_def, {}));
|
||||
ret->rendezvous =
|
||||
new tensorflow::IntraProcessRendezvous(ret->session->device_mgr);
|
||||
|
||||
|
@ -174,6 +174,16 @@ TEST_F(CApiWhileLoopTest, BasicLoop) {
|
||||
EXPECT_TRUE(outputs_[1].oper != nullptr);
|
||||
EXPECT_GE(outputs_[1].index, 0);
|
||||
|
||||
// Check that cond and body inputs are not present
|
||||
for (int i = 0; i < params_->ninputs; ++i) {
|
||||
string cond_name =
|
||||
::tensorflow::strings::StrCat(params_->name, "/cond/cond_input", i);
|
||||
string body_name =
|
||||
::tensorflow::strings::StrCat(params_->name, "/body/body_input", i);
|
||||
EXPECT_TRUE(TF_GraphOperationByName(graph_, cond_name.c_str()) == nullptr);
|
||||
EXPECT_TRUE(TF_GraphOperationByName(graph_, body_name.c_str()) == nullptr);
|
||||
}
|
||||
|
||||
// Run the graph
|
||||
Run({-9, 2});
|
||||
ExpectOutputValue(0, 3);
|
||||
|
@ -481,6 +481,42 @@ Status AddNGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("AddN", AddNGrad);
|
||||
|
||||
// MaximumMinimumGradCommon adds shared ops to calculate gradients for
|
||||
// the binary Maximum and Minimum ops.
|
||||
Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs,
|
||||
const Output& comparator) {
|
||||
// comparator is a boolean tensor, with
|
||||
// y = x_1 at points where comparator is true, and x_2 otherwise
|
||||
// Therefore
|
||||
// dy/dx_1 = 1 where comparator is true, and 0 otherwise.
|
||||
// dy/dx_2 = 0 where comparator is true, and 1 otherwise.
|
||||
auto grad = grad_inputs[0];
|
||||
auto zeros = ZerosLike(scope, grad);
|
||||
auto gx_1 = Where3(scope, comparator, grad, zeros);
|
||||
auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros);
|
||||
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
|
||||
}
|
||||
|
||||
Status MaximumGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
|
||||
return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
|
||||
comparator);
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
|
||||
|
||||
Status MinimumGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto comparator = LessEqual(scope, op.input(0), op.input(1));
|
||||
return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
|
||||
comparator);
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
|
||||
|
||||
Status RealGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
|
@ -925,6 +925,15 @@ class NaryGradTest : public ::testing::Test {
|
||||
EXPECT_LT(max_error, 1e-3);
|
||||
}
|
||||
|
||||
void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
|
||||
const TensorShape& y_shape) {
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-3);
|
||||
}
|
||||
|
||||
Scope scope_;
|
||||
};
|
||||
|
||||
@ -993,5 +1002,27 @@ TEST_F(NaryGradTest, SquaredDifference) {
|
||||
RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, Maximum) {
|
||||
TensorShape shape({3, 2});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Maximum(scope_, x, Const(scope_, 1.0f));
|
||||
// Select values away from 1.0f to avoid instability when computing
|
||||
// finite differences.
|
||||
Tensor x_init_value =
|
||||
test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, Minimum) {
|
||||
TensorShape shape({3, 2});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Minimum(scope_, x, Const(scope_, 1.0f));
|
||||
// Select values away from 1.0f to avoid instability when computing
|
||||
// finite differences.
|
||||
Tensor x_init_value =
|
||||
test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -178,6 +178,9 @@ def tf_library(name, graph, config,
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/aot:runtime",
|
||||
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
|
||||
|
@ -624,15 +624,18 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
FunctionLibraryDefinition* const library = options.flib_def;
|
||||
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<FunctionLibraryRuntime> flr(
|
||||
NewFunctionLibraryRuntime(nullptr, options.session_options->env, nullptr,
|
||||
TF_GRAPH_DEF_VERSION, library, opts));
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env,
|
||||
TF_GRAPH_DEF_VERSION, library, opts));
|
||||
FunctionLibraryRuntime* flr =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
|
||||
auto rewrite_subgraph = [&flr](
|
||||
std::unique_ptr<Graph>* subgraph, std::vector<int>* input_permutation,
|
||||
std::vector<int>* output_permutation, NodeDef* node) {
|
||||
auto rewrite_subgraph = [flr](std::unique_ptr<Graph>* subgraph,
|
||||
std::vector<int>* input_permutation,
|
||||
std::vector<int>* output_permutation,
|
||||
NodeDef* node) {
|
||||
// Optimize the subgraph.
|
||||
OptimizeGraph(flr.get(), subgraph);
|
||||
OptimizeGraph(flr, subgraph);
|
||||
|
||||
const int num_args = input_permutation->size();
|
||||
std::vector<bool> const_args(num_args);
|
||||
|
@ -176,8 +176,11 @@ Status FindCompilationCandidates(
|
||||
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
|
||||
std::unordered_set<Node*>* candidates) {
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<FunctionLibraryRuntime> lib_runtime(NewFunctionLibraryRuntime(
|
||||
nullptr, env, nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts));
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
|
||||
flib_def, opts));
|
||||
FunctionLibraryRuntime* lib_runtime =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
|
||||
for (Node* node : graph.op_nodes()) {
|
||||
DeviceType device_type("");
|
||||
@ -191,7 +194,7 @@ Status FindCompilationCandidates(
|
||||
XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration));
|
||||
DeviceType jit_device_type(registration->compilation_device_name);
|
||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) {
|
||||
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) {
|
||||
VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
|
||||
<< ": " << node->type_string();
|
||||
continue;
|
||||
@ -203,7 +206,7 @@ Status FindCompilationCandidates(
|
||||
continue;
|
||||
}
|
||||
if (node->type_string() == "While" &&
|
||||
!IsCompilableWhile(*node, jit_device_type, 0, lib_runtime.get())) {
|
||||
!IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) {
|
||||
continue;
|
||||
}
|
||||
candidates->insert(node);
|
||||
|
@ -63,6 +63,39 @@ class FusedBatchNormTest(XLATestCase):
|
||||
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
|
||||
return grad_x, grad_scale, grad_offset
|
||||
|
||||
def testInference(self):
|
||||
x_shape = [2, 2, 6, 2]
|
||||
scale_shape = [2]
|
||||
x_val = np.random.random_sample(x_shape).astype(np.float32)
|
||||
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
|
||||
offset_val = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
data_format = "NHWC"
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
# To avoid constant folding
|
||||
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
|
||||
scale = array_ops.placeholder(np.float32, shape=[2], name="scale")
|
||||
offset = array_ops.placeholder(np.float32, shape=[2], name="offset")
|
||||
epsilon = 0.001
|
||||
y_ref, mean_ref, var_ref = self._reference_training(
|
||||
x_val, scale_val, offset_val, epsilon, data_format)
|
||||
y, mean, variance = nn.fused_batch_norm(
|
||||
t_val,
|
||||
scale,
|
||||
offset,
|
||||
mean=mean_ref,
|
||||
variance=var_ref,
|
||||
epsilon=epsilon,
|
||||
data_format=data_format,
|
||||
is_training=False)
|
||||
|
||||
y_val, _, _ = sess.run(
|
||||
[y, mean,
|
||||
variance], {t_val: x_val,
|
||||
scale: scale_val,
|
||||
offset: offset_val})
|
||||
self.assertAllClose(y_val, y_ref, atol=1e-3)
|
||||
|
||||
def _testLearning(self, use_gradient_checker):
|
||||
x_shape = [2, 2, 6, 2]
|
||||
scale_shape = [2]
|
||||
|
@ -260,6 +260,13 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[-2, 0, 8]], dtype=dtype),
|
||||
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_finite,
|
||||
np.array(
|
||||
[[42, float("inf"), -123], [float("nan"), 0, -0.0]], dtype=dtype),
|
||||
expected=np.array(
|
||||
[[True, False, True], [False, True, True]], dtype=np.bool))
|
||||
|
||||
def testNumericOps(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
|
@ -31,6 +31,7 @@ tf_kernel_library(
|
||||
"function_ops.cc",
|
||||
"gather_op.cc",
|
||||
"identity_op.cc",
|
||||
"is_finite_op.cc",
|
||||
"l2loss_op.cc",
|
||||
"lrn_ops.cc",
|
||||
"matmul_op.cc",
|
||||
|
@ -39,28 +39,36 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
errors::InvalidArgument("Not supported format"));
|
||||
feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format);
|
||||
}
|
||||
// TODO(b/62843645): Implement BatchNormInference.
|
||||
OP_REQUIRES(
|
||||
ctx, is_training_,
|
||||
errors::InvalidArgument("Fused batch normalization for inference is "
|
||||
"not supported yet on XLA backend."));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
|
||||
ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, feature_index_);
|
||||
if (is_training_) {
|
||||
xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
|
||||
ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_,
|
||||
feature_index_);
|
||||
|
||||
// In training mode, outputs the normalized value as well as the calculated
|
||||
// mean and variance.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
|
||||
// In training mode, outputs the normalized value as well as the
|
||||
// calculated mean and variance.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
|
||||
}
|
||||
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
|
||||
// space 1 & 2". They are used to pass the per-batch mean and
|
||||
// variance to the gradient. Here we maintain the same behavior by setting
|
||||
// them to the mean and variance calculated by BatchNormTraining.
|
||||
ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
|
||||
ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
|
||||
} else {
|
||||
xla::ComputationDataHandle output = ctx->builder()->BatchNormInference(
|
||||
ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3),
|
||||
ctx->Input(4), epsilon_, feature_index_);
|
||||
ctx->SetOutput(0, output);
|
||||
// Directly send input to output as mean and variance in inference mode.
|
||||
ctx->SetOutput(1, ctx->Input(3));
|
||||
ctx->SetOutput(2, ctx->Input(4));
|
||||
ctx->SetOutput(3, ctx->Input(3));
|
||||
ctx->SetOutput(4, ctx->Input(4));
|
||||
}
|
||||
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
|
||||
// space 1 & 2". They are used to pass the per-batch mean and
|
||||
// variance to the gradient. Here we maintain the same behavior by setting
|
||||
// them to the mean and variance calculated by BatchNormTraining.
|
||||
ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
|
||||
ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
|
||||
}
|
||||
|
||||
private:
|
||||
|
43
tensorflow/compiler/tf2xla/kernels/is_finite_op.cc
Normal file
43
tensorflow/compiler/tf2xla/kernels/is_finite_op.cc
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class IsFiniteOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit IsFiniteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle input = ctx->Input(0);
|
||||
ctx->SetOutput(0, ctx->builder()->IsFinite(input));
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IsFiniteOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("IsFinite"), IsFiniteOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
@ -139,7 +139,7 @@ xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice(
|
||||
out_index);
|
||||
|
||||
auto ip1 = bodyb.Add(i, bodyb.ConstantR0<int32>(1));
|
||||
bodyb.Tuple({ip1, data, indices_1d, updated_output});
|
||||
bodyb.Tuple({ip1, data, idcs, updated_output});
|
||||
}
|
||||
auto body_status = bodyb.Build();
|
||||
// TF_CHECK_OK(body_status);
|
||||
|
@ -88,15 +88,18 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||
}
|
||||
|
||||
local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
|
||||
|
||||
FunctionDefLibrary{}));
|
||||
local_flib_runtime_ = NewFunctionLibraryRuntime(
|
||||
&device_mgr_, Env::Default(), device_, options.graph_def_version,
|
||||
local_pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
&device_mgr_, Env::Default(), options.graph_def_version,
|
||||
local_flib_def_.get(), OptimizerOptions(),
|
||||
nullptr /* custom_kernel_creator */);
|
||||
flib_runtime_ = NewFunctionLibraryRuntime(
|
||||
&device_mgr_, Env::Default(), device_, options.graph_def_version,
|
||||
options.flib_def, OptimizerOptions(),
|
||||
nullptr /* custom_kernel_creator */);
|
||||
nullptr /* custom_kernel_creator */));
|
||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
&device_mgr_, Env::Default(), options.graph_def_version, options.flib_def,
|
||||
OptimizerOptions(), nullptr /* custom_kernel_creator */));
|
||||
|
||||
local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
|
||||
flib_runtime_ = pflr_->GetFLR(device_->name());
|
||||
}
|
||||
|
||||
XlaCompiler::~XlaCompiler() = default;
|
||||
@ -137,8 +140,8 @@ Status XlaCompiler::CompileFunction(
|
||||
}
|
||||
|
||||
const FunctionBody* fbody;
|
||||
if (!GetFunctionBody(function, local_flib_runtime_.get(), &fbody).ok()) {
|
||||
TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_.get(), &fbody));
|
||||
if (!GetFunctionBody(function, local_flib_runtime_, &fbody).ok()) {
|
||||
TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_, &fbody));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args));
|
||||
@ -159,7 +162,7 @@ Status XlaCompiler::CompileFunction(
|
||||
opts.set_do_function_inlining(true);
|
||||
opts.set_do_constant_folding(true);
|
||||
GraphOptimizer optimizer(opts);
|
||||
optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(),
|
||||
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
||||
/*device=*/nullptr, &graph, /*shape_map=*/nullptr);
|
||||
|
||||
VLOG(1) << "====================================================";
|
||||
@ -464,7 +467,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
context->set_args(std::move(context_args));
|
||||
|
||||
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
|
||||
flib_runtime_.get(), NextStepId()));
|
||||
flib_runtime_, NextStepId()));
|
||||
|
||||
int num_nonconst_outputs;
|
||||
int num_computation_outputs;
|
||||
|
@ -276,7 +276,7 @@ class XlaCompiler {
|
||||
xla::Client* client() const { return options_.client; }
|
||||
XlaCompilationDevice* device() const { return device_; }
|
||||
const DeviceMgr* device_mgr() const { return &device_mgr_; }
|
||||
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_.get(); }
|
||||
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
|
||||
|
||||
// Retrieves the channel handle associated with `key`. Allocates
|
||||
// a new channel handle if none exists.
|
||||
@ -303,9 +303,11 @@ class XlaCompiler {
|
||||
// library and runtime for functions created as part of the functionalize
|
||||
// control flow transformation.
|
||||
std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
|
||||
std::unique_ptr<FunctionLibraryRuntime> local_flib_runtime_;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
|
||||
|
||||
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
|
||||
FunctionLibraryRuntime* local_flib_runtime_; // owned by local_pflr_.
|
||||
FunctionLibraryRuntime* flib_runtime_; // owned by pflr_.
|
||||
|
||||
struct SignatureHash {
|
||||
uint64 operator()(
|
||||
|
@ -1477,9 +1477,29 @@ ComputationDataHandle ComputationBuilder::BatchNormInference(
|
||||
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
|
||||
const ComputationDataHandle& offset, const ComputationDataHandle& mean,
|
||||
const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
|
||||
// TODO(b/62843645): Implement BatchNormInference.
|
||||
NoteError(Unimplemented("BatchNormInference is not implemented yet."));
|
||||
return ComputationDataHandle();
|
||||
if (!first_error_.ok() || !PrepareComputation().ok()) {
|
||||
return ComputationDataHandle();
|
||||
}
|
||||
BatchNormInferenceRequest request;
|
||||
*request.mutable_operand() = operand;
|
||||
*request.mutable_scale() = scale;
|
||||
*request.mutable_offset() = offset;
|
||||
*request.mutable_mean() = mean;
|
||||
*request.mutable_variance() = variance;
|
||||
request.set_epsilon(epsilon);
|
||||
request.set_feature_index(feature_index);
|
||||
|
||||
OpRequest op_request;
|
||||
*op_request.mutable_batch_norm_inference_request() = request;
|
||||
*op_request.mutable_computation() = computation_.handle();
|
||||
AddOpMetadata(&op_request);
|
||||
|
||||
OpResponse response;
|
||||
|
||||
VLOG(2) << "making BatchNormInference request";
|
||||
|
||||
Status s = client_->stub()->Op(&op_request, &response);
|
||||
return ParseOpResponse(s, &response);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::BatchNormGrad(
|
||||
|
@ -48,38 +48,56 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
std::vector<std::unique_ptr<Literal>> elements;
|
||||
for (const Shape& element_shape : shape.tuple_shapes()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
|
||||
MakeFakeLiteral(element_shape));
|
||||
elements.push_back(std::move(element));
|
||||
}
|
||||
return Literal::MakeTupleOwned(std::move(elements));
|
||||
}
|
||||
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
|
||||
std::minstd_rand0 engine;
|
||||
switch (shape.element_type()) {
|
||||
case F32: {
|
||||
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
|
||||
TF_CHECK_OK(literal->Populate<float>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
case S32: {
|
||||
std::uniform_int_distribution<int32> generator(
|
||||
std::numeric_limits<int32>::lowest(),
|
||||
std::numeric_limits<int32>::max());
|
||||
TF_CHECK_OK(literal->Populate<int32>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return Unimplemented("Unsupported type for fake literal generation: %s",
|
||||
ShapeUtil::HumanString(shape).c_str());
|
||||
}
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
|
||||
Client* client) {
|
||||
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {
|
||||
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
|
||||
std::minstd_rand0 engine;
|
||||
switch (shape.element_type()) {
|
||||
case F32: {
|
||||
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
|
||||
TF_CHECK_OK(literal->Populate<float>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
case S32: {
|
||||
std::uniform_int_distribution<int32> generator(
|
||||
std::numeric_limits<int32>::lowest(),
|
||||
std::numeric_limits<int32>::max());
|
||||
TF_CHECK_OK(literal->Populate<int32>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(WARNING)
|
||||
<< "Unsupported type for host-side fake data generation: "
|
||||
<< ShapeUtil::HumanString(shape)
|
||||
<< "; falling back to making small amount of fake data via device.";
|
||||
return MakeFakeDataViaDeviceOrDie(shape, client);
|
||||
StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
|
||||
if (!literal_status.ok()) {
|
||||
// If we got an Unimplemented error, fall back to making the fake data via
|
||||
// an on-device computation.
|
||||
CHECK_EQ(literal_status.status().code(),
|
||||
tensorflow::error::UNIMPLEMENTED);
|
||||
return MakeFakeDataViaDeviceOrDie(shape, client);
|
||||
}
|
||||
return client->TransferToServer(*literal).ValueOrDie();
|
||||
return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie();
|
||||
}
|
||||
|
||||
// If the data is large, generate it on-device.
|
||||
|
@ -26,6 +26,10 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Generates fake data in a literal of the given shape, or returns an error
|
||||
// status if the element type is currently unhandled for fake data generation.
|
||||
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
|
||||
|
||||
// Generates fake data of the given shape on the device or dies. The fake data
|
||||
// is created by performing a computation on the device rather than transferring
|
||||
// data from the host to the device.
|
||||
|
@ -1101,7 +1101,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::
|
||||
VLOG(4) << " old user: " << user->ToString();
|
||||
CHECK_EQ(user->operand(reshape_or_broadcast_operand_index),
|
||||
reshape_or_broadcast);
|
||||
std::vector<HloInstruction*> new_user_operands = user->operands();
|
||||
auto new_user_operands = user->operands();
|
||||
new_user_operands[reshape_or_broadcast_operand_index] = operand;
|
||||
auto new_user = computation_->AddInstruction(user->CloneWithNewOperands(
|
||||
ShapeUtil::MakeShapeWithLayout(
|
||||
@ -1505,9 +1505,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
// We cannot insert bitcasts if the layouts will not be compatible.
|
||||
// TODO(b/33178038): Consider inserting a transpose if a bitcast would be
|
||||
// invalid.
|
||||
if (!valid_bitcast_callback_(input_shape, lhs->shape()) ||
|
||||
!valid_bitcast_callback_(new_filter_shape, rhs->shape()) ||
|
||||
!valid_bitcast_callback_(convolution_shape, dot_output_shape)) {
|
||||
if (!valid_bitcast_callback_(input_shape, new_input_shape) ||
|
||||
!valid_bitcast_callback_(filter_shape, new_filter_shape) ||
|
||||
!valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -56,11 +56,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
|
||||
|
||||
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
|
||||
|
||||
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
|
||||
|
||||
// Runs the visitor on a computation.
|
||||
static bool Run(HloComputation* computation, bool rewrite_training_op,
|
||||
bool rewrite_grad_op, bool use_fusion);
|
||||
bool rewrite_inference_op, bool rewrite_grad_op,
|
||||
bool use_fusion);
|
||||
|
||||
// Returns whether any batch norm ops were rewritten.
|
||||
const bool changed() const { return changed_; }
|
||||
@ -70,9 +73,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
private:
|
||||
explicit BatchNormRewriterVisitor(HloComputation* computation,
|
||||
bool rewrite_training_op,
|
||||
bool rewrite_inference_op,
|
||||
bool rewrite_grad_op, bool use_fusion)
|
||||
: computation_(computation),
|
||||
rewrite_training_op_(rewrite_training_op),
|
||||
rewrite_inference_op_(rewrite_inference_op),
|
||||
rewrite_grad_op_(rewrite_grad_op),
|
||||
use_fusion_(use_fusion) {}
|
||||
|
||||
@ -94,6 +99,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
HloComputation* computation_;
|
||||
|
||||
bool rewrite_training_op_;
|
||||
bool rewrite_inference_op_;
|
||||
bool rewrite_grad_op_;
|
||||
bool use_fusion_;
|
||||
|
||||
@ -126,11 +132,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
bool BatchNormRewriterVisitor::Run(HloComputation* computation,
|
||||
bool rewrite_training_op,
|
||||
bool rewrite_inference_op,
|
||||
bool rewrite_grad_op, bool use_fusion) {
|
||||
BatchNormRewriterVisitor visitor(computation,
|
||||
/*rewrite_training_op=*/rewrite_training_op,
|
||||
/*rewrite_grad_op=*/rewrite_grad_op,
|
||||
/*use_fusion=*/use_fusion);
|
||||
BatchNormRewriterVisitor visitor(
|
||||
computation,
|
||||
/*rewrite_training_op=*/rewrite_training_op,
|
||||
/*rewrite_inference_op=*/rewrite_inference_op,
|
||||
/*rewrite_grad_op=*/rewrite_grad_op,
|
||||
/*use_fusion=*/use_fusion);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
@ -268,6 +277,82 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchNormRewriterVisitor::HandleBatchNormInference(
|
||||
HloInstruction* batch_norm) {
|
||||
if (!rewrite_inference_op_) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Expand batch norm inference into smaller HLO ops.
|
||||
HloInstruction* operand = batch_norm->mutable_operand(0);
|
||||
const Shape operand_shape = operand->shape();
|
||||
int64 feature_index = batch_norm->feature_index();
|
||||
|
||||
HloInstruction* scale = batch_norm->mutable_operand(1);
|
||||
HloInstruction* offset = batch_norm->mutable_operand(2);
|
||||
HloInstruction* mean = batch_norm->mutable_operand(3);
|
||||
HloInstruction* var = batch_norm->mutable_operand(4);
|
||||
const Shape feature_shape = scale->shape();
|
||||
|
||||
auto epsilon = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
|
||||
|
||||
std::vector<int64> dimensions_without_feature;
|
||||
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
|
||||
if (i != feature_index) {
|
||||
dimensions_without_feature.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto scale_broadcasted = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
|
||||
|
||||
auto offset_broadcasted = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
|
||||
|
||||
auto mean_broadcasted = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
|
||||
|
||||
auto var_broadcasted = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
|
||||
|
||||
// Var[X] + epsilon.
|
||||
auto var_add_epsilon =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
|
||||
|
||||
auto neg_half = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
|
||||
|
||||
// 1 / Sqrt[Var[X] + epsilon].
|
||||
auto rsqrt_var_add_epsilon =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
|
||||
|
||||
// X - E[X].
|
||||
auto operand_minus_mean =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
|
||||
|
||||
// (X - E[X]) / Sqrt[Var[X] + epsilon].
|
||||
auto normalized = computation_->AddInstruction(
|
||||
HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
|
||||
operand_minus_mean, rsqrt_var_add_epsilon));
|
||||
|
||||
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
|
||||
auto scaled_normalized =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
|
||||
|
||||
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
|
||||
auto shifted_normalized = HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
|
||||
|
||||
TF_CHECK_OK(
|
||||
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchNormRewriterVisitor::HandleBatchNormGrad(
|
||||
HloInstruction* batch_norm) {
|
||||
// Use the following formulas to calculate gradients:
|
||||
@ -457,7 +542,8 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
|
||||
}
|
||||
for (auto& comp : computations) {
|
||||
if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_,
|
||||
rewrite_grad_op_, use_fusion_)) {
|
||||
rewrite_inference_op_, rewrite_grad_op_,
|
||||
use_fusion_)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
@ -30,8 +30,10 @@ class BatchNormRewriter : public HloPassInterface {
|
||||
public:
|
||||
// When use_fusion is set, a multi-output fusion node is created.
|
||||
BatchNormRewriter(bool rewrite_training_op = false,
|
||||
bool rewrite_inference_op = false,
|
||||
bool rewrite_grad_op = false, bool use_fusion = true)
|
||||
: rewrite_training_op_(rewrite_training_op),
|
||||
rewrite_inference_op_(rewrite_inference_op),
|
||||
rewrite_grad_op_(rewrite_grad_op),
|
||||
use_fusion_(use_fusion) {}
|
||||
~BatchNormRewriter() = default;
|
||||
@ -43,6 +45,7 @@ class BatchNormRewriter : public HloPassInterface {
|
||||
|
||||
private:
|
||||
bool rewrite_training_op_;
|
||||
bool rewrite_inference_op_;
|
||||
bool rewrite_grad_op_;
|
||||
bool use_fusion_;
|
||||
};
|
||||
|
@ -64,6 +64,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) {
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining);
|
||||
BatchNormRewriter rewriter(/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true);
|
||||
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
|
||||
root = computation->root_instruction();
|
||||
@ -105,6 +106,7 @@ TEST_F(BatchNormRewriterTest, BatchNormGrad) {
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad);
|
||||
BatchNormRewriter rewriter(/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true);
|
||||
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
|
||||
root = computation->root_instruction();
|
||||
|
@ -1156,8 +1156,7 @@ void BufferAssigner::AddWhileSetToColocatedBufferSets(
|
||||
// predecessor set, and must be unambiguous.
|
||||
const PointsToSet& init_points_to =
|
||||
points_to_analysis.GetPointsToSet(instruction->operand(0));
|
||||
const std::vector<const LogicalBuffer*>& init_buffers =
|
||||
init_points_to.element(buffer->index());
|
||||
const auto& init_buffers = init_points_to.element(buffer->index());
|
||||
CHECK_EQ(init_buffers.size(), 1);
|
||||
CHECK_GT(predecessor_set.count(init_buffers[0]), 0);
|
||||
predecessor_while_buffers.push_back(init_buffers[0]);
|
||||
@ -1220,8 +1219,8 @@ const LogicalBuffer* AddBufferToColocatedSet(
|
||||
std::vector<const LogicalBuffer*>* colocated_set) {
|
||||
// CopyInsertion ensures root points-to set is unambiguous and distinct.
|
||||
const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
|
||||
CHECK(!points_to.IsAmbiguous());
|
||||
CHECK(points_to.IsDistinct());
|
||||
DCHECK(!points_to.IsAmbiguous());
|
||||
DCHECK(points_to.IsDistinct());
|
||||
colocated_set->push_back(points_to.element(index)[0]);
|
||||
return colocated_set->back();
|
||||
}
|
||||
|
@ -306,7 +306,7 @@ class BufferAssignment {
|
||||
|
||||
// Returns the set LogicalBuffers which may be the source of the value at the
|
||||
// given index and instruction.
|
||||
const std::vector<const LogicalBuffer*>& GetSourceBuffers(
|
||||
const PointsToSet::BufferList& GetSourceBuffers(
|
||||
const HloInstruction* instruction, const ShapeIndex& index) const {
|
||||
return GetPointsToSet(instruction).element(index);
|
||||
}
|
||||
|
@ -54,8 +54,7 @@ class BufferLiveness {
|
||||
bool MaybeLiveOut(const LogicalBuffer& buffer) const;
|
||||
|
||||
// Returns the complete set of buffers that may be live out of the module.
|
||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& maybe_live_out_buffers()
|
||||
const {
|
||||
const PointsToSet::BufferSet& maybe_live_out_buffers() const {
|
||||
return maybe_live_out_buffers_;
|
||||
}
|
||||
|
||||
@ -106,7 +105,7 @@ class BufferLiveness {
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> aliased_buffers_;
|
||||
|
||||
// LogicalBuffers that may be live out of the entry computation.
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers_;
|
||||
PointsToSet::BufferSet maybe_live_out_buffers_;
|
||||
|
||||
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
|
||||
};
|
||||
|
@ -37,10 +37,9 @@ class BufferLivenessTest : public HloTestBase {
|
||||
const LogicalBuffer& GetBuffer(const BufferLiveness& liveness,
|
||||
const HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
const std::vector<const LogicalBuffer*>& pointed_to =
|
||||
liveness.points_to_analysis()
|
||||
.GetPointsToSet(instruction)
|
||||
.element(index);
|
||||
const auto& pointed_to = liveness.points_to_analysis()
|
||||
.GetPointsToSet(instruction)
|
||||
.element(index);
|
||||
CHECK_EQ(1, pointed_to.size());
|
||||
CHECK_EQ(instruction, pointed_to[0]->instruction());
|
||||
CHECK(index == pointed_to[0]->index());
|
||||
@ -72,9 +71,9 @@ class BufferLivenessTest : public HloTestBase {
|
||||
ShapeUtil::GetSubshape(b->shape(), index)));
|
||||
// Lookup PointsTo set for instructions 'a' and 'b'.
|
||||
auto& points_to_analysis = liveness.points_to_analysis();
|
||||
const std::vector<const LogicalBuffer*>& points_to_a =
|
||||
const auto& points_to_a =
|
||||
points_to_analysis.GetPointsToSet(a).element(index);
|
||||
const std::vector<const LogicalBuffer*>& points_to_b =
|
||||
const auto& points_to_b =
|
||||
points_to_analysis.GetPointsToSet(b).element(index);
|
||||
// Make sure PointsTo sets for 'a' and 'b' are unambiguous.
|
||||
EXPECT_EQ(1, points_to_a.size());
|
||||
@ -435,8 +434,9 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// Create param0 Tuple.
|
||||
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}),
|
||||
0,
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}),
|
||||
"param0"));
|
||||
// Create independent computations for each tuple elememt.
|
||||
|
||||
@ -498,8 +498,9 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// Create param0 Tuple.
|
||||
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}),
|
||||
0,
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}),
|
||||
"param0"));
|
||||
// Create dependent computations for each tuple elememt.
|
||||
|
||||
|
@ -187,27 +187,25 @@ Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant(
|
||||
|
||||
// Multiple buffers within a parameter/constant may be live out, so collect
|
||||
// a set of indices at which to copy first.
|
||||
points_to.ForEachElement(
|
||||
[this](const ShapeIndex& index,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
if (IsReadOnlyIndex(index)) {
|
||||
return;
|
||||
}
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
// pointee is the HloInstruction producing the buffer which may be
|
||||
// liveout.
|
||||
HloInstruction* pointee = buffer->instruction();
|
||||
if (pointee->opcode() == HloOpcode::kParameter ||
|
||||
pointee->opcode() == HloOpcode::kConstant) {
|
||||
VLOG(2) << "Parameter or constant buffer " << buffer->ToString()
|
||||
<< " index: " << tensorflow::str_util::Join(index, ",")
|
||||
<< " may be live out of computation: "
|
||||
<< pointee->ToString();
|
||||
RecordIndex(index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
points_to.ForEachElement([this](const ShapeIndex& index,
|
||||
const PointsToSet::BufferList& buffers) {
|
||||
if (IsReadOnlyIndex(index)) {
|
||||
return;
|
||||
}
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
// pointee is the HloInstruction producing the buffer which may be
|
||||
// liveout.
|
||||
HloInstruction* pointee = buffer->instruction();
|
||||
if (pointee->opcode() == HloOpcode::kParameter ||
|
||||
pointee->opcode() == HloOpcode::kConstant) {
|
||||
VLOG(2) << "Parameter or constant buffer " << buffer->ToString()
|
||||
<< " index: " << tensorflow::str_util::Join(index, ",")
|
||||
<< " may be live out of computation: " << pointee->ToString();
|
||||
RecordIndex(index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -230,8 +228,7 @@ Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
|
||||
buffer_to_source_indices;
|
||||
points_to.ForEachElement(
|
||||
[this, &buffer_to_source_indices](
|
||||
const ShapeIndex& index,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
const ShapeIndex& index, const PointsToSet::BufferList& buffers) {
|
||||
if (buffers.size() > 1) {
|
||||
// Record ambiguous points-to set at 'index'.
|
||||
if (!indices_to_copy_.element(index)) {
|
||||
@ -285,7 +282,7 @@ Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction(
|
||||
}
|
||||
const auto& points_to_analysis = liveness.points_to_analysis();
|
||||
// Lookup buffers for 'instruction_' and 'other_instruction'.
|
||||
const std::vector<const LogicalBuffer*> instruction_buffers =
|
||||
const auto instruction_buffers =
|
||||
points_to_analysis.GetPointsToSet(instruction_).element(index);
|
||||
// If 'instruction_' has ambiguous points-to-set at 'index', it would
|
||||
// have been recorded in a previous pass (and we would have returned
|
||||
@ -294,7 +291,7 @@ Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction(
|
||||
CHECK_EQ(1, instruction_buffers.size());
|
||||
const LogicalBuffer* instruction_buffer = instruction_buffers[0];
|
||||
|
||||
const std::vector<const LogicalBuffer*> other_instruction_buffers =
|
||||
const auto other_instruction_buffers =
|
||||
points_to_analysis.GetPointsToSet(other_instruction).element(index);
|
||||
// Do not insert a copy if both instructions point at the same buffer.
|
||||
// This eliminates unnecessary copies of read-only tuple elements.
|
||||
@ -451,58 +448,57 @@ StatusOr<ShapeTree<HloInstruction*>> RevertReadOnlyIndicesForConstants(
|
||||
FlatSet<const LogicalBuffer*> buffer_set;
|
||||
|
||||
ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
|
||||
points_to.ForEachElement(
|
||||
[init_hlo, read_only_indices, shared_copies, &buffer_set,
|
||||
©_overrides](const ShapeIndex& index,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
// Look for read-only entry parameters.
|
||||
if (!read_only_indices->element(index)) {
|
||||
return;
|
||||
points_to.ForEachElement([init_hlo, read_only_indices, shared_copies,
|
||||
&buffer_set, ©_overrides](
|
||||
const ShapeIndex& index,
|
||||
const PointsToSet::BufferList& buffers) {
|
||||
// Look for read-only entry parameters.
|
||||
if (!read_only_indices->element(index)) {
|
||||
return;
|
||||
}
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
HloInstruction* pointee = buffer->instruction();
|
||||
const bool is_constant = pointee->opcode() == HloOpcode::kConstant;
|
||||
if (!is_constant) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We have found an constant that is read-only in
|
||||
// the while body. These buffers are managed by the caller, and cannot
|
||||
// be aliased with HLO buffers. Revert this read-only index,
|
||||
// to allow it to be copied.
|
||||
*read_only_indices->mutable_element(index) = false;
|
||||
|
||||
// Optimization to allow multiple while loops that share the same
|
||||
// read-only entry constants to share a single copy.
|
||||
// Only unambiguous and distinct array-shaped buffers are allowed, to
|
||||
// reduce code complexity. The shape of the entry parameter must be
|
||||
// identical to the shape of the init_hlo at this index, to ensure
|
||||
// there were no intervening bitcast or GTE instructions, which are
|
||||
// also hard to handle.
|
||||
const Shape& pointee_shape = pointee->shape();
|
||||
const Shape& init_shape =
|
||||
ShapeUtil::GetSubshape(init_hlo->shape(), index);
|
||||
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
|
||||
ShapeUtil::Equal(pointee_shape, init_shape) &&
|
||||
buffer_set.count(buffer) < 1) {
|
||||
HloInstruction** copy = &(*shared_copies)[pointee];
|
||||
if (*copy == nullptr) {
|
||||
*copy = pointee->parent()->AddInstruction(HloInstruction::CreateUnary(
|
||||
pointee_shape, HloOpcode::kCopy, pointee));
|
||||
}
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
HloInstruction* pointee = buffer->instruction();
|
||||
const bool is_constant = pointee->opcode() == HloOpcode::kConstant;
|
||||
if (!is_constant) {
|
||||
continue;
|
||||
}
|
||||
// Add the copy as an override.
|
||||
*copy_overrides.mutable_element(index) = *copy;
|
||||
}
|
||||
|
||||
// We have found an constant that is read-only in
|
||||
// the while body. These buffers are managed by the caller, and cannot
|
||||
// be aliased with HLO buffers. Revert this read-only index,
|
||||
// to allow it to be copied.
|
||||
*read_only_indices->mutable_element(index) = false;
|
||||
// Tracks whether this current buffer is distinct.
|
||||
buffer_set.insert(buffer);
|
||||
|
||||
// Optimization to allow multiple while loops that share the same
|
||||
// read-only entry constants to share a single copy.
|
||||
// Only unambiguous and distinct array-shaped buffers are allowed, to
|
||||
// reduce code complexity. The shape of the entry parameter must be
|
||||
// identical to the shape of the init_hlo at this index, to ensure
|
||||
// there were no intervening bitcast or GTE instructions, which are
|
||||
// also hard to handle.
|
||||
const Shape& pointee_shape = pointee->shape();
|
||||
const Shape& init_shape =
|
||||
ShapeUtil::GetSubshape(init_hlo->shape(), index);
|
||||
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
|
||||
ShapeUtil::Equal(pointee_shape, init_shape) &&
|
||||
buffer_set.count(buffer) < 1) {
|
||||
HloInstruction** copy = &(*shared_copies)[pointee];
|
||||
if (*copy == nullptr) {
|
||||
*copy =
|
||||
pointee->parent()->AddInstruction(HloInstruction::CreateUnary(
|
||||
pointee_shape, HloOpcode::kCopy, pointee));
|
||||
}
|
||||
// Add the copy as an override.
|
||||
*copy_overrides.mutable_element(index) = *copy;
|
||||
}
|
||||
|
||||
// Tracks whether this current buffer is distinct.
|
||||
buffer_set.insert(buffer);
|
||||
|
||||
// We've already reverted the read-only index and handled the
|
||||
// single-copy optimization above, so there's nothing more to do.
|
||||
break;
|
||||
}
|
||||
});
|
||||
// We've already reverted the read-only index and handled the
|
||||
// single-copy optimization above, so there's nothing more to do.
|
||||
break;
|
||||
}
|
||||
});
|
||||
return copy_overrides;
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ class CopyInsertionTest : public HloTestBase {
|
||||
EXPECT_TRUE(points_to.IsDistinct());
|
||||
EXPECT_TRUE(!points_to.IsAmbiguous());
|
||||
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers =
|
||||
auto maybe_live_out_buffers =
|
||||
points_to_analysis
|
||||
->GetPointsToSet(module->entry_computation()->root_instruction())
|
||||
.CreateFlattenedSet();
|
||||
|
@ -102,6 +102,7 @@ cc_library(
|
||||
":compiler_functor",
|
||||
":cpu_runtime",
|
||||
":cpu_runtime_avx",
|
||||
":cpu_runtime_neon",
|
||||
":cpu_runtime_sse4_1",
|
||||
":disassembler",
|
||||
":runtime_conv2d",
|
||||
@ -284,6 +285,7 @@ cc_library(
|
||||
deps = [
|
||||
":cpu_runtime",
|
||||
":cpu_runtime_avx",
|
||||
":cpu_runtime_neon",
|
||||
":cpu_runtime_sse4_1",
|
||||
":disassembler",
|
||||
":llvm_ir_runtime",
|
||||
@ -309,11 +311,11 @@ cc_library(
|
||||
srcs = ["cpu_runtime_sse4_1.cc"],
|
||||
hdrs = ["cpu_runtime_sse4_1.h"],
|
||||
copts = ["-DEIGEN_AVOID_STL_ARRAY"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -321,11 +323,24 @@ cc_library(
|
||||
srcs = ["cpu_runtime_avx.cc"],
|
||||
hdrs = ["cpu_runtime_avx.h"],
|
||||
copts = ["-DEIGEN_AVOID_STL_ARRAY"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_runtime_neon",
|
||||
srcs = ["cpu_runtime_neon.cc"],
|
||||
hdrs = ["cpu_runtime_neon.h"],
|
||||
# runtime_copts() enables -mfpu=neon
|
||||
copts = ["-DEIGEN_AVOID_STL_ARRAY"] + runtime_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
@ -54,6 +55,7 @@ CompilerFunctor::AllIntrinsics() {
|
||||
VectorIntrinsics intrinsics;
|
||||
intrinsics.sse_intrinsics = true;
|
||||
intrinsics.avx_intrinsics = true;
|
||||
intrinsics.neon_intrinsics = true;
|
||||
return intrinsics;
|
||||
}
|
||||
|
||||
@ -150,20 +152,28 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
|
||||
CompilerFunctor::VectorIntrinsics const& available_intrinsics) {
|
||||
std::vector<llvm::VecDesc> vector_functions;
|
||||
|
||||
const llvm::VecDesc four_wide_vector_functions[] = {
|
||||
{"expf", runtime::kExpV4F32SymbolName, 4},
|
||||
{"llvm.exp.f32", runtime::kExpV4F32SymbolName, 4},
|
||||
const llvm::VecDesc four_wide_vector_functions_neon[] = {
|
||||
{"expf", runtime::kExpV4F32NEONSymbolName, 4},
|
||||
{"llvm.exp.f32", runtime::kExpV4F32NEONSymbolName, 4},
|
||||
|
||||
{"logf", runtime::kLogV4F32SymbolName, 4},
|
||||
{"llvm.log.f32", runtime::kLogV4F32SymbolName, 4},
|
||||
{"logf", runtime::kLogV4F32NEONSymbolName, 4},
|
||||
{"llvm.log.f32", runtime::kLogV4F32NEONSymbolName, 4},
|
||||
};
|
||||
|
||||
const llvm::VecDesc eight_wide_vector_functions[] = {
|
||||
{"expf", runtime::kExpV8F32SymbolName, 8},
|
||||
{"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8},
|
||||
const llvm::VecDesc four_wide_vector_functions_sse[] = {
|
||||
{"expf", runtime::kExpV4F32SSESymbolName, 4},
|
||||
{"llvm.exp.f32", runtime::kExpV4F32SSESymbolName, 4},
|
||||
|
||||
{"logf", runtime::kLogV8F32SymbolName, 8},
|
||||
{"llvm.log.f32", runtime::kLogV8F32SymbolName, 8},
|
||||
{"logf", runtime::kLogV4F32SSESymbolName, 4},
|
||||
{"llvm.log.f32", runtime::kLogV4F32SSESymbolName, 4},
|
||||
};
|
||||
|
||||
const llvm::VecDesc eight_wide_vector_functions_avx[] = {
|
||||
{"expf", runtime::kExpV8F32AVXSymbolName, 8},
|
||||
{"llvm.exp.f32", runtime::kExpV8F32AVXSymbolName, 8},
|
||||
|
||||
{"logf", runtime::kLogV8F32AVXSymbolName, 8},
|
||||
{"llvm.log.f32", runtime::kLogV8F32AVXSymbolName, 8},
|
||||
};
|
||||
|
||||
// These functions are generated by XLA as LLVM IR, so they're always
|
||||
@ -176,27 +186,45 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
|
||||
{"llvm.tanh.f32", runtime::kTanhV8F32SymbolName, 8},
|
||||
};
|
||||
|
||||
if (arch == llvm::Triple::x86 || llvm::Triple::x86_64) {
|
||||
llvm::SmallVector<llvm::StringRef, 32> features;
|
||||
feature_string.split(features, ',', -1, /*KeepEmpty=*/false);
|
||||
if (std::find(features.begin(), features.end(), "+sse4.1") !=
|
||||
features.end() &&
|
||||
available_intrinsics.sse_intrinsics) {
|
||||
vector_functions.insert(vector_functions.end(),
|
||||
std::begin(four_wide_vector_functions),
|
||||
std::end(four_wide_vector_functions));
|
||||
llvm::SmallVector<llvm::StringRef, 32> features;
|
||||
feature_string.split(features, ',', -1, /*KeepEmpty=*/false);
|
||||
auto has_feature = [&features](const llvm::StringRef feature) {
|
||||
return std::find(features.begin(), features.end(), feature) !=
|
||||
features.end();
|
||||
};
|
||||
|
||||
switch (arch) {
|
||||
case llvm::Triple::x86:
|
||||
case llvm::Triple::x86_64: {
|
||||
if (has_feature("+sse4.1") && available_intrinsics.sse_intrinsics) {
|
||||
vector_functions.insert(vector_functions.end(),
|
||||
std::begin(four_wide_vector_functions_sse),
|
||||
std::end(four_wide_vector_functions_sse));
|
||||
}
|
||||
if (has_feature("+avx") && available_intrinsics.avx_intrinsics) {
|
||||
vector_functions.insert(vector_functions.end(),
|
||||
std::begin(eight_wide_vector_functions_avx),
|
||||
std::end(eight_wide_vector_functions_avx));
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (std::find(features.begin(), features.end(), "+avx") != features.end() &&
|
||||
available_intrinsics.avx_intrinsics) {
|
||||
vector_functions.insert(vector_functions.end(),
|
||||
std::begin(eight_wide_vector_functions),
|
||||
std::end(eight_wide_vector_functions));
|
||||
case llvm::Triple::arm:
|
||||
case llvm::Triple::aarch64: {
|
||||
if (has_feature("+neon") && available_intrinsics.neon_intrinsics) {
|
||||
vector_functions.insert(vector_functions.end(),
|
||||
std::begin(four_wide_vector_functions_neon),
|
||||
std::end(four_wide_vector_functions_neon));
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
vector_functions.insert(vector_functions.end(),
|
||||
std::begin(ir_vector_functions),
|
||||
std::end(ir_vector_functions));
|
||||
|
||||
return vector_functions;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -35,6 +35,7 @@ class CompilerFunctor {
|
||||
struct VectorIntrinsics {
|
||||
bool sse_intrinsics;
|
||||
bool avx_intrinsics;
|
||||
bool neon_intrinsics;
|
||||
};
|
||||
|
||||
// Returns a VectorIntrinsics where all intrinsics are available.
|
||||
|
@ -260,6 +260,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
|
||||
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
|
||||
pass.AddPass<BatchNormRewriter>(
|
||||
/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true,
|
||||
/*use_fusion=*/false);
|
||||
pass.AddPass<AlgebraicSimplifier>(
|
||||
|
@ -316,7 +316,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> CpuExecutable::ExecuteOnStream(
|
||||
[&buffers, &buffers_in_result, &result_buffer, this](
|
||||
const ShapeIndex& index, size_t* buffer_entry) {
|
||||
if (ShapeUtil::IsLeafIndex(result_buffer->shape(), index)) {
|
||||
const std::vector<const LogicalBuffer*>& sources =
|
||||
const auto& sources =
|
||||
this->GetRootPointsToSet().element(index);
|
||||
// The points to set is unambiguous so the set should be a
|
||||
// singleton.
|
||||
|
@ -20,13 +20,13 @@ limitations under the License.
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
#ifdef __AVX__
|
||||
xla::cpu::runtime::V8F32 __xla_cpu_runtime_ExpV8F32(
|
||||
xla::cpu::runtime::V8F32 x) {
|
||||
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX(
|
||||
xla::cpu::runtime::V8F32AVX x) {
|
||||
return Eigen::internal::pexp(x);
|
||||
}
|
||||
|
||||
xla::cpu::runtime::V8F32 __xla_cpu_runtime_LogV8F32(
|
||||
xla::cpu::runtime::V8F32 x) {
|
||||
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
|
||||
xla::cpu::runtime::V8F32AVX x) {
|
||||
return Eigen::internal::plog(x);
|
||||
}
|
||||
#endif // __AVX__
|
||||
@ -35,8 +35,8 @@ namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
const char *const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32";
|
||||
const char *const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32";
|
||||
const char *const kExpV8F32AVXSymbolName = "__xla_cpu_runtime_ExpV8F32AVX";
|
||||
const char *const kLogV8F32AVXSymbolName = "__xla_cpu_runtime_LogV8F32AVX";
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
|
@ -28,11 +28,10 @@ namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
extern const char *const kExpV8F32SymbolName;
|
||||
extern const char *const kLogV8F32SymbolName;
|
||||
extern const char *const kTanhV8F32SymbolName;
|
||||
extern const char *const kExpV8F32AVXSymbolName;
|
||||
extern const char *const kLogV8F32AVXSymbolName;
|
||||
|
||||
typedef float V8F32 __attribute__((__vector_size__(32)));
|
||||
typedef float V8F32AVX __attribute__((__vector_size__(32)));
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
@ -42,12 +41,11 @@ extern "C" {
|
||||
// The following functions are vectorized versions of a selection of libm
|
||||
// library functions.
|
||||
// References to these functions are created by the LLVM vectorizer.
|
||||
xla::cpu::runtime::V8F32 __xla_cpu_runtime_ExpV8F32(xla::cpu::runtime::V8F32 x)
|
||||
TF_ATTRIBUTE_WEAK;
|
||||
|
||||
xla::cpu::runtime::V8F32 __xla_cpu_runtime_LogV8F32(xla::cpu::runtime::V8F32 x)
|
||||
TF_ATTRIBUTE_WEAK;
|
||||
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX(
|
||||
xla::cpu::runtime::V8F32AVX x) TF_ATTRIBUTE_WEAK;
|
||||
|
||||
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
|
||||
xla::cpu::runtime::V8F32AVX x) TF_ATTRIBUTE_WEAK;
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
|
||||
|
46
tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc
Normal file
46
tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
#ifdef __ARM_NEON__
|
||||
|
||||
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
|
||||
xla::cpu::runtime::V4F32NEON x) {
|
||||
return Eigen::internal::pexp(x);
|
||||
}
|
||||
|
||||
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
|
||||
xla::cpu::runtime::V4F32NEON x) {
|
||||
Eigen::internal::Packet4f p = x;
|
||||
return Eigen::internal::plog(p);
|
||||
}
|
||||
|
||||
#endif // __ARM_NEON__
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
const char *const kExpV4F32NEONSymbolName = "__xla_cpu_runtime_ExpV4F32NEON";
|
||||
const char *const kLogV4F32NEONSymbolName = "__xla_cpu_runtime_LogV4F32NEON";
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
62
tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h
Normal file
62
tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
|
||||
|
||||
// This header declares functions which may be called by the generated code on
|
||||
// the CPU. Calls to these functions must be resolved explicitly in the JIT in
|
||||
// xla::cpu::SimpleResolver.
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
#ifdef __ARM_NEON__
|
||||
// For the other runtimes (AVX, SSE4.1) we define the vector type directly using
|
||||
// __attribute__((__vector_size__(*))). Unfortunately, the typedef for the ARM
|
||||
// NEON SIMD types is not portable, so the type has to come from <arm_neon.h>
|
||||
#include <arm_neon.h>
|
||||
#endif // __ARM_NEON__
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
extern const char *const kExpV4F32NEONSymbolName;
|
||||
extern const char *const kLogV4F32NEONSymbolName;
|
||||
|
||||
#ifdef __ARM_NEON__
|
||||
typedef float32x4_t V4F32NEON;
|
||||
#else
|
||||
// On non-ARM platforms ensure the declaration is present
|
||||
struct V4F32NEON;
|
||||
#endif // __ARM_NEON__
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
extern "C" {
|
||||
|
||||
// The following functions are vectorized versions of a selection of libm
|
||||
// library functions.
|
||||
// References to these functions are created by the LLVM vectorizer.
|
||||
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
|
||||
xla::cpu::runtime::V4F32NEON x) TF_ATTRIBUTE_WEAK;
|
||||
|
||||
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
|
||||
xla::cpu::runtime::V4F32NEON x) TF_ATTRIBUTE_WEAK;
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
|
@ -21,14 +21,14 @@ limitations under the License.
|
||||
|
||||
#ifdef __SSE4_1__
|
||||
|
||||
xla::cpu::runtime::V4F32 __xla_cpu_runtime_ExpV4F32(
|
||||
xla::cpu::runtime::V4F32 x) {
|
||||
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE(
|
||||
xla::cpu::runtime::V4F32SSE x) {
|
||||
Eigen::internal::Packet4f p = x;
|
||||
return Eigen::internal::pexp(p);
|
||||
}
|
||||
|
||||
xla::cpu::runtime::V4F32 __xla_cpu_runtime_LogV4F32(
|
||||
xla::cpu::runtime::V4F32 x) {
|
||||
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
|
||||
xla::cpu::runtime::V4F32SSE x) {
|
||||
Eigen::internal::Packet4f p = x;
|
||||
return Eigen::internal::plog(p);
|
||||
}
|
||||
@ -39,8 +39,8 @@ namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
const char *const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32";
|
||||
const char *const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32";
|
||||
const char *const kExpV4F32SSESymbolName = "__xla_cpu_runtime_ExpV4F32SSE";
|
||||
const char *const kLogV4F32SSESymbolName = "__xla_cpu_runtime_LogV4F32SSE";
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
|
@ -28,11 +28,10 @@ namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
extern const char *const kExpV4F32SymbolName;
|
||||
extern const char *const kLogV4F32SymbolName;
|
||||
extern const char *const kTanhV4F32SymbolName;
|
||||
extern const char *const kExpV4F32SSESymbolName;
|
||||
extern const char *const kLogV4F32SSESymbolName;
|
||||
|
||||
typedef float V4F32 __attribute__((__vector_size__(16)));
|
||||
typedef float V4F32SSE __attribute__((__vector_size__(16)));
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
@ -43,11 +42,11 @@ extern "C" {
|
||||
// The following functions are vectorized versions of a selection of libm
|
||||
// library functions.
|
||||
// References to these functions are created by the LLVM vectorizer.
|
||||
xla::cpu::runtime::V4F32 __xla_cpu_runtime_ExpV4F32(xla::cpu::runtime::V4F32 x)
|
||||
TF_ATTRIBUTE_WEAK;
|
||||
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE(
|
||||
xla::cpu::runtime::V4F32SSE x) TF_ATTRIBUTE_WEAK;
|
||||
|
||||
xla::cpu::runtime::V4F32 __xla_cpu_runtime_LogV4F32(xla::cpu::runtime::V4F32 x)
|
||||
TF_ATTRIBUTE_WEAK;
|
||||
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
|
||||
xla::cpu::runtime::V4F32SSE x) TF_ATTRIBUTE_WEAK;
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
|
||||
|
@ -566,7 +566,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> ParallelCpuExecutable::ExecuteOnStream(
|
||||
[&buffers, &buffers_in_result, &result_buffer, this](
|
||||
const ShapeIndex& index, size_t* buffer_entry) {
|
||||
if (ShapeUtil::IsLeafIndex(result_buffer->shape(), index)) {
|
||||
const std::vector<const LogicalBuffer*>& sources =
|
||||
const auto& sources =
|
||||
this->GetRootPointsToSet().element(index);
|
||||
// The points to set is unambiguous so the set should be a
|
||||
// singleton.
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
|
||||
@ -91,10 +92,12 @@ class JITSymbolTable {
|
||||
ADD_JIT_SYMBOL_TO_TABLE(ReleaseInfeedBufferAfterDequeue);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(AcquireOutfeedBufferForPopulation);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(ReleaseOutfeedBufferAfterPopulation);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(ExpV8F32);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(LogV8F32);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(LogV4F32);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(ExpV8F32AVX);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(LogV8F32AVX);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32SSE);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(LogV4F32SSE);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32NEON);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(LogV4F32NEON);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(EigenConvF32);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF32);
|
||||
ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF64);
|
||||
@ -162,8 +165,9 @@ llvm::StringRef GetHostCpuName() {
|
||||
|
||||
CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
|
||||
CompilerFunctor::VectorIntrinsics intrinsics;
|
||||
intrinsics.sse_intrinsics = (&__xla_cpu_runtime_ExpV4F32 != nullptr);
|
||||
intrinsics.avx_intrinsics = (&__xla_cpu_runtime_ExpV8F32 != nullptr);
|
||||
intrinsics.sse_intrinsics = (&__xla_cpu_runtime_ExpV4F32SSE != nullptr);
|
||||
intrinsics.avx_intrinsics = (&__xla_cpu_runtime_ExpV8F32AVX != nullptr);
|
||||
intrinsics.neon_intrinsics = (&__xla_cpu_runtime_ExpV4F32NEON != nullptr);
|
||||
return intrinsics;
|
||||
}
|
||||
|
||||
|
@ -228,6 +228,9 @@ class DfsHloVisitor {
|
||||
|
||||
virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0;
|
||||
|
||||
virtual Status HandleBatchNormInference(
|
||||
HloInstruction* batchNormInference) = 0;
|
||||
|
||||
virtual Status HandleBatchNormGrad(HloInstruction* batchNormGrad) = 0;
|
||||
|
||||
// Invoked to inform the visitor that the traversal has completed, and that
|
||||
@ -245,6 +248,11 @@ class DfsHloVisitor {
|
||||
VisitState GetVisitState(int id) { return visit_state_.GetState(id); }
|
||||
VisitState GetVisitState(const HloInstruction& instruction);
|
||||
|
||||
// Resize internal state if necessary to hold state for ids <= num.
|
||||
// This call is purely a performance hint and can be omitted without
|
||||
// affecting correctness.
|
||||
void ReserveVisitStates(int num) { visit_state_.Reserve(num); }
|
||||
|
||||
void SetVisitState(int id, VisitState state) {
|
||||
visit_state_.SetState(id, state);
|
||||
}
|
||||
@ -298,35 +306,35 @@ class DfsHloVisitor {
|
||||
private:
|
||||
class DFSVisitStates {
|
||||
public:
|
||||
DFSVisitStates() {
|
||||
// Avoid frequent resizes of the visited bits array
|
||||
states_.reserve(512);
|
||||
DFSVisitStates() {}
|
||||
void Reserve(uint64 num) {
|
||||
states_.reserve((num + kStatesPerWord - 1) / kStatesPerWord);
|
||||
}
|
||||
VisitState GetState(int id) {
|
||||
int word_index = id / kStatesPerWord;
|
||||
VisitState GetState(uint64 id) {
|
||||
uint64 word_index = id / kStatesPerWord;
|
||||
if (word_index >= states_.size()) {
|
||||
return VisitState::kNotVisited;
|
||||
}
|
||||
static_assert(static_cast<int>(VisitState::kVisited) < 3,
|
||||
"VisitState must fit in two bits");
|
||||
uint64 w = states_[word_index];
|
||||
int shift = 2 * (id % kStatesPerWord); // 2 bits per state
|
||||
uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state
|
||||
return static_cast<VisitState>((w >> shift) & 0x3);
|
||||
}
|
||||
void SetState(int id, VisitState state) {
|
||||
int word_index = id / kStatesPerWord;
|
||||
void SetState(uint64 id, VisitState state) {
|
||||
uint64 word_index = id / kStatesPerWord;
|
||||
if (word_index >= states_.size()) {
|
||||
states_.resize(word_index + 1, 0);
|
||||
}
|
||||
uint64* w = &states_[word_index];
|
||||
int shift = 2 * (id % kStatesPerWord); // 2 bits per state
|
||||
uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state
|
||||
uint64 mask = 0x3ull << shift;
|
||||
*w = (*w & ~mask) | (static_cast<uint64>(state) << shift);
|
||||
DCHECK_EQ(GetState(id), state);
|
||||
}
|
||||
|
||||
private:
|
||||
static const int kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/;
|
||||
static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/;
|
||||
// Map from id to two-bit states. We store 32 such states per 64-bit
|
||||
// value
|
||||
std::vector<uint64> states_;
|
||||
|
@ -54,6 +54,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
Status HandleBatchNormInference(HloInstruction* hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
Status HandleBatchNormGrad(HloInstruction* hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
@ -135,6 +135,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
|
||||
// instead.
|
||||
pass.AddPass<BatchNormRewriter>(
|
||||
/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true,
|
||||
/*use_fusion=*/false);
|
||||
pass.AddPass<AlgebraicSimplifier>(
|
||||
|
@ -232,7 +232,7 @@ StatusOr<se::DeviceMemoryBase> GpuExecutable::ExecuteOnStream(
|
||||
TF_RETURN_IF_ERROR(GetRootPointsToSet().ForEachElementWithStatus(
|
||||
[&referred_by_output, &buffer_allocations, this](
|
||||
const ShapeIndex& /*index*/,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
const PointsToSet::BufferList& buffers) {
|
||||
// The points to set is unambiguous so the set should be a
|
||||
// singleton. That is, we know exactly which instruction produced
|
||||
// the array at this element.
|
||||
@ -311,7 +311,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> GpuExecutable::ExecuteOnStream(
|
||||
[&buffer_allocations, &buffers_in_result, &shaped_buffer, this](
|
||||
const ShapeIndex& index, size_t* buffer_entry) {
|
||||
if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) {
|
||||
const std::vector<const LogicalBuffer*>& sources =
|
||||
const auto& sources =
|
||||
this->GetRootPointsToSet().element(index);
|
||||
// The points to set is unambiguous so the set should be a
|
||||
// singleton. That is, we know exactly which instruction
|
||||
|
@ -39,7 +39,7 @@ std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
|
||||
for (const HloInstruction* operand : instruction->operands()) {
|
||||
points_to_analysis.GetPointsToSet(operand).ForEachElement(
|
||||
[&](const ShapeIndex& /*index*/,
|
||||
const std::vector<const LogicalBuffer*>& points_to) {
|
||||
const PointsToSet::BufferList& points_to) {
|
||||
buffers.insert(buffers.end(), points_to.begin(), points_to.end());
|
||||
});
|
||||
}
|
||||
@ -107,7 +107,7 @@ Status HeapSimulator::RunComputation(
|
||||
FlatMap<const LogicalBuffer*, FlatSet<const HloInstruction*>> live_buffers;
|
||||
|
||||
const HloInstruction* root = computation.root_instruction();
|
||||
FlatSet<const LogicalBuffer*> output_source_buffers =
|
||||
auto output_source_buffers =
|
||||
points_to_analysis.GetPointsToSet(root).CreateFlattenedSet();
|
||||
|
||||
std::vector<const LogicalBuffer*> dead_buffers_to_free;
|
||||
|
@ -481,7 +481,7 @@ TEST_F(HeapSimulatorTest, WholeModule) {
|
||||
// Base class for heap algorithm tests.
|
||||
class HeapAlgorithmTestBase : public ::testing::Test {
|
||||
protected:
|
||||
HeapAlgorithmTestBase() {
|
||||
HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
|
||||
buffer_a_ = DummyLogicalBuffer();
|
||||
buffer_b_ = DummyLogicalBuffer();
|
||||
buffer_c_ = DummyLogicalBuffer();
|
||||
@ -505,15 +505,16 @@ class HeapAlgorithmTestBase : public ::testing::Test {
|
||||
const LogicalBuffer* buffer_i_;
|
||||
|
||||
private:
|
||||
// Create a dummy LogicalBuffer to pass to the heap algorithm. Since the
|
||||
// algorithms only use the buffer as a handle, we don't need to fill in much
|
||||
// other than the id and color.
|
||||
// Create a dummy LogicalBuffer to pass to the heap algorithm.
|
||||
const LogicalBuffer* DummyLogicalBuffer() {
|
||||
const LogicalBuffer::Id id = buffers_.size();
|
||||
buffers_.emplace_back(MakeUnique<LogicalBuffer>(nullptr, ShapeIndex{}, id));
|
||||
auto const0 = builder_.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
buffers_.emplace_back(MakeUnique<LogicalBuffer>(const0, ShapeIndex{}, id));
|
||||
return buffers_.back().get();
|
||||
}
|
||||
|
||||
HloComputation::Builder builder_;
|
||||
std::vector<std::unique_ptr<LogicalBuffer>> buffers_;
|
||||
};
|
||||
|
||||
|
@ -374,6 +374,12 @@ Status HloCostAnalysis::HandleBatchNormTraining(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleBatchNormInference(
|
||||
HloInstruction* batchNormInference) {
|
||||
// TODO(b/62294698): Implement cost analysis for batch-norm-inference.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleBatchNormGrad(HloInstruction* batchNormGrad) {
|
||||
// TODO(b/62294698): Implement cost analysis for batch-norm-grad.
|
||||
return Status::OK();
|
||||
|
@ -89,6 +89,7 @@ class HloCostAnalysis : public DfsHloVisitor {
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
HloComputation* function_handle) override;
|
||||
Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override;
|
||||
Status HandleBatchNormInference(HloInstruction* batchNormInference) override;
|
||||
Status HandleBatchNormGrad(HloInstruction* batchNormGrad) override;
|
||||
Status HandleFusion(HloInstruction* fusion) override;
|
||||
Status HandleCall(HloInstruction* call) override;
|
||||
|
@ -742,6 +742,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kParameter:
|
||||
return kOrange;
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
|
@ -406,6 +406,23 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape,
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreateBatchNormInference(
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
||||
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
|
||||
float epsilon, int64 feature_index) {
|
||||
auto instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape));
|
||||
instruction->AppendOperand(operand);
|
||||
instruction->AppendOperand(scale);
|
||||
instruction->AppendOperand(offset);
|
||||
instruction->AppendOperand(mean);
|
||||
instruction->AppendOperand(variance);
|
||||
instruction->epsilon_ = epsilon;
|
||||
instruction->feature_index_ = feature_index;
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
|
||||
HloInstruction* scale, HloInstruction* mean,
|
||||
@ -726,7 +743,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
// If this is already a multioutput fusion instruction, expand the root
|
||||
// tuple by 1.
|
||||
HloInstruction* fused_root = fused_expression_root();
|
||||
std::vector<HloInstruction*> tuple_elements;
|
||||
HloInstruction::InstructionVector tuple_elements;
|
||||
bool newly_created_tuple_instr = false;
|
||||
if (fused_root->opcode() == HloOpcode::kTuple) {
|
||||
tuple_elements = fused_root->operands();
|
||||
@ -735,10 +752,9 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
newly_created_tuple_instr = true;
|
||||
}
|
||||
if (clone->opcode() == HloOpcode::kTuple) {
|
||||
const auto& tuple_elements_to_fuse = clone->operands();
|
||||
tuple_elements.insert(tuple_elements.end(),
|
||||
tuple_elements_to_fuse.begin(),
|
||||
tuple_elements_to_fuse.end());
|
||||
for (auto inst : clone->operands()) {
|
||||
tuple_elements.push_back(inst);
|
||||
}
|
||||
} else {
|
||||
tuple_elements.push_back(clone);
|
||||
}
|
||||
@ -1065,6 +1081,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
return CreateBatchNormTraining(shape, new_operands[0], new_operands[1],
|
||||
new_operands[2], epsilon(),
|
||||
feature_index());
|
||||
|
||||
case HloOpcode::kBatchNormInference:
|
||||
CHECK_EQ(new_operands.size(), 5);
|
||||
return CreateBatchNormInference(
|
||||
shape, new_operands[0], new_operands[1], new_operands[2],
|
||||
new_operands[3], new_operands[4], epsilon(), feature_index());
|
||||
case HloOpcode::kInfeed:
|
||||
CHECK_EQ(new_operands.size(), 0);
|
||||
return CreateInfeed(shape, infeed_config());
|
||||
@ -1355,6 +1377,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
ShapeUtil::Compatible(shape(), other.shape());
|
||||
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
return feature_index() == other.feature_index() &&
|
||||
epsilon() == other.epsilon();
|
||||
@ -1940,8 +1963,8 @@ HloInstruction::fused_instructions() const {
|
||||
|
||||
HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
|
||||
: unique_id_(-1),
|
||||
shape_(shape),
|
||||
opcode_(opcode),
|
||||
shape_(shape),
|
||||
name_("%" + HloOpcodeString(opcode)) {
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
|
||||
}
|
||||
@ -1952,6 +1975,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
||||
return visitor->HandleAbs(this, operands_[0]);
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
return visitor->HandleBatchNormTraining(this);
|
||||
case HloOpcode::kBatchNormInference:
|
||||
return visitor->HandleBatchNormInference(this);
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
return visitor->HandleBatchNormGrad(this);
|
||||
case HloOpcode::kSign:
|
||||
@ -2092,12 +2117,13 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
||||
HloOpcodeString(opcode_).c_str());
|
||||
}
|
||||
|
||||
using DFSStack =
|
||||
tensorflow::gtl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
|
||||
|
||||
// Push "child" onto the dfs_stack if not already visited. Returns false if a
|
||||
// cycle was detected, and true otherwise.
|
||||
inline bool PushDFSChild(
|
||||
DfsHloVisitor* visitor,
|
||||
std::vector<std::pair<int, HloInstruction*>>* dfs_stack,
|
||||
HloInstruction* child) {
|
||||
inline bool PushDFSChild(DfsHloVisitor* visitor, DFSStack* dfs_stack,
|
||||
HloInstruction* child) {
|
||||
const int id = child->unique_id();
|
||||
CHECK_GE(id, 0) << "instruction may not have a parent computation";
|
||||
switch (visitor->GetVisitState(id)) {
|
||||
@ -2120,13 +2146,15 @@ using InternalCompareFunction =
|
||||
static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
|
||||
const InternalCompareFunction* operand_order,
|
||||
bool ignore_control_predecessors) {
|
||||
visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds());
|
||||
|
||||
// dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
|
||||
//
|
||||
// We need to keep track of both the id and the instruction because
|
||||
// instructions can get deleted while they are on the stack, so we
|
||||
// can't always use the (potentiall dead) instruction object to grab
|
||||
// its id.
|
||||
std::vector<std::pair<int, HloInstruction*>> dfs_stack;
|
||||
DFSStack dfs_stack;
|
||||
dfs_stack.emplace_back(root->unique_id(), root);
|
||||
|
||||
do {
|
||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -224,6 +225,12 @@ class HloInstruction {
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
||||
HloInstruction* offset, float epsilon, int64 feature_index);
|
||||
|
||||
// Creates a batch-norm-inference instruction.
|
||||
static std::unique_ptr<HloInstruction> CreateBatchNormInference(
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
||||
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
|
||||
float epsilon, int64 feature_index);
|
||||
|
||||
// Creates a batch-norm-grad instruction.
|
||||
static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
||||
@ -330,7 +337,8 @@ class HloInstruction {
|
||||
int64 operand_count() const { return operands_.size(); }
|
||||
|
||||
// Returns the vector of operands of this instruction.
|
||||
const std::vector<HloInstruction*>& operands() const { return operands_; }
|
||||
using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
|
||||
const InstructionVector& operands() const { return operands_; }
|
||||
|
||||
// Returns the index of 'target' in the operands sequence.
|
||||
// Precondition: target must be an operand (or a fatal error will occur).
|
||||
@ -980,15 +988,34 @@ class HloInstruction {
|
||||
|
||||
int unique_id_; // Unique to this HloInstruction within a HloModule
|
||||
|
||||
// Opcode for this instruction.
|
||||
HloOpcode opcode_;
|
||||
|
||||
// Instruction operands.
|
||||
InstructionVector operands_;
|
||||
|
||||
// The set of control predecessors of this instruction.
|
||||
std::vector<HloInstruction*> control_predecessors_;
|
||||
|
||||
// The users of this instruction. Users are HLOs where this instruction is an
|
||||
// operand. The vector users_ and the set user_set_ contain identical
|
||||
// members. The set enables fast membership testing and the vector enables
|
||||
// fast, stable iteration.
|
||||
std::vector<HloInstruction*> users_;
|
||||
std::unordered_set<const HloInstruction*> user_set_;
|
||||
|
||||
// The set of control successors of this instruction.
|
||||
std::vector<HloInstruction*> control_successors_;
|
||||
|
||||
// The computation in which this instruction is contained.
|
||||
HloComputation* parent_ = nullptr;
|
||||
|
||||
// Shape of outfeed request.
|
||||
Shape outfeed_shape_;
|
||||
|
||||
// Result shape of this instruction.
|
||||
Shape shape_;
|
||||
|
||||
// Opcode for this instruction.
|
||||
HloOpcode opcode_;
|
||||
|
||||
// Literal, only present for kConstant.
|
||||
std::unique_ptr<Literal> literal_;
|
||||
|
||||
@ -1054,22 +1081,6 @@ class HloInstruction {
|
||||
// Outfeed configuration information, only present for kOutfeed.
|
||||
string outfeed_config_;
|
||||
|
||||
// Instruction operands.
|
||||
std::vector<HloInstruction*> operands_;
|
||||
|
||||
// The users of this instruction. Users are HLOs where this instruction is an
|
||||
// operand. The vector users_ and the set user_set_ contain identical
|
||||
// members. The set enables fast membership testing and the vector enables
|
||||
// fast, stable iteration.
|
||||
std::vector<HloInstruction*> users_;
|
||||
std::unordered_set<const HloInstruction*> user_set_;
|
||||
|
||||
// The set of control predecessors of this instruction.
|
||||
std::vector<HloInstruction*> control_predecessors_;
|
||||
|
||||
// The set of control successors of this instruction.
|
||||
std::vector<HloInstruction*> control_successors_;
|
||||
|
||||
// A trace instruction that consumes this instruction.
|
||||
//
|
||||
// Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
|
||||
@ -1098,9 +1109,6 @@ class HloInstruction {
|
||||
// String identifier for instruction.
|
||||
string name_;
|
||||
|
||||
// The computation in which this instruction is contained.
|
||||
HloComputation* parent_ = nullptr;
|
||||
|
||||
// Metadata for debugging.
|
||||
OpMetadata metadata_;
|
||||
|
||||
|
@ -33,6 +33,8 @@ string HloOpcodeString(HloOpcode opcode) {
|
||||
return "add";
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
return "batch-norm-training";
|
||||
case HloOpcode::kBatchNormInference:
|
||||
return "batch-norm-inference";
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
return "batch-norm-grad";
|
||||
case HloOpcode::kBitcast:
|
||||
|
@ -31,6 +31,7 @@ enum class HloOpcode {
|
||||
kAbs,
|
||||
kAdd,
|
||||
kBatchNormTraining,
|
||||
kBatchNormInference,
|
||||
kBatchNormGrad,
|
||||
kBitcast,
|
||||
kBroadcast,
|
||||
|
@ -55,16 +55,6 @@ namespace {
|
||||
|
||||
// Returns true if the given instruction is rematerializable.
|
||||
bool IsRematerializable(const HloInstruction* instruction) {
|
||||
// Conservatively, don't rematerialize instruction with control
|
||||
// dependencies. For one, control dependencies are added to prevent
|
||||
// interference of aliased buffers (say, in while bodies) and
|
||||
// rematerialization is ignorant of liveness and may break the intended
|
||||
// ordering.
|
||||
if (!instruction->control_predecessors().empty() ||
|
||||
!instruction->control_successors().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't rematerialize instructions with side effects or instructions which
|
||||
// cannot be cloned safely.
|
||||
switch (instruction->opcode()) {
|
||||
@ -503,7 +493,7 @@ MemoryUsageTracker::MemoryUsageTracker(
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const InstructionList& instruction_list)
|
||||
: computation_(computation), instruction_list_(instruction_list) {
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> live_out_set =
|
||||
PointsToSet::BufferSet live_out_set =
|
||||
points_to_analysis.GetPointsToSet(computation_->root_instruction())
|
||||
.CreateFlattenedSet();
|
||||
tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
|
||||
@ -906,6 +896,19 @@ Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
|
||||
continue;
|
||||
}
|
||||
|
||||
// If any of the candidate's control successor has been placed, we need to
|
||||
// skip this candidate. Otherwise we will violate control dependency.
|
||||
bool control_successor_placed =
|
||||
std::any_of(candidate->control_successors().begin(),
|
||||
candidate->control_successors().end(),
|
||||
[&memory_tracker](const HloInstruction* inst) {
|
||||
return memory_tracker.IsPlaced(inst);
|
||||
});
|
||||
|
||||
if (control_successor_placed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64 memory_reduced =
|
||||
memory_tracker.MemoryReducedIfRematerialized(item);
|
||||
|
||||
@ -1047,6 +1050,15 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
|
||||
HloInstruction* remat =
|
||||
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
|
||||
|
||||
// Add control dependencies to the new operation.
|
||||
for (auto successor : best->control_successors()) {
|
||||
TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
|
||||
}
|
||||
for (auto predecessor : best->control_predecessors()) {
|
||||
TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
|
||||
}
|
||||
|
||||
Item* remat_item = instruction_list.CreateItem(remat);
|
||||
|
||||
// Replace each remaining use of 'best' with the rematerialization.
|
||||
@ -1082,6 +1094,15 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
}
|
||||
}
|
||||
}
|
||||
// Insert rematerialized instruction before any of its successors to
|
||||
// preserve ordering regarding control dependency.
|
||||
for (auto successor : remat->control_successors()) {
|
||||
Item* successor_item = instruction_list.GetItem(successor);
|
||||
// Assert to make sure we never remat an operation with control
|
||||
// successor already placed.
|
||||
CHECK(!successor_item->placed);
|
||||
place_before.push_back(successor_item);
|
||||
}
|
||||
instruction_list.InsertBeforeInstructions(remat_item, place_before);
|
||||
|
||||
// If the rematerialized instruction is dead then rematerialization is
|
||||
|
@ -78,6 +78,7 @@ namespace xla {
|
||||
|
||||
// Expensive instructions.
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConvolution:
|
||||
|
@ -628,9 +628,8 @@ Status CheckLayouts(
|
||||
const PointsToSet& points_to_set =
|
||||
points_to_analysis->GetPointsToSet(instruction.get());
|
||||
TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
|
||||
[&instruction](
|
||||
ShapeIndex index,
|
||||
const std::vector<const LogicalBuffer*>& buffers) -> Status {
|
||||
[&instruction](ShapeIndex index,
|
||||
const PointsToSet::BufferList& buffers) -> Status {
|
||||
if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
|
||||
const Shape& instruction_subshape =
|
||||
ShapeUtil::GetSubshape(instruction->shape(), index);
|
||||
@ -934,7 +933,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs(
|
||||
return points_to_set.ForEachElementWithStatus(
|
||||
[this, &shape_layout, constraints](
|
||||
const ShapeIndex& index,
|
||||
const std::vector<const LogicalBuffer*>& buffers) -> Status {
|
||||
const PointsToSet::BufferList& buffers) -> Status {
|
||||
if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
if (constraints->BufferLayout(*buffer) == nullptr &&
|
||||
@ -1076,7 +1075,7 @@ StatusOr<Layout> InferArrayLayout(
|
||||
TF_RET_CHECK(
|
||||
!points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index));
|
||||
|
||||
const std::vector<const LogicalBuffer*>& source_buffers =
|
||||
const auto& source_buffers =
|
||||
points_to_analysis.GetPointsToSet(instruction).element(index);
|
||||
TF_RET_CHECK(!source_buffers.empty());
|
||||
|
||||
|
@ -80,7 +80,7 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
|
||||
HloInstruction* instruction, const ShapeIndex& index,
|
||||
const TuplePointsToAnalysis& points_to_analysis) {
|
||||
std::vector<std::pair<HloInstruction*, int64>> uses;
|
||||
const std::vector<const LogicalBuffer*>& points_to =
|
||||
const PointsToSet::BufferList& points_to =
|
||||
points_to_analysis.GetPointsToSet(instruction).element(index);
|
||||
for (const LogicalBuffer* buffer : points_to) {
|
||||
for (const BufferAlias& alias :
|
||||
|
@ -26,6 +26,14 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
LogicalBuffer::LogicalBuffer(HloInstruction* instruction,
|
||||
const ShapeIndex& index, Id id)
|
||||
: instruction_(instruction), id_(id), color_(kInvalidColor), index_(index) {
|
||||
const auto& s = shape();
|
||||
is_array_ = ShapeUtil::IsArray(s);
|
||||
is_tuple_ = ShapeUtil::IsTuple(s);
|
||||
}
|
||||
|
||||
string LogicalBuffer::ToString() const {
|
||||
return tensorflow::strings::StrCat(instruction_->name(), "[",
|
||||
tensorflow::str_util::Join(index_, ","),
|
||||
|
@ -97,11 +97,7 @@ class LogicalBuffer {
|
||||
using SizeFunction = std::function<int64(const LogicalBuffer&)>;
|
||||
using AlignmentFunction = std::function<int64(LogicalBuffer::Color)>;
|
||||
|
||||
LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id)
|
||||
: instruction_(instruction),
|
||||
index_(index),
|
||||
id_(id),
|
||||
color_(kInvalidColor) {}
|
||||
LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id);
|
||||
|
||||
Id id() const { return id_; }
|
||||
|
||||
@ -140,14 +136,14 @@ class LogicalBuffer {
|
||||
bool IsTopLevel() const { return index_.empty(); }
|
||||
|
||||
// Whether this buffer contains a tuple.
|
||||
bool IsTuple() const { return ShapeUtil::IsTuple(shape()); }
|
||||
bool IsTuple() const { return is_tuple_; }
|
||||
|
||||
// Whether this buffer contains an array.
|
||||
bool IsArray() const { return is_array_; }
|
||||
|
||||
// operator< is required for std::set.
|
||||
bool operator<(const LogicalBuffer& other) const { return id_ < other.id_; }
|
||||
|
||||
// Whether this buffer contains an array.
|
||||
bool IsArray() const { return ShapeUtil::IsArray(shape()); }
|
||||
|
||||
string ToString() const;
|
||||
LogicalBufferProto ToProto(const SizeFunction& size_fn) const;
|
||||
|
||||
@ -160,9 +156,11 @@ class LogicalBuffer {
|
||||
|
||||
private:
|
||||
HloInstruction* instruction_;
|
||||
ShapeIndex index_;
|
||||
Id id_;
|
||||
Id id_ : 62;
|
||||
bool is_array_ : 1;
|
||||
bool is_tuple_ : 1;
|
||||
Color color_;
|
||||
ShapeIndex index_;
|
||||
|
||||
// Similar to HLO constructs (HloInstruction, etc), pointers are used for
|
||||
// comparison to equality, so disable all copying.
|
||||
|
@ -129,7 +129,7 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) {
|
||||
// metadata.
|
||||
bool IsElementwiseOfEquivalentReshapesOrTransposes(
|
||||
const HloInstruction* instruction) {
|
||||
const std::vector<HloInstruction*>& operands = instruction->operands();
|
||||
const auto& operands = instruction->operands();
|
||||
HloInstruction* first_reshape_operand =
|
||||
FirstNonScalarAndNonTrivialReshapeOperand(instruction);
|
||||
// If there are no non-trivial reshapes or transposes, then there is nothing
|
||||
@ -216,7 +216,7 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
|
||||
<< "\n\tnew elementwise shape: "
|
||||
<< ShapeUtil::HumanString(new_elementwise_shape);
|
||||
|
||||
std::vector<HloInstruction*> operands = instruction->operands();
|
||||
auto operands = instruction->operands();
|
||||
for (size_t i = 0; i < operands.size(); ++i) {
|
||||
// All scalar operands remain as-is, even if they're reshape or transpose,
|
||||
// to simplify handling wrt special scalar broadcast rules for ops like
|
||||
|
@ -1211,6 +1211,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
|
||||
handle_status = computation->AddBatchNormTrainingInstruction(
|
||||
arg->batch_norm_training_request());
|
||||
break;
|
||||
case OpRequest::kBatchNormInferenceRequest:
|
||||
handle_status = computation->AddBatchNormInferenceInstruction(
|
||||
arg->batch_norm_inference_request());
|
||||
break;
|
||||
case OpRequest::kBatchNormGradRequest:
|
||||
handle_status = computation->AddBatchNormGradInstruction(
|
||||
arg->batch_norm_grad_request());
|
||||
|
@ -885,6 +885,150 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
output_shape_for_mean_and_var});
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
|
||||
const Shape& operand_shape, const Shape& offset_shape,
|
||||
const Shape& scale_shape, const Shape& mean_shape,
|
||||
const Shape& variance_shape, int64 feature_index) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
|
||||
offset_shape, "offset input of batch norm inference"));
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
|
||||
scale_shape, "scale input of batch norm inference"));
|
||||
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) ==
|
||||
tensorflow::Status::OK());
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) ==
|
||||
tensorflow::Status::OK());
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) ==
|
||||
tensorflow::Status::OK());
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShape(mean_shape) ==
|
||||
tensorflow::Status::OK());
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShape(variance_shape) ==
|
||||
tensorflow::Status::OK());
|
||||
|
||||
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"Expected feature_index of batch-norm-inference to be "
|
||||
"smaller than the rank of operand_shape; "
|
||||
"got feature_index %lld, and rank %lld",
|
||||
feature_index, ShapeUtil::Rank(operand_shape));
|
||||
}
|
||||
|
||||
if (feature_index < 0) {
|
||||
return InvalidArgument(
|
||||
"Expected feature_index of batch-norm-inference to "
|
||||
"be a non-negative number, got %lld",
|
||||
feature_index);
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(operand_shape) < 1) {
|
||||
return InvalidArgument(
|
||||
"Expected the rank of operand to "
|
||||
"batch-norm-inference to be at least 1; got %lld",
|
||||
ShapeUtil::Rank(operand_shape));
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(offset_shape) != 1) {
|
||||
return InvalidArgument(
|
||||
"Offset input of batch-norm-inference must have"
|
||||
" rank 1, but has rank %lld.",
|
||||
ShapeUtil::Rank(offset_shape));
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(scale_shape) != 1) {
|
||||
return InvalidArgument(
|
||||
"Scale input of batch-norm-inference must have"
|
||||
" rank 1, but has rank %lld.",
|
||||
ShapeUtil::Rank(scale_shape));
|
||||
}
|
||||
|
||||
if (!ShapeUtil::ElementIsFloating(operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The operand to batch-norm-inference must have a floating point "
|
||||
"element type, but the shape is %s",
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
"but the shape of offset factor is %s "
|
||||
"and the shape of operand is %s",
|
||||
PrimitiveType_Name(offset_shape.element_type()).c_str(),
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
"but the shape of scale factor is %s "
|
||||
"and the shape of operand is %s",
|
||||
PrimitiveType_Name(scale_shape.element_type()).c_str(),
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
"but the shape of mean is %s "
|
||||
"and the shape of operand is %s",
|
||||
PrimitiveType_Name(mean_shape.element_type()).c_str(),
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
"but the shape of variance is %s "
|
||||
"and the shape of operand is %s",
|
||||
PrimitiveType_Name(mean_shape.element_type()).c_str(),
|
||||
PrimitiveType_Name(variance_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
const int64 feature_count = operand_shape.dimensions(feature_index);
|
||||
Shape output_shape_for_mean_and_var =
|
||||
ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
|
||||
|
||||
if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
|
||||
return InvalidArgument(
|
||||
"The size of offset factor should be the same as feature count,"
|
||||
"but the size of offset factor is %lld "
|
||||
"and the feature count is %lld",
|
||||
ShapeUtil::GetDimension(offset_shape, 0), feature_count);
|
||||
}
|
||||
|
||||
if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
|
||||
return InvalidArgument(
|
||||
"The size of scale factor should be the same as feature count,"
|
||||
"but the size of scale factor is %lld "
|
||||
"and the feature count is %lld",
|
||||
ShapeUtil::GetDimension(scale_shape, 0), feature_count);
|
||||
}
|
||||
|
||||
if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
|
||||
return InvalidArgument(
|
||||
"The size of mean should be the same as feature count,"
|
||||
"but the size of mean is %lld "
|
||||
"and the feature count is %lld",
|
||||
ShapeUtil::GetDimension(mean_shape, 0), feature_count);
|
||||
}
|
||||
|
||||
if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
|
||||
return InvalidArgument(
|
||||
"The size of variance should be the same as feature count,"
|
||||
"but the size of variance is %lld "
|
||||
"and the feature count is %lld",
|
||||
ShapeUtil::GetDimension(variance_shape, 0), feature_count);
|
||||
}
|
||||
|
||||
return operand_shape;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
|
||||
const Shape& operand_shape, const Shape& scale_shape,
|
||||
const Shape& mean_shape, const Shape& var_shape,
|
||||
|
@ -71,6 +71,13 @@ class ShapeInference {
|
||||
const Shape& scale_shape,
|
||||
int64 feature_index);
|
||||
|
||||
// Infers the shape produced by InferBatchNormInference with the given
|
||||
// operands.
|
||||
static StatusOr<Shape> InferBatchNormInferenceShape(
|
||||
const Shape& operand_shape, const Shape& offset_shape,
|
||||
const Shape& scale_shape, const Shape& mean_shape,
|
||||
const Shape& variance_shape, int64 feature_index);
|
||||
|
||||
// Infers the shape produced by InferBatchNormGrad with the given operands.
|
||||
static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
|
||||
const Shape& scale_shape,
|
||||
|
@ -46,8 +46,7 @@ std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
|
||||
bool PointsToSet::IsAmbiguous() const {
|
||||
bool ambiguous = false;
|
||||
ForEachElement(
|
||||
[&ambiguous](const ShapeIndex& /*index*/,
|
||||
const std::vector<const LogicalBuffer*>& points_to) {
|
||||
[&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) {
|
||||
ambiguous |= points_to.size() > 1;
|
||||
});
|
||||
return ambiguous;
|
||||
@ -56,9 +55,8 @@ bool PointsToSet::IsAmbiguous() const {
|
||||
bool PointsToSet::IsDistinct() const {
|
||||
bool distinct = true;
|
||||
std::set<const LogicalBuffer*> all_points_to;
|
||||
ForEachElement([&distinct, &all_points_to](
|
||||
const ShapeIndex& /*index*/,
|
||||
const std::vector<const LogicalBuffer*>& points_to) {
|
||||
ForEachElement([&distinct, &all_points_to](const ShapeIndex& /*index*/,
|
||||
const BufferList& points_to) {
|
||||
for (auto& buffer : points_to) {
|
||||
if (all_points_to.count(buffer) != 0) {
|
||||
distinct = false;
|
||||
@ -75,34 +73,31 @@ size_t PointsToSet::size() const {
|
||||
return CreateFlattenedSet().size();
|
||||
}
|
||||
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> PointsToSet::CreateFlattenedSet()
|
||||
const {
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> flat_set;
|
||||
ForEachElement([&flat_set](const ShapeIndex& /*index*/,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
flat_set.insert(buffers.begin(), buffers.end());
|
||||
});
|
||||
PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const {
|
||||
BufferSet flat_set;
|
||||
ForEachElement(
|
||||
[&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) {
|
||||
flat_set.insert(buffers.begin(), buffers.end());
|
||||
});
|
||||
return flat_set;
|
||||
}
|
||||
|
||||
bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
|
||||
bool found = false;
|
||||
ForEachElement(
|
||||
[&found, &buffer](
|
||||
const ShapeIndex& /*index*/,
|
||||
const std::vector<const LogicalBuffer*>& pointed_to_buffers) {
|
||||
if (!found &&
|
||||
std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
|
||||
&buffer) != pointed_to_buffers.end()) {
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
|
||||
const BufferList& pointed_to_buffers) {
|
||||
if (!found &&
|
||||
std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
|
||||
&buffer) != pointed_to_buffers.end()) {
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
|
||||
const ShapeIndex& index) const {
|
||||
const std::vector<const LogicalBuffer*>& pointed_to_buffers = element(index);
|
||||
const auto& pointed_to_buffers = element(index);
|
||||
return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
|
||||
&buffer) != pointed_to_buffers.end();
|
||||
}
|
||||
@ -115,14 +110,14 @@ void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
|
||||
mutable_element(index)->push_back(&buffer);
|
||||
}
|
||||
|
||||
const std::set<HloInstruction*>& PointsToSet::tuple_sources(
|
||||
const PointsToSet::SourceSet& PointsToSet::tuple_sources(
|
||||
const ShapeIndex& index) const {
|
||||
return tuple_sources_.element(index);
|
||||
return tree_.element(index).tuple_sources;
|
||||
}
|
||||
|
||||
void PointsToSet::add_tuple_source(const ShapeIndex& index,
|
||||
HloInstruction* tuple) {
|
||||
tuple_sources_.mutable_element(index)->insert(tuple);
|
||||
tree_.mutable_element(index)->tuple_sources.insert(tuple);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
|
||||
@ -177,7 +172,7 @@ Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(
|
||||
points_to_set.ForEachElement(
|
||||
[this, &instruction](
|
||||
const ShapeIndex& index,
|
||||
const std::vector<const LogicalBuffer*>& pointed_to_buffers) {
|
||||
const PointsToSet::BufferList& pointed_to_buffers) {
|
||||
for (const LogicalBuffer* buffer : pointed_to_buffers) {
|
||||
PerBuffer(buffer->id())
|
||||
->buffer_aliases.emplace_back(instruction.get(), index);
|
||||
@ -205,7 +200,7 @@ Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
|
||||
PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
|
||||
points_to_set.ForEachMutableElement(
|
||||
[this, hlo_instruction](const ShapeIndex& index,
|
||||
std::vector<const LogicalBuffer*>* buffers) {
|
||||
PointsToSet::BufferList* buffers) {
|
||||
const LogicalBuffer& buffer = NewLogicalBuffer(hlo_instruction, index);
|
||||
buffers->push_back(&buffer);
|
||||
});
|
||||
@ -232,7 +227,7 @@ Status TuplePointsToAnalysis::HandleGetTupleElement(
|
||||
// operand to the points-to set for this GetTupleElement instruction.
|
||||
points_to_set.ForEachMutableElement(
|
||||
[&, this](const ShapeIndex& target_index,
|
||||
std::vector<const LogicalBuffer*>* points_to) {
|
||||
PointsToSet::BufferList* points_to) {
|
||||
// Construct an index into the operand by prepending element_index to
|
||||
// the index for the GetTupleElement instruction's points-to set.
|
||||
ShapeIndex src_index;
|
||||
@ -289,7 +284,7 @@ Status TuplePointsToAnalysis::HandleTuple(
|
||||
operand_points_to_set.ForEachElement(
|
||||
[&points_to_set, &operand_points_to_set, i](
|
||||
const ShapeIndex& src_index,
|
||||
const std::vector<const LogicalBuffer*>& points_to) {
|
||||
const PointsToSet::BufferList& points_to) {
|
||||
ShapeIndex target_index;
|
||||
target_index.push_back(i);
|
||||
for (auto element : src_index) {
|
||||
@ -324,7 +319,7 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select,
|
||||
PointsToSet& points_to_set = CreateCopiedPointsToSet(select, on_true);
|
||||
const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set;
|
||||
points_to_set.ForEachMutableElement(
|
||||
[&](const ShapeIndex& index, std::vector<const LogicalBuffer*>* buffers) {
|
||||
[&](const ShapeIndex& index, PointsToSet::BufferList* buffers) {
|
||||
for (const LogicalBuffer* false_buffer :
|
||||
false_points_to_set.element(index)) {
|
||||
points_to_set.AddPointedToBuffer(*false_buffer, index);
|
||||
@ -361,8 +356,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
|
||||
|
||||
bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
|
||||
const HloInstruction* instruction, const ShapeIndex& index) const {
|
||||
const std::vector<const LogicalBuffer*>& buffers =
|
||||
GetPointsToSet(instruction).element(index);
|
||||
const auto& buffers = GetPointsToSet(instruction).element(index);
|
||||
return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
|
||||
}
|
||||
|
||||
@ -398,8 +392,7 @@ const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
|
||||
|
||||
StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
|
||||
const HloInstruction* instruction, const ShapeIndex& index) const {
|
||||
const std::vector<const LogicalBuffer*>& buffers =
|
||||
GetPointsToSet(instruction).element(index);
|
||||
const auto& buffers = GetPointsToSet(instruction).element(index);
|
||||
if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
|
||||
return FailedPrecondition(
|
||||
"instruction %s does not define buffer at index {%s}",
|
||||
@ -424,27 +417,26 @@ Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
|
||||
const HloInstruction* instruction,
|
||||
TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
|
||||
GetPointsToSet(instruction)
|
||||
.ForEachElement(
|
||||
[this, buffers, instruction](
|
||||
const ShapeIndex& index,
|
||||
const std::vector<const LogicalBuffer*>& source_buffers) {
|
||||
// Add buffers which 'instruction' is the source of.
|
||||
CHECK(!source_buffers.empty());
|
||||
if (source_buffers.size() == 1 &&
|
||||
source_buffers[0]->instruction() == instruction) {
|
||||
// If this instruction is the source of this buffer the
|
||||
// indices must match.
|
||||
DCHECK(source_buffers[0]->index() == index);
|
||||
buffers->push_back(source_buffers[0]);
|
||||
} else {
|
||||
// If the points-to set includes more than one buffer then
|
||||
// necessarily this instruction did not produce the
|
||||
// buffer.
|
||||
for (const LogicalBuffer* source_buffer : source_buffers) {
|
||||
DCHECK(source_buffer->instruction() != instruction);
|
||||
}
|
||||
}
|
||||
});
|
||||
.ForEachElement([this, buffers, instruction](
|
||||
const ShapeIndex& index,
|
||||
const PointsToSet::BufferList& source_buffers) {
|
||||
// Add buffers which 'instruction' is the source of.
|
||||
CHECK(!source_buffers.empty());
|
||||
if (source_buffers.size() == 1 &&
|
||||
source_buffers[0]->instruction() == instruction) {
|
||||
// If this instruction is the source of this buffer the
|
||||
// indices must match.
|
||||
DCHECK(source_buffers[0]->index() == index);
|
||||
buffers->push_back(source_buffers[0]);
|
||||
} else {
|
||||
// If the points-to set includes more than one buffer then
|
||||
// necessarily this instruction did not produce the
|
||||
// buffer.
|
||||
for (const LogicalBuffer* source_buffer : source_buffers) {
|
||||
DCHECK(source_buffer->instruction() != instruction);
|
||||
}
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -456,7 +448,7 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
|
||||
const PointsToSet& src_points_to_set = GetPointsToSet(src);
|
||||
dst_points_to_set.ForEachMutableElement(
|
||||
[this, &dst_points_to_set, &src_points_to_set](
|
||||
const ShapeIndex& index, std::vector<const LogicalBuffer*>* buffers) {
|
||||
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
|
||||
*buffers = src_points_to_set.element(index);
|
||||
for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
|
||||
dst_points_to_set.add_tuple_source(index, tuple_source);
|
||||
@ -505,19 +497,18 @@ void TuplePointsToAnalysis::InstructionToString(
|
||||
tensorflow::strings::StrAppend(output, prefix, " instruction ",
|
||||
instruction->ToShortString(), ":\n");
|
||||
const PointsToSet& points_to_set = GetPointsToSet(instruction);
|
||||
points_to_set.ForEachElement(
|
||||
[&prefix, &output](const ShapeIndex& index,
|
||||
const std::vector<const LogicalBuffer*>& points_to) {
|
||||
tensorflow::strings::StrAppend(
|
||||
output, prefix, " {", tensorflow::str_util::Join(index, ","),
|
||||
"}: ",
|
||||
tensorflow::str_util::Join(
|
||||
points_to, ", ",
|
||||
[](string* out, const LogicalBuffer* source) {
|
||||
out->append(source->ToString());
|
||||
}),
|
||||
"\n");
|
||||
});
|
||||
points_to_set.ForEachElement([&prefix, &output](
|
||||
const ShapeIndex& index,
|
||||
const PointsToSet::BufferList& points_to) {
|
||||
tensorflow::strings::StrAppend(
|
||||
output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ",
|
||||
tensorflow::str_util::Join(
|
||||
points_to, ", ",
|
||||
[](string* out, const LogicalBuffer* source) {
|
||||
out->append(source->ToString());
|
||||
}),
|
||||
"\n");
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/compactptrset.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -46,14 +47,12 @@ namespace xla {
|
||||
// nested tuple). Each node in this tree corresponds to a single buffer in the
|
||||
// instruction's output and contains the set of Buffers which might define
|
||||
// the corresponding buffer.
|
||||
class PointsToSet : public ShapeTree<std::vector<const LogicalBuffer*>> {
|
||||
class PointsToSet {
|
||||
public:
|
||||
// Construct our ShapeTree with a pointer rather than a reference to a Shape
|
||||
// because this is very hot code, and copying (and then destroying) all these
|
||||
// Shapes is slow.
|
||||
explicit PointsToSet(const Shape* shape)
|
||||
: ShapeTree<std::vector<const LogicalBuffer*>>(shape),
|
||||
tuple_sources_(shape) {}
|
||||
explicit PointsToSet(const Shape* shape) : tree_(shape) {}
|
||||
|
||||
// Returns true if any points-to sets for any subshape element is not a
|
||||
// singleton.
|
||||
@ -69,7 +68,8 @@ class PointsToSet : public ShapeTree<std::vector<const LogicalBuffer*>> {
|
||||
|
||||
// Creates a set containing the union of all LogicalBuffers contained in the
|
||||
// PointsToSet.
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> CreateFlattenedSet() const;
|
||||
using BufferSet = tensorflow::gtl::CompactPointerSet<const LogicalBuffer*>;
|
||||
BufferSet CreateFlattenedSet() const;
|
||||
|
||||
// Returns true if the given buffer is in the points-to set at the given
|
||||
// index.
|
||||
@ -102,13 +102,49 @@ class PointsToSet : public ShapeTree<std::vector<const LogicalBuffer*>> {
|
||||
// tuple_sources() at the index of an array shape (not a tuple) returns the
|
||||
// empty set. The instructions in the set returned by tuple_sources
|
||||
// necessarily are either Tuple instructions, constants, or parameters.
|
||||
const std::set<HloInstruction*>& tuple_sources(const ShapeIndex& index) const;
|
||||
using SourceSet = tensorflow::gtl::CompactPointerSet<HloInstruction*>;
|
||||
const SourceSet& tuple_sources(const ShapeIndex& index) const;
|
||||
|
||||
// Add a tuple source instruction for the given index.
|
||||
void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple);
|
||||
|
||||
using BufferList = tensorflow::gtl::InlinedVector<const LogicalBuffer*, 1>;
|
||||
|
||||
// Return the list of logical buffers for the subshape at index.
|
||||
const BufferList& element(const ShapeIndex& index) const {
|
||||
return tree_.element(index).buffers;
|
||||
}
|
||||
BufferList* mutable_element(const ShapeIndex& index) {
|
||||
return &tree_.mutable_element(index)->buffers;
|
||||
}
|
||||
|
||||
// Call fn(index, buflist) for every subshape index.
|
||||
template <typename Fn>
|
||||
void ForEachElement(const Fn& fn) const {
|
||||
tree_.ForEachElement([&fn](const ShapeIndex& index, const Elem& elem) {
|
||||
fn(index, elem.buffers);
|
||||
});
|
||||
}
|
||||
template <typename Fn>
|
||||
void ForEachMutableElement(const Fn& fn) {
|
||||
tree_.ForEachMutableElement([&fn](const ShapeIndex& index, Elem* elem) {
|
||||
fn(index, &elem->buffers);
|
||||
});
|
||||
}
|
||||
template <typename Fn>
|
||||
Status ForEachElementWithStatus(const Fn& fn) const {
|
||||
return tree_.ForEachElementWithStatus(
|
||||
[&fn](const ShapeIndex& index, const Elem& elem) {
|
||||
return fn(index, elem.buffers);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
ShapeTree<std::set<HloInstruction*>> tuple_sources_;
|
||||
struct Elem {
|
||||
BufferList buffers;
|
||||
SourceSet tuple_sources;
|
||||
};
|
||||
ShapeTree<Elem> tree_;
|
||||
|
||||
// PointsToSet contains references (const LogicalBuffer*) to elements within
|
||||
// TuplePointsToAnalysis so disable copying.
|
||||
|
@ -35,8 +35,8 @@ namespace op = xla::testing::opcode_matchers;
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::UnorderedElementsAreArray;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
using ::testing::UnorderedElementsAreArray;
|
||||
|
||||
class TuplePointsToAnalysisTest : public HloTestBase {
|
||||
protected:
|
||||
@ -62,7 +62,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
|
||||
// index. CHECKs if no buffer is defined at that point.
|
||||
const LogicalBuffer* const GetBuffer(const HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
const std::vector<const LogicalBuffer*>& pointed_to =
|
||||
const auto& pointed_to =
|
||||
points_to_analysis_->GetPointsToSet(instruction).element(index);
|
||||
CHECK_EQ(1, pointed_to.size());
|
||||
CHECK_EQ(instruction, pointed_to[0]->instruction());
|
||||
@ -73,7 +73,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
|
||||
// Checks that the given points-to set contains exactly (unordered) the given
|
||||
// LogicalBuffers.
|
||||
void ExpectHasBuffers(
|
||||
const std::vector<const LogicalBuffer*>& points_to_set,
|
||||
const PointsToSet::BufferList& points_to_set,
|
||||
tensorflow::gtl::ArraySlice<const LogicalBuffer*> buffers) {
|
||||
std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
|
||||
EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec));
|
||||
@ -82,22 +82,22 @@ class TuplePointsToAnalysisTest : public HloTestBase {
|
||||
// Checks that the given points-to set contains exactly (unordered) the
|
||||
// top-level buffers of the given instructions.
|
||||
void ExpectHasTopLevelBuffers(
|
||||
const std::vector<const LogicalBuffer*>& points_to_set,
|
||||
const PointsToSet::BufferList& points_to_set,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
|
||||
std::vector<const LogicalBuffer*> buffers;
|
||||
PointsToSet::BufferList buffers;
|
||||
for (auto instruction : instructions) {
|
||||
buffers.push_back(GetBuffer(instruction, /*index=*/{}));
|
||||
}
|
||||
ExpectHasBuffers(points_to_set, buffers);
|
||||
}
|
||||
|
||||
// Overload which takes a std::set instead of a std::vector.
|
||||
// Overload which takes a set instead of a vector.
|
||||
void ExpectHasTopLevelBuffers(
|
||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& points_to_set,
|
||||
const PointsToSet::BufferSet& points_to_set,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
|
||||
ExpectHasTopLevelBuffers(std::vector<const LogicalBuffer*>(
|
||||
points_to_set.begin(), points_to_set.end()),
|
||||
instructions);
|
||||
ExpectHasTopLevelBuffers(
|
||||
PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()),
|
||||
instructions);
|
||||
}
|
||||
|
||||
// Checks that the buffer defined at the given instruction and index has
|
||||
|
@ -507,6 +507,53 @@ UserComputation::AddBatchNormTrainingInstruction(
|
||||
return handle;
|
||||
}
|
||||
|
||||
StatusOr<ComputationDataHandle>
|
||||
UserComputation::AddBatchNormInferenceInstruction(
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
|
||||
LookUpRequest(batch_norm_inference_request.operand()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
|
||||
LookUpRequest(batch_norm_inference_request.scale()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* offset,
|
||||
LookUpRequest(batch_norm_inference_request.offset()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* mean,
|
||||
LookUpRequest(batch_norm_inference_request.mean()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* variance,
|
||||
LookUpRequest(batch_norm_inference_request.variance()));
|
||||
|
||||
ComputationDataHandle handle = CreateComputationDataHandle();
|
||||
|
||||
OperationRequest& request =
|
||||
(*session_computation_.mutable_requests())[handle.handle()];
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
|
||||
ShapeInference::InferBatchNormInferenceShape(
|
||||
operand->output_shape(), scale->output_shape(),
|
||||
offset->output_shape(), mean->output_shape(),
|
||||
variance->output_shape(),
|
||||
batch_norm_inference_request.feature_index()));
|
||||
|
||||
*request.mutable_output_shape() = inferred_shape;
|
||||
|
||||
*request.mutable_output_handle() = handle;
|
||||
|
||||
*request.mutable_request()->mutable_batch_norm_inference_request() =
|
||||
batch_norm_inference_request;
|
||||
|
||||
VLOG(1) << "AddBatchNormInferenceInstruction ("
|
||||
<< GetVersionedHandleInternal() << "), data handle "
|
||||
<< handle.handle() << ": "
|
||||
<< batch_norm_inference_request.ShortDebugString();
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
StatusOr<ComputationDataHandle> UserComputation::AddBatchNormGradInstruction(
|
||||
const BatchNormGradRequest& batch_norm_grad_request) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
@ -1678,6 +1725,25 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormInferenceRequest: {
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request =
|
||||
request.request().batch_norm_inference_request();
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_inference_request.operand(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_inference_request.scale(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_inference_request.offset(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_inference_request.mean(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_inference_request.variance(), visited,
|
||||
is_constant);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormGradRequest: {
|
||||
const BatchNormGradRequest& batch_norm_grad_request =
|
||||
request.request().batch_norm_grad_request();
|
||||
@ -2119,6 +2185,18 @@ static void ForEachOperand(
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormInferenceRequest: {
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request =
|
||||
request.request().batch_norm_inference_request();
|
||||
|
||||
apply(batch_norm_inference_request.operand());
|
||||
apply(batch_norm_inference_request.scale());
|
||||
apply(batch_norm_inference_request.offset());
|
||||
apply(batch_norm_inference_request.mean());
|
||||
apply(batch_norm_inference_request.variance());
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormGradRequest: {
|
||||
const BatchNormGradRequest& batch_norm_grad_request =
|
||||
request.request().batch_norm_grad_request();
|
||||
@ -2647,6 +2725,28 @@ void ComputationLowerer::Visit(
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormInferenceRequest: {
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request =
|
||||
request.request().batch_norm_inference_request();
|
||||
HloInstruction* operand =
|
||||
lookup_instruction(batch_norm_inference_request.operand());
|
||||
HloInstruction* scale =
|
||||
lookup_instruction(batch_norm_inference_request.scale());
|
||||
HloInstruction* offset =
|
||||
lookup_instruction(batch_norm_inference_request.offset());
|
||||
HloInstruction* mean =
|
||||
lookup_instruction(batch_norm_inference_request.mean());
|
||||
HloInstruction* variance =
|
||||
lookup_instruction(batch_norm_inference_request.variance());
|
||||
|
||||
hlo_instruction =
|
||||
add_instruction(HloInstruction::CreateBatchNormInference(
|
||||
request.output_shape(), operand, scale, offset, mean, variance,
|
||||
batch_norm_inference_request.epsilon(),
|
||||
batch_norm_inference_request.feature_index()));
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormGradRequest: {
|
||||
const BatchNormGradRequest& batch_norm_grad_request =
|
||||
request.request().batch_norm_grad_request();
|
||||
|
@ -89,6 +89,10 @@ class UserComputation {
|
||||
StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction(
|
||||
const BatchNormTrainingRequest& batch_norm_training_request);
|
||||
|
||||
// Enqueues a batch norm inference instruction onto this user computation.
|
||||
StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction(
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request);
|
||||
|
||||
// Enqueues a batch norm grad instruction onto this user computation.
|
||||
StatusOr<ComputationDataHandle> AddBatchNormGradInstruction(
|
||||
const BatchNormGradRequest& batch_norm_grad_request);
|
||||
|
@ -115,25 +115,16 @@ class ShapeTree {
|
||||
ShapeTree(Shape shape, const T& init_value);
|
||||
ShapeTree(const Shape* shape, const T& init_value);
|
||||
|
||||
ShapeTree(const ShapeTree& other)
|
||||
: root_(other.root_), shape_storage_(other.shape_storage_) {
|
||||
// Fix up internal pointer if necessary.
|
||||
if (shape_storage_) {
|
||||
CHECK_EQ(other.shape_, &*other.shape_storage_);
|
||||
shape_ = &*shape_storage_;
|
||||
} else {
|
||||
shape_ = other.shape_;
|
||||
}
|
||||
}
|
||||
ShapeTree(const ShapeTree& other) { *this = other; }
|
||||
|
||||
ShapeTree& operator=(const ShapeTree& other) {
|
||||
root_ = other.root_;
|
||||
shape_storage_ = other.shape_storage_;
|
||||
|
||||
// Fix up internal pointer if necessary.
|
||||
if (shape_storage_) {
|
||||
CHECK_EQ(other.shape_, &*other.shape_storage_);
|
||||
shape_ = &*shape_storage_;
|
||||
if (other.shape_storage_) {
|
||||
CHECK_EQ(other.shape_, other.shape_storage_.get());
|
||||
shape_storage_.reset(new Shape(*other.shape_));
|
||||
shape_ = shape_storage_.get();
|
||||
} else {
|
||||
shape_ = other.shape_;
|
||||
}
|
||||
@ -259,11 +250,11 @@ class ShapeTree {
|
||||
Node root_;
|
||||
|
||||
// If we own our Shape, this field contains it, and shape_ is a pointer into
|
||||
// here. Otherwise if we don't own our shape, this is nullopt.
|
||||
tensorflow::gtl::optional<Shape> shape_storage_;
|
||||
// here. Otherwise if we don't own our shape, this is nullptr.
|
||||
std::unique_ptr<Shape> shape_storage_;
|
||||
|
||||
// The XLA shape mirrored in this ShapeTree. This is either a pointer into
|
||||
// shape_storage_ or the Shape pointer passed to our constructor.
|
||||
// The XLA shape mirrored in this ShapeTree. This is either
|
||||
// shape_storage_.get() or the Shape pointer passed to our constructor.
|
||||
const Shape* shape_;
|
||||
};
|
||||
|
||||
@ -401,10 +392,12 @@ void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
|
||||
|
||||
template <typename T>
|
||||
ShapeTree<T>::ShapeTree(Shape shape)
|
||||
: root_(), shape_storage_(std::move(shape)), shape_(&*shape_storage_) {
|
||||
: root_(),
|
||||
shape_storage_(MakeUnique<Shape>(std::move(shape))),
|
||||
shape_(shape_storage_.get()) {
|
||||
// The shape_ field is just used to hold the structure of the shape.
|
||||
// It should not be relied upon to store layout information.
|
||||
LayoutUtil::ClearLayout(&*shape_storage_);
|
||||
LayoutUtil::ClearLayout(shape_storage_.get());
|
||||
InitChildren(*shape_, &root_);
|
||||
}
|
||||
|
||||
@ -416,11 +409,11 @@ ShapeTree<T>::ShapeTree(const Shape* shape) : root_(), shape_(shape) {
|
||||
template <typename T>
|
||||
ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
|
||||
: root_(init_value),
|
||||
shape_storage_(std::move(shape)),
|
||||
shape_(&*shape_storage_) {
|
||||
shape_storage_(MakeUnique<Shape>(std::move(shape))),
|
||||
shape_(shape_storage_.get()) {
|
||||
// The shape_ field is just used to hold the structure of the shape.
|
||||
// It should not be relied upon to store layout information.
|
||||
LayoutUtil::ClearLayout(&*shape_storage_);
|
||||
LayoutUtil::ClearLayout(shape_storage_.get());
|
||||
InitChildren(*shape_, init_value, &root_);
|
||||
}
|
||||
|
||||
|
@ -425,13 +425,42 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) {
|
||||
HumanString(program_shape.result()));
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(const string& s) {
|
||||
namespace {
|
||||
// Parses shapes with simple recursive descent structure -- consumes from the
|
||||
// front of s and passes that view recursively as required.
|
||||
StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
tensorflow::str_util::RemoveLeadingWhitespace(s);
|
||||
|
||||
if (s->Consume("(")) { // Tuple.
|
||||
std::vector<Shape> shapes;
|
||||
bool must_end = false;
|
||||
while (true) {
|
||||
if (s->Consume(")")) {
|
||||
break;
|
||||
} else if (must_end) {
|
||||
return InvalidArgument("Expected end of tuple; got: \"%s\"",
|
||||
s->ToString().c_str());
|
||||
}
|
||||
shapes.emplace_back();
|
||||
TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s));
|
||||
tensorflow::str_util::RemoveLeadingWhitespace(s);
|
||||
must_end = !s->Consume(",");
|
||||
}
|
||||
return ShapeUtil::MakeTupleShape(shapes);
|
||||
}
|
||||
|
||||
string element_type_string;
|
||||
string dimensions_string;
|
||||
string layout_string;
|
||||
if (RE2::FullMatch(s, "([fsu]32)\\[([\\d,]*)\\](?: {([\\d,]*)})?",
|
||||
&element_type_string, &dimensions_string,
|
||||
&layout_string)) {
|
||||
// tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so
|
||||
// we convert in to the RE2-consumable type and then consume the corresponding
|
||||
// amount from our StringPiece type.
|
||||
tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
|
||||
if (RE2::Consume(&s_consumable,
|
||||
"^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*{([\\d,]*)})?",
|
||||
&element_type_string, &dimensions_string, &layout_string)) {
|
||||
size_t consumed = s->size() - s_consumable.size();
|
||||
s->remove_prefix(consumed);
|
||||
auto comma_list_to_int64s =
|
||||
[&s](const string& input) -> StatusOr<std::vector<int64>> {
|
||||
std::vector<int64> results;
|
||||
@ -439,39 +468,58 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) {
|
||||
int64 element;
|
||||
if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) {
|
||||
return InvalidArgument(
|
||||
"invalid value in parsed shape string: \"%s\" in \"%s\"",
|
||||
piece.c_str(), s.c_str());
|
||||
"Invalid s64 value in parsed shape string: \"%s\" in \"%s\"",
|
||||
piece.c_str(), s->ToString().c_str());
|
||||
}
|
||||
results.push_back(element);
|
||||
}
|
||||
return results;
|
||||
};
|
||||
|
||||
// Extract the dimensions.
|
||||
TF_ASSIGN_OR_RETURN(std::vector<int64> dimensions,
|
||||
comma_list_to_int64s(dimensions_string));
|
||||
PrimitiveType primitive_type;
|
||||
if (element_type_string == "f32") {
|
||||
primitive_type = F32;
|
||||
} else if (element_type_string == "s32") {
|
||||
primitive_type = S32;
|
||||
} else if (element_type_string == "u32") {
|
||||
primitive_type = U32;
|
||||
} else {
|
||||
LOG(FATAL) << "unhandled element type string: " << element_type_string;
|
||||
|
||||
// Extract the primitive element type.
|
||||
PrimitiveType primitive_type = PRIMITIVE_TYPE_INVALID;
|
||||
for (PrimitiveType i =
|
||||
static_cast<PrimitiveType>(PRIMITIVE_TYPE_INVALID + 1);
|
||||
i < TUPLE; i = static_cast<PrimitiveType>(i + 1)) {
|
||||
if (tensorflow::str_util::Lowercase(PrimitiveType_Name(i)) ==
|
||||
element_type_string) {
|
||||
primitive_type = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (primitive_type == PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("Invalid element type string: \"%s\".",
|
||||
element_type_string.c_str());
|
||||
}
|
||||
|
||||
Shape result;
|
||||
if (layout_string.empty()) {
|
||||
result = MakeShape(primitive_type, dimensions);
|
||||
// Create a shape without a layout set.
|
||||
result = ShapeUtil::MakeShape(primitive_type, dimensions);
|
||||
} else {
|
||||
// Extract the layout minor-to-major and set it.
|
||||
TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
|
||||
comma_list_to_int64s(layout_string));
|
||||
TF_RET_CHECK(dimensions.size() == min2maj.size());
|
||||
result = MakeShapeWithLayout(primitive_type, dimensions, min2maj);
|
||||
result =
|
||||
ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj);
|
||||
}
|
||||
TF_DCHECK_OK(ValidateShape(result));
|
||||
return result;
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(result));
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
return InvalidArgument("invalid shape string to parse: \"%s\"", s.c_str());
|
||||
return InvalidArgument("Invalid shape string to parse: \"%s\"",
|
||||
s->ToString().c_str());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(
|
||||
tensorflow::StringPiece s) {
|
||||
return ParseShapeStringInternal(&s);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
|
||||
|
@ -125,7 +125,7 @@ class ShapeUtil {
|
||||
|
||||
// Parses a ShapeUtil::HumanString-format shape string back into a shape
|
||||
// object.
|
||||
static StatusOr<Shape> ParseShapeString(const string& s);
|
||||
static StatusOr<Shape> ParseShapeString(tensorflow::StringPiece s);
|
||||
|
||||
// Returns whether the LHS and RHS shapes have the same dimensions; note: does
|
||||
// not check element type.
|
||||
|
@ -78,6 +78,30 @@ TEST(ShapeUtilTest, ParseShapeStringR2F32) {
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
|
||||
string shape_string = "(f32[1572864],s8[5120,1024])";
|
||||
Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
|
||||
Shape expected =
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
|
||||
ShapeUtil::MakeShape(S8, {5120, 1024})});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
|
||||
string shape_string = "(f32[1],(f32[2]), f32[3])";
|
||||
Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
|
||||
Shape expected = ShapeUtil::MakeTupleShape({
|
||||
ShapeUtil::MakeShape(F32, {1}),
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}),
|
||||
ShapeUtil::MakeShape(F32, {3}),
|
||||
});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
|
||||
Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2});
|
||||
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
|
||||
|
@ -785,6 +785,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
|
||||
&builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
|
||||
SetFastMathDisabled(true);
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<float>({-2.0f, -0.6f, -0.6f, 0.0f});
|
||||
auto rhs = builder.ConstantR1<float>({0.5f, 0.6f, -0.6f, -0.6f});
|
||||
auto minimum = builder.Pow(lhs, rhs);
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<float>({});
|
||||
|
@ -306,6 +306,109 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) {
|
||||
ErrorSpec(0.01, 1));
|
||||
}
|
||||
|
||||
XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) {
|
||||
float epsilon = 0.001;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
const std::vector<int64>& bounds = GetParam().bounds;
|
||||
Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
|
||||
input_array.FillRandom(GetParam().random_value_var,
|
||||
GetParam().random_value_mean);
|
||||
|
||||
const int64 feature_index = GetParam().feature_index;
|
||||
const int64 num_elements_per_feature =
|
||||
Product(bounds) / bounds[feature_index];
|
||||
const int64 feature_bound = bounds[feature_index];
|
||||
std::vector<float> offset(feature_bound, 1);
|
||||
std::vector<float> scale(feature_bound, 2);
|
||||
|
||||
auto input_squared =
|
||||
ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
|
||||
std::vector<int64> reduce_dims;
|
||||
for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
|
||||
if (i != feature_index) {
|
||||
reduce_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto sum =
|
||||
ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
|
||||
[](float a, float b) { return a + b; });
|
||||
|
||||
auto sum_squared =
|
||||
ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
|
||||
[](float a, float b) { return a + b; });
|
||||
|
||||
std::vector<float> mean(feature_bound);
|
||||
|
||||
for (int64 i = 0; i < feature_bound; ++i) {
|
||||
mean[i] = sum[i] / num_elements_per_feature;
|
||||
}
|
||||
|
||||
std::vector<float> mean_square(feature_bound);
|
||||
for (int64 i = 0; i < feature_bound; ++i) {
|
||||
mean_square[i] = mean[i] * mean[i];
|
||||
}
|
||||
|
||||
std::vector<float> square_mean(feature_bound);
|
||||
for (int64 i = 0; i < feature_bound; ++i) {
|
||||
square_mean[i] = sum_squared[i] / num_elements_per_feature;
|
||||
}
|
||||
|
||||
std::vector<float> var(feature_bound);
|
||||
for (int64 i = 0; i < feature_bound; ++i) {
|
||||
var[i] = square_mean[i] - mean_square[i];
|
||||
}
|
||||
|
||||
Array4D<float> mean4D =
|
||||
*ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
|
||||
auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
|
||||
auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
|
||||
auto offset4D =
|
||||
*ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
|
||||
|
||||
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
|
||||
scale4D, offset4D, epsilon);
|
||||
|
||||
auto offset_literal = Literal::CreateR1<float>(offset);
|
||||
auto scale_literal = Literal::CreateR1<float>(scale);
|
||||
auto mean_literal = Literal::CreateR1<float>(mean);
|
||||
auto var_literal = Literal::CreateR1<float>(var);
|
||||
auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
|
||||
|
||||
auto input_activations =
|
||||
builder.Parameter(0, input_literal->shape(), "input");
|
||||
auto scale_activations =
|
||||
builder.Parameter(1, scale_literal->shape(), "offset");
|
||||
auto offset_activations =
|
||||
builder.Parameter(2, offset_literal->shape(), "scale");
|
||||
auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
|
||||
auto variance_activations =
|
||||
builder.Parameter(4, var_literal->shape(), "variance");
|
||||
|
||||
Array4D<float> expected = normalized;
|
||||
|
||||
std::unique_ptr<GlobalData> input_data =
|
||||
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> scale_data =
|
||||
client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> offset_data =
|
||||
client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> mean_data =
|
||||
client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> variance_data =
|
||||
client_->TransferToServer(*var_literal).ConsumeValueOrDie();
|
||||
|
||||
builder.BatchNormInference(input_activations, scale_activations,
|
||||
offset_activations, mean_activations,
|
||||
variance_activations, epsilon, feature_index);
|
||||
|
||||
ComputeAndCompareR4<float>(
|
||||
&builder, expected,
|
||||
{input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
|
||||
variance_data.get()},
|
||||
ErrorSpec(0.01, 1));
|
||||
}
|
||||
|
||||
XLA_TEST_P(BatchNormTest, RandomizedGradTests) {
|
||||
float epsilon = 0.001;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
@ -47,6 +47,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
@ -55,11 +56,16 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
namespace tools {
|
||||
namespace {
|
||||
|
||||
// Invokes the given computation passing arbitrary data for every (unbound)
|
||||
// parameter if use_fake_data, Otherwise use recorded data if available.
|
||||
//
|
||||
// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
|
||||
// otherwise, no infeed is performed.
|
||||
StatusOr<std::unique_ptr<Literal>> ReplayComputation(
|
||||
const SessionModule& module, bool use_fake_data, Client* client) {
|
||||
const SessionModule& module, tensorflow::StringPiece fake_infeed_shape,
|
||||
bool use_fake_data, Client* client) {
|
||||
TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module));
|
||||
|
||||
std::vector<std::unique_ptr<GlobalData>> arguments;
|
||||
@ -74,6 +80,27 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
|
||||
}
|
||||
}
|
||||
|
||||
// We only instantiate the thread pool if the user has requested that a
|
||||
// concurrent infeed occur via the fake_infeed_shape.
|
||||
tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool;
|
||||
|
||||
if (!fake_infeed_shape.empty()) {
|
||||
pool.emplace(tensorflow::Env::Default(), "infeed",
|
||||
/*num_threads=*/1);
|
||||
pool->Schedule([fake_infeed_shape, client]() {
|
||||
StatusOr<Shape> shape_status =
|
||||
ShapeUtil::ParseShapeString(fake_infeed_shape);
|
||||
TF_CHECK_OK(shape_status.status());
|
||||
Shape shape = std::move(shape_status).ValueOrDie();
|
||||
StatusOr<std::unique_ptr<Literal>> data_status = MakeFakeLiteral(shape);
|
||||
TF_CHECK_OK(data_status.status());
|
||||
std::unique_ptr<Literal> data = std::move(data_status).ValueOrDie();
|
||||
while (true) {
|
||||
TF_CHECK_OK(client->TransferToInfeed(*data));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<GlobalData*> execute_arguments;
|
||||
execute_arguments.reserve(arguments.size());
|
||||
for (auto& argument : arguments) {
|
||||
@ -82,17 +109,20 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
|
||||
return client->ExecuteAndTransfer(computation, execute_arguments);
|
||||
}
|
||||
|
||||
void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
|
||||
int RealMain(tensorflow::gtl::ArraySlice<char*> args,
|
||||
tensorflow::StringPiece fake_infeed_shape, bool use_fake_data) {
|
||||
Client* client = ClientLibrary::LocalClientOrDie();
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
int exit_status = EXIT_SUCCESS;
|
||||
for (char* arg : args) {
|
||||
SessionModule module;
|
||||
TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module));
|
||||
StatusOr<std::unique_ptr<Literal>> result_status =
|
||||
ReplayComputation(module, use_fake_data, client);
|
||||
ReplayComputation(module, fake_infeed_shape, use_fake_data, client);
|
||||
if (!result_status.ok()) {
|
||||
fprintf(stderr, "%s: error: %s\n", arg,
|
||||
result_status.status().ToString().c_str());
|
||||
exit_status = EXIT_FAILURE;
|
||||
continue;
|
||||
}
|
||||
std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
|
||||
@ -105,17 +135,22 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
|
||||
Literal(module.result()).ToString().c_str());
|
||||
}
|
||||
}
|
||||
return exit_status;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tools
|
||||
} // namespace xla
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// Flags
|
||||
string fake_infeed_shape;
|
||||
bool use_fake_data = false;
|
||||
const std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("use_fake_data", &use_fake_data,
|
||||
"Replay computation using fake data"),
|
||||
tensorflow::Flag("fake_infeed_shape", &fake_infeed_shape,
|
||||
"Shape of fake data to construct for (infinite) infeed"),
|
||||
};
|
||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
@ -126,6 +161,5 @@ int main(int argc, char** argv) {
|
||||
|
||||
tensorflow::gtl::ArraySlice<char*> args(argv, argc);
|
||||
args.pop_front(); // Pop off the binary name, argv[0]
|
||||
xla::tools::RealMain(args, use_fake_data);
|
||||
return 0;
|
||||
return xla::tools::RealMain(args, fake_infeed_shape, use_fake_data);
|
||||
}
|
||||
|
@ -491,6 +491,16 @@ message BatchNormTrainingRequest {
|
||||
int64 feature_index = 5;
|
||||
}
|
||||
|
||||
message BatchNormInferenceRequest {
|
||||
ComputationDataHandle operand = 1;
|
||||
ComputationDataHandle scale = 2;
|
||||
ComputationDataHandle offset = 3;
|
||||
ComputationDataHandle mean = 4;
|
||||
ComputationDataHandle variance = 5;
|
||||
float epsilon = 6;
|
||||
int64 feature_index = 7;
|
||||
}
|
||||
|
||||
message BatchNormGradRequest {
|
||||
ComputationDataHandle operand = 1;
|
||||
ComputationDataHandle scale = 2;
|
||||
@ -813,7 +823,8 @@ message OpRequest {
|
||||
OutfeedRequest outfeed_request = 32;
|
||||
BatchNormTrainingRequest batch_norm_training_request = 35;
|
||||
BatchNormGradRequest batch_norm_grad_request = 37;
|
||||
// Next: 38
|
||||
BatchNormInferenceRequest batch_norm_inference_request = 38;
|
||||
// Next: 39
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ py_library(
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/contrib/deprecated:deprecated_py",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"//tensorflow/contrib/factorization:factorization_py",
|
||||
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
@ -77,9 +78,8 @@ py_library(
|
||||
"//tensorflow/contrib/text:text_py",
|
||||
"//tensorflow/contrib/tfprof",
|
||||
"//tensorflow/contrib/timeseries",
|
||||
"//tensorflow/contrib/tpu",
|
||||
"//tensorflow/contrib/tpu:tpu_estimator",
|
||||
"//tensorflow/contrib/tpu:tpu_helper_library",
|
||||
"//tensorflow/contrib/tpu:tpu_py",
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
],
|
||||
|
@ -241,6 +241,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops
|
||||
# Broken tensorboard test due to cmake issues.
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker
|
||||
# tensor_forest tests (also note that we exclude the hybrid tests for now)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
|
||||
|
@ -21,6 +21,7 @@ py_test(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:training",
|
||||
@ -28,6 +29,27 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "iterator_ops_cluster_test",
|
||||
size = "small",
|
||||
srcs = ["iterator_ops_cluster_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "batch_dataset_op_test",
|
||||
size = "small",
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
@ -229,6 +231,16 @@ class DatasetConstructorTest(test.TestCase):
|
||||
dtypes.int64), dataset.output_types)
|
||||
self.assertEquals(([], ([], []), []), dataset.output_shapes)
|
||||
|
||||
def testNestedDict(self):
|
||||
components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
|
||||
dataset = dataset_ops.Dataset.from_tensors(components)
|
||||
self.assertEquals(dtypes.int32, dataset.output_types["a"]["aa"])
|
||||
self.assertEquals(dtypes.float32, dataset.output_types["a"]["ab"])
|
||||
self.assertEquals(dtypes.int32, dataset.output_types["b"])
|
||||
self.assertEquals([], dataset.output_shapes["a"]["aa"])
|
||||
self.assertEquals([2], dataset.output_shapes["a"]["ab"])
|
||||
self.assertEquals([3], dataset.output_shapes["b"])
|
||||
|
||||
def testNonSequenceNestedStructure(self):
|
||||
components = np.array([1, 2, 3])
|
||||
|
||||
@ -255,6 +267,214 @@ class DatasetConstructorTest(test.TestCase):
|
||||
self.assertEquals(dtypes.int64, get_next.dtype)
|
||||
self.assertEquals([3], get_next.shape)
|
||||
|
||||
def _testFromGenerator(self, generator, elem_sequence, num_repeats):
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64)
|
||||
.repeat(num_repeats)
|
||||
.prefetch(5)
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
for _ in range(2): # Run twice to test reinitialization.
|
||||
sess.run(init_op)
|
||||
for _ in range(num_repeats):
|
||||
for elem in elem_sequence:
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats):
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64)
|
||||
.repeat(num_repeats)
|
||||
.prefetch(5)
|
||||
.make_one_shot_iterator())
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
for _ in range(num_repeats):
|
||||
for elem in elem_sequence:
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testFromGeneratorUsingFunction(self):
|
||||
def generator():
|
||||
for i in range(1, 100):
|
||||
yield [i] * i
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(generator, elem_sequence, 1)
|
||||
self._testFromGenerator(generator, elem_sequence, 5)
|
||||
self._testFromGeneratorOneShot(generator, elem_sequence, 1)
|
||||
self._testFromGeneratorOneShot(generator, elem_sequence, 5)
|
||||
|
||||
def testFromGeneratorUsingList(self):
|
||||
generator = lambda: [[i] * i for i in range(1, 100)]
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(generator, elem_sequence, 1)
|
||||
self._testFromGenerator(generator, elem_sequence, 5)
|
||||
|
||||
def testFromGeneratorUsingNdarray(self):
|
||||
generator = lambda: np.arange(100, dtype=np.int64)
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(generator, elem_sequence, 1)
|
||||
self._testFromGenerator(generator, elem_sequence, 5)
|
||||
|
||||
def testFromGeneratorUsingGeneratorExpression(self):
|
||||
# NOTE(mrry): Generator *expressions* are not repeatable (or in
|
||||
# general reusable), because they eagerly evaluate the `for`
|
||||
# expression as `iter(range(1, 100))` and discard the means of
|
||||
# reconstructing `range(1, 100)`. Wrapping the generator
|
||||
# expression in a `lambda` makes it repeatable.
|
||||
generator = lambda: ([i] * i for i in range(1, 100))
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(generator, elem_sequence, 1)
|
||||
self._testFromGenerator(generator, elem_sequence, 5)
|
||||
|
||||
def testFromMultipleConcurrentGenerators(self):
|
||||
num_inner_repeats = 5
|
||||
num_outer_repeats = 100
|
||||
|
||||
def generator():
|
||||
for i in range(1, 10):
|
||||
yield ([i] * i, [i, i ** 2, i ** 3])
|
||||
input_list = list(generator())
|
||||
|
||||
# The interleave transformation is essentially a flat map that
|
||||
# draws from multiple input datasets concurrently (in a cyclic
|
||||
# fashion). By placing `Datsaet.from_generator()` inside an
|
||||
# interleave, we test its behavior when multiple iterators are
|
||||
# active at the same time; by additionally prefetching inside the
|
||||
# interleave, we create the possibility of parallel (modulo GIL)
|
||||
# invocations to several iterators created by the same dataset.
|
||||
def interleave_fn(_):
|
||||
return (dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=(dtypes.int64, dtypes.int64),
|
||||
output_shapes=([None], [3]))
|
||||
.repeat(num_inner_repeats).prefetch(5))
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(num_outer_repeats)
|
||||
.interleave(interleave_fn, cycle_length=10,
|
||||
block_length=len(input_list))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(num_inner_repeats * num_outer_repeats):
|
||||
for elem in input_list:
|
||||
val0, val1 = sess.run(get_next)
|
||||
self.assertAllEqual(elem[0], val0)
|
||||
self.assertAllEqual(elem[1], val1)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testFromGeneratorsRunningInParallel(self):
|
||||
num_parallel_iterators = 3
|
||||
|
||||
# Define shared state that multiple iterator instances will access to
|
||||
# demonstrate their concurrent activity.
|
||||
lock = threading.Lock()
|
||||
condition = threading.Condition(lock)
|
||||
next_ticket = [0] # GUARDED_BY(lock)
|
||||
|
||||
def generator():
|
||||
# NOTE(mrry): We yield one element before the barrier, because
|
||||
# the current implementation of `Dataset.interleave()` must
|
||||
# fetch one element from each incoming dataset to start the
|
||||
# prefetching.
|
||||
yield 0
|
||||
|
||||
# Define a barrier that `num_parallel_iterators` iterators must enter
|
||||
# before any can proceed. Demonstrates that multiple iterators may be
|
||||
# active at the same time.
|
||||
condition.acquire()
|
||||
ticket = next_ticket[0]
|
||||
next_ticket[0] += 1
|
||||
if ticket == num_parallel_iterators - 1:
|
||||
# The last iterator to join the barrier notifies the others.
|
||||
condition.notify_all()
|
||||
else:
|
||||
# Wait until the last iterator enters the barrier.
|
||||
while next_ticket[0] < num_parallel_iterators:
|
||||
condition.wait()
|
||||
condition.release()
|
||||
|
||||
yield 1
|
||||
|
||||
# As in `testFromMultipleConcurrentGenerators()`, we use a combination of
|
||||
# `Dataset.interleave()` and `Dataset.prefetch()` to cause multiple
|
||||
# iterators to be active concurrently.
|
||||
def interleave_fn(_):
|
||||
return dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2)
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(num_parallel_iterators)
|
||||
.interleave(
|
||||
interleave_fn, cycle_length=num_parallel_iterators, block_length=1)
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for elem in [0, 1]:
|
||||
for _ in range(num_parallel_iterators):
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testFromGeneratorTypeError(self):
|
||||
def generator():
|
||||
yield np.array([1, 2, 3], dtype=np.int64)
|
||||
yield np.array([4, 5, 6], dtype=np.int64)
|
||||
yield "ERROR"
|
||||
yield np.array([7, 8, 9], dtype=np.int64)
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64, output_shapes=[3])
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
||||
with self.assertRaisesOpError(r"element of type .*int64.* was expected"):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual([7, 8, 9], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testFromGeneratorShapeError(self):
|
||||
def generator():
|
||||
yield np.array([1, 2, 3], dtype=np.int64)
|
||||
yield np.array([4, 5, 6], dtype=np.int64)
|
||||
yield np.array([7, 8, 9, 10], dtype=np.int64)
|
||||
yield np.array([11, 12, 13], dtype=np.int64)
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64, output_shapes=[3])
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
||||
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual([11, 12, 13], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -0,0 +1,109 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops that need test_util."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class IteratorClusterTest(test.TestCase):
|
||||
|
||||
def testRemoteIteratorWithoutRemoteCallFail(self):
|
||||
worker_config = config_pb2.ConfigProto()
|
||||
worker_config.device_count["CPU"] = 2
|
||||
worker, _ = test_util.create_local_cluster(
|
||||
1, 1, worker_config=worker_config)
|
||||
|
||||
with ops.device("/job:worker/replica:0/task:0/cpu:1"):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
iterator_3 = dataset_3.make_one_shot_iterator()
|
||||
iterator_3_handle = iterator_3.string_handle()
|
||||
|
||||
with ops.device("/job:worker/replica:0/task:0/cpu:0"):
|
||||
remote_it = dataset_ops.Iterator.from_string_handle(
|
||||
iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes)
|
||||
get_next_op = remote_it.get_next()
|
||||
|
||||
with session.Session(worker[0].target) as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
def testRemoteIteratorUsingRemoteCallOp(self):
|
||||
worker_config = config_pb2.ConfigProto()
|
||||
worker_config.device_count["CPU"] = 2
|
||||
worker, _ = test_util.create_local_cluster(
|
||||
1, 1, worker_config=worker_config)
|
||||
|
||||
with ops.device("/job:worker/replica:0/task:0/cpu:1"):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
iterator_3 = dataset_3.make_one_shot_iterator()
|
||||
iterator_3_handle = iterator_3.string_handle()
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _remote_fn(h):
|
||||
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
h, dataset_3.output_types, dataset_3.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
with ops.device("/job:worker/replica:0/task:0/cpu:0"):
|
||||
target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
remote_op = functional_ops.remote_call(
|
||||
args=[iterator_3_handle],
|
||||
Tout=[dtypes.int32],
|
||||
f=_remote_fn,
|
||||
target=target_placeholder)
|
||||
|
||||
with session.Session(worker[0].target) as sess:
|
||||
elem = sess.run(
|
||||
remote_op,
|
||||
feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"})
|
||||
self.assertEqual(elem, [1])
|
||||
# Fails when target is cpu:0 where the resource is not located.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:worker/replica:0/task:0/cpu:0"
|
||||
})
|
||||
elem = sess.run(
|
||||
remote_op,
|
||||
feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"})
|
||||
self.assertEqual(elem, [2])
|
||||
elem = sess.run(
|
||||
remote_op,
|
||||
feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"})
|
||||
self.assertEqual(elem, [3])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:worker/replica:0/task:0/cpu:1"
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -25,8 +25,10 @@ from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -416,6 +418,62 @@ class IteratorTest(test.TestCase):
|
||||
feedable_int_vector.get_next(),
|
||||
feed_dict={handle_placeholder: handle_float_vector}))
|
||||
|
||||
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
|
||||
worker_config = config_pb2.ConfigProto()
|
||||
worker_config.device_count["CPU"] = 2
|
||||
|
||||
with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
iterator_3 = dataset_3.make_one_shot_iterator()
|
||||
iterator_3_handle = iterator_3.string_handle()
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _remote_fn(h):
|
||||
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
h, dataset_3.output_types, dataset_3.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
|
||||
target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
remote_op = functional_ops.remote_call(
|
||||
args=[iterator_3_handle],
|
||||
Tout=[dtypes.int32],
|
||||
f=_remote_fn,
|
||||
target=target_placeholder)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
elem = sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||
})
|
||||
self.assertEqual(elem, [1])
|
||||
# Fails when target is cpu:0 where the resource is not located.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||
})
|
||||
elem = sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||
})
|
||||
self.assertEqual(elem, [2])
|
||||
elem = sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||
})
|
||||
self.assertEqual(elem, [3])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(
|
||||
remote_op,
|
||||
feed_dict={
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -549,6 +549,41 @@ class MapDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testReturnList(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.map(lambda x: [x, constant_op.constant(37.0)])
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, 37.0), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testMultiOutputPyFunc(self):
|
||||
# The `tf.py_func()` op returns a list of tensors for its outputs.
|
||||
def _map_fn(x_tensor):
|
||||
def _map_py_func(x):
|
||||
return x, np.array(37.0, dtype=np.float64)
|
||||
return script_ops.py_func(
|
||||
_map_py_func, [x_tensor], [dtypes.int64, dtypes.float64])
|
||||
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.map(_map_fn)
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, 37.0), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -24,6 +24,7 @@ py_library(
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
|
@ -18,6 +18,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -40,6 +42,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
|
||||
@ -559,6 +562,168 @@ class Dataset(object):
|
||||
"""
|
||||
return SparseTensorSliceDataset(sparse_tensor)
|
||||
|
||||
class _GeneratorState(object):
|
||||
"""Stores outstanding iterators created from a Python generator.
|
||||
|
||||
This class keeps track of potentially multiple iterators that may have
|
||||
been created from a generator, e.g. in the case that the dataset is
|
||||
repeated, or nested within a parallel computation.
|
||||
"""
|
||||
|
||||
def __init__(self, generator):
|
||||
self._generator = generator
|
||||
self._lock = threading.Lock()
|
||||
self._next_id = 0 # GUARDED_BY(self._lock)
|
||||
self._iterators = collections.defaultdict(lambda: iter(generator()))
|
||||
|
||||
def get_next_id(self):
|
||||
with self._lock:
|
||||
ret = self._next_id
|
||||
self._next_id += 1
|
||||
return ret
|
||||
|
||||
def get_iterator(self, iterator_id):
|
||||
return self._iterators[iterator_id]
|
||||
|
||||
def iterator_completed(self, iterator_id):
|
||||
del self._iterators[iterator_id]
|
||||
|
||||
@staticmethod
|
||||
def from_generator(generator, output_types, output_shapes=None):
|
||||
"""Creates a `Dataset` whose elements are generated by `generator`.
|
||||
|
||||
The `generator` argument must be a callable object that returns
|
||||
an object that support the `iter()` protocol (e.g. a generator function).
|
||||
The elements generated by `generator` must be compatible with the given
|
||||
`output_types` and (optional) `output_shapes` arguments.
|
||||
|
||||
Args:
|
||||
generator: A callable object that takes no arguments and returns an
|
||||
object that supports the `iter()` protocol.
|
||||
output_types: A nested structure of `tf.DType` objects corresponding to
|
||||
each component of an element yielded by `generator`.
|
||||
output_shapes: (Optional.) A nested structure of `tf.TensorShape`
|
||||
objects corresponding to each component of an element yielded by
|
||||
`generator`.
|
||||
|
||||
Returns:
|
||||
A `Dataset`.
|
||||
"""
|
||||
if not callable(generator):
|
||||
raise TypeError("`generator` must be callable.")
|
||||
if output_shapes is None:
|
||||
output_shapes = nest.map_structure(
|
||||
lambda _: tensor_shape.TensorShape(None), output_types)
|
||||
else:
|
||||
output_shapes = nest.map_structure_up_to(
|
||||
output_types, tensor_shape.as_shape, output_shapes)
|
||||
|
||||
flattened_types = nest.flatten(output_types)
|
||||
flattened_shapes = nest.flatten(output_shapes)
|
||||
|
||||
generator_state = Dataset._GeneratorState(generator)
|
||||
|
||||
def get_iterator_id_map_fn(unused_dummy):
|
||||
"""Creates a unique `iterator_id` for each pass over the dataset.
|
||||
|
||||
The "iterator_id" disambiguates between multiple concurrently
|
||||
existing iterators.
|
||||
|
||||
Args:
|
||||
unused_dummy: Ignored value.
|
||||
|
||||
Returns:
|
||||
A `tf.int64` tensor whose value uniquely identifies an iterator in
|
||||
`generator_state`.
|
||||
"""
|
||||
return script_ops.py_func(
|
||||
generator_state.get_next_id, [], dtypes.int64, stateful=True)
|
||||
|
||||
def generator_map_fn(iterator_id_t):
|
||||
"""Generates the next element from iterator with ID `iterator_id_t`.
|
||||
|
||||
We map this function across an infinite repetition of the
|
||||
`iterator_id_t`, and raise `StopIteration` to terminate the iteration.
|
||||
|
||||
Args:
|
||||
iterator_id_t: A `tf.int64` tensor whose value uniquely identifies
|
||||
the iterator in `generator_state` from which to generate an element.
|
||||
|
||||
Returns:
|
||||
A nested structure of tensors representing an element from the iterator.
|
||||
"""
|
||||
def generator_py_func(iterator_id):
|
||||
"""A `py_func` that will be called to invoke the iterator."""
|
||||
try:
|
||||
values = next(generator_state.get_iterator(iterator_id))
|
||||
except StopIteration:
|
||||
generator_state.iterator_completed(iterator_id)
|
||||
raise StopIteration("Iteration finished.")
|
||||
|
||||
# Use the same _convert function from the py_func() implementation to
|
||||
# convert the returned values to arrays early, so that we can inspect
|
||||
# their values.
|
||||
# pylint: disable=protected-access
|
||||
ret_arrays = [script_ops.FuncRegistry._convert(ret)
|
||||
for ret in nest.flatten_up_to(output_types, values)]
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Additional type and shape checking to ensure that the components
|
||||
# of the generated element match the `output_types` and `output_shapes`
|
||||
# arguments.
|
||||
for (ret_array, expected_dtype, expected_shape) in zip(
|
||||
ret_arrays, flattened_types, flattened_shapes):
|
||||
if ret_array.dtype != expected_dtype.as_numpy_dtype:
|
||||
raise TypeError(
|
||||
"`generator` yielded an element of type %s where an element "
|
||||
"of type %s was expected."
|
||||
% (ret_array.dtype, expected_dtype.as_numpy_dtype))
|
||||
if not expected_shape.is_compatible_with(ret_array.shape):
|
||||
raise ValueError(
|
||||
"`generator` yielded an element of shape %s where an element "
|
||||
"of shape %s was expected." % (ret_array.shape, expected_shape))
|
||||
|
||||
return ret_arrays
|
||||
|
||||
flat_values = script_ops.py_func(
|
||||
generator_py_func, [iterator_id_t], flattened_types, stateful=True)
|
||||
|
||||
# The `py_func()` op drops the inferred shapes, so we add them back in
|
||||
# here.
|
||||
if output_shapes is not None:
|
||||
for ret_t, shape in zip(flat_values, flattened_shapes):
|
||||
ret_t.set_shape(shape)
|
||||
|
||||
return nest.pack_sequence_as(output_types, flat_values)
|
||||
|
||||
# This function associates each traversal of `generator` with a unique
|
||||
# iterator ID.
|
||||
def flat_map_fn(iterator_id_t):
|
||||
# First, generate an infinite dataset containing the iterator ID repeated
|
||||
# forever.
|
||||
repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None)
|
||||
|
||||
# The `generator_map_fn` gets the next element from the iterator with the
|
||||
# relevant ID, and raises StopIteration when that iterator contains no
|
||||
# more elements.
|
||||
return repeated_id.map(generator_map_fn)
|
||||
|
||||
# A single-element dataset that, each time it is evaluated, contains a
|
||||
# freshly-generated and unique (for the returned dataset) int64
|
||||
# ID that will be used to identify the appropriate Python state, which
|
||||
# is encapsulated in `generator_state`, and captured in
|
||||
# `get_iterator_id_map_fn`.
|
||||
dummy = 0
|
||||
id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn)
|
||||
|
||||
# A dataset that contains all of the elements generated by a
|
||||
# single iterator created from `generator`, identified by the
|
||||
# iterator ID contained in `id_dataset`. Lifting the iteration
|
||||
# into a flat_map here enables multiple repetitions and/or nested
|
||||
# versions of the returned dataset to be created, because it forces
|
||||
# the generation of a new ID for each version.
|
||||
return id_dataset.flat_map(flat_map_fn)
|
||||
|
||||
@staticmethod
|
||||
def range(*args):
|
||||
"""Creates a `Dataset` of a step-separated range of values.
|
||||
@ -1123,6 +1288,11 @@ class Dataset(object):
|
||||
}
|
||||
```
|
||||
|
||||
NOTE: The order of elements yielded by this transformation is
|
||||
deterministic, as long as `map_func` is a pure function. If
|
||||
`map_func` contains any stateful operations, the order in which
|
||||
that state is accessed is undefined.
|
||||
|
||||
Args:
|
||||
map_func: A function mapping a nested structure of tensors (having shapes
|
||||
and types defined by `self.output_shapes` and `self.output_types`) to a
|
||||
@ -1821,6 +1991,19 @@ class MapDataset(Dataset):
|
||||
else:
|
||||
ret = map_func(nested_args)
|
||||
|
||||
# If `map_func` returns a list of tensors, `nest.flatten()` and
|
||||
# `ops.convert_to_tensor()` would conspire to attempt to stack
|
||||
# those tensors into a single tensor, because the customized
|
||||
# version of `nest.flatten()` does not recurse into lists. Since
|
||||
# it is more likely that the list arose from returning the
|
||||
# result of an operation (such as `tf.py_func()`) that returns a
|
||||
# list of not-necessarily-stackable tensors, we treat the
|
||||
# returned value is a `tuple` instead. A user wishing to pack
|
||||
# the return value into a single tensor can use an explicit
|
||||
# `tf.stack()` before returning.
|
||||
if isinstance(ret, list):
|
||||
ret = tuple(ret)
|
||||
|
||||
# Extract shape information from the returned values.
|
||||
flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)]
|
||||
self._output_shapes = nest.pack_sequence_as(
|
||||
|
@ -40,6 +40,14 @@ import six as _six
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
|
||||
def _sorted(dict_):
|
||||
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
|
||||
try:
|
||||
return sorted(_six.iterkeys(dict_))
|
||||
except TypeError:
|
||||
raise TypeError("nest only supports dicts with sortable keys.")
|
||||
|
||||
|
||||
def _sequence_like(instance, args):
|
||||
"""Converts the sequence `args` to the same type as `instance`.
|
||||
|
||||
@ -51,9 +59,13 @@ def _sequence_like(instance, args):
|
||||
`args` with the type of `instance`.
|
||||
"""
|
||||
if isinstance(instance, dict):
|
||||
# This is a dict. Iterate over the keys in sorted order to make
|
||||
# this deterministic.
|
||||
return {k: v for k, v in zip(sorted(instance.keys()), args)}
|
||||
# Pack dictionaries in a deterministic order by sorting the keys.
|
||||
# Notice this means that we ignore the original order of `OrderedDict`
|
||||
# instances. This is intentional, to avoid potential bugs caused by mixing
|
||||
# ordered and plain dicts (e.g., flattening a dict but using a
|
||||
# corresponding `OrderedDict` to pack it back).
|
||||
result = dict(zip(_sorted(instance), args))
|
||||
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
|
||||
elif (isinstance(instance, tuple) and
|
||||
hasattr(instance, "_fields") and
|
||||
isinstance(instance._fields, _collections.Sequence) and
|
||||
@ -65,16 +77,22 @@ def _sequence_like(instance, args):
|
||||
return type(instance)(args)
|
||||
|
||||
|
||||
def _elements_of(nest):
|
||||
if isinstance(nest, dict):
|
||||
# Iterate over dict keys in sorted order to make this deterministic.
|
||||
return [v for _, v in sorted(nest.items())]
|
||||
def _yield_value(iterable):
|
||||
if isinstance(iterable, dict):
|
||||
# Iterate through dictionaries in a deterministic order by sorting the
|
||||
# keys. Notice this means that we ignore the original order of `OrderedDict`
|
||||
# instances. This is intentional, to avoid potential bugs caused by mixing
|
||||
# ordered and plain dicts (e.g., flattening a dict but using a
|
||||
# corresponding `OrderedDict` to pack it back).
|
||||
for key in _sorted(iterable):
|
||||
yield iterable[key]
|
||||
else:
|
||||
return nest
|
||||
for value in iterable:
|
||||
yield value
|
||||
|
||||
|
||||
def _yield_flat_nest(nest):
|
||||
for n in _elements_of(nest):
|
||||
for n in _yield_value(nest):
|
||||
if is_sequence(n):
|
||||
for ni in _yield_flat_nest(n):
|
||||
yield ni
|
||||
@ -132,7 +150,7 @@ def _recursive_assert_same_structure(nest1, nest2, check_types):
|
||||
"structure has type %s, while second structure has type %s."
|
||||
% (type_nest1, type_nest2))
|
||||
|
||||
for n1, n2 in zip(_elements_of(nest1), _elements_of(nest2)):
|
||||
for n1, n2 in zip(_yield_value(nest1), _yield_value(nest2)):
|
||||
_recursive_assert_same_structure(n1, n2, check_types)
|
||||
|
||||
|
||||
@ -181,7 +199,7 @@ def _packed_nest_with_indices(structure, flat, index):
|
||||
(assuming indexing starts from `index`).
|
||||
"""
|
||||
packed = []
|
||||
for s in structure:
|
||||
for s in _yield_value(structure):
|
||||
if is_sequence(s):
|
||||
new_index, child = _packed_nest_with_indices(s, flat, index)
|
||||
packed.append(_sequence_like(s, child))
|
||||
@ -286,8 +304,8 @@ def map_structure(func, *structure, **check_types_dict):
|
||||
def _yield_flat_up_to(shallow_tree, input_tree):
|
||||
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
|
||||
if is_sequence(shallow_tree):
|
||||
for shallow_branch, input_branch in zip(_elements_of(shallow_tree),
|
||||
_elements_of(input_tree)):
|
||||
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
|
||||
_yield_value(input_tree)):
|
||||
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
|
||||
yield input_leaf
|
||||
else:
|
||||
|
@ -65,6 +65,73 @@ class NestTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
|
||||
|
||||
def testFlattenDictOrder(self):
|
||||
"""`flatten` orders dicts by key, including OrderedDicts."""
|
||||
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
|
||||
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
|
||||
ordered_flat = nest.flatten(ordered)
|
||||
plain_flat = nest.flatten(plain)
|
||||
self.assertEqual([0, 1, 2, 3], ordered_flat)
|
||||
self.assertEqual([0, 1, 2, 3], plain_flat)
|
||||
|
||||
def testPackDictOrder(self):
|
||||
"""Packing orders dicts by key, including OrderedDicts."""
|
||||
ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
|
||||
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
|
||||
seq = [0, 1, 2, 3]
|
||||
ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
|
||||
plain_reconstruction = nest.pack_sequence_as(plain, seq)
|
||||
self.assertEqual(
|
||||
collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
|
||||
ordered_reconstruction)
|
||||
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
|
||||
|
||||
def testFlattenAndPack_withDicts(self):
|
||||
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
|
||||
named_tuple = collections.namedtuple("A", ("b", "c"))
|
||||
mess = (
|
||||
"z",
|
||||
named_tuple(3, 4),
|
||||
{
|
||||
"c": (
|
||||
1,
|
||||
collections.OrderedDict([
|
||||
("b", 3),
|
||||
("a", 2),
|
||||
]),
|
||||
),
|
||||
"b": 5
|
||||
},
|
||||
17
|
||||
)
|
||||
|
||||
flattened = nest.flatten(mess)
|
||||
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
|
||||
|
||||
structure_of_mess = (
|
||||
14,
|
||||
named_tuple("a", True),
|
||||
{
|
||||
"c": (
|
||||
0,
|
||||
collections.OrderedDict([
|
||||
("b", 9),
|
||||
("a", 8),
|
||||
]),
|
||||
),
|
||||
"b": 3
|
||||
},
|
||||
"hi everybody",
|
||||
)
|
||||
|
||||
unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
|
||||
self.assertEqual(unflattened, mess)
|
||||
|
||||
# Check also that the OrderedDict was created, with the correct key order.
|
||||
unflattened_ordered_dict = unflattened[2]["c"][1]
|
||||
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
|
||||
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
|
||||
|
||||
def testIsSequence(self):
|
||||
self.assertFalse(nest.is_sequence("1234"))
|
||||
self.assertFalse(nest.is_sequence([1, 3, [4, 5]]))
|
||||
|
@ -50,6 +50,7 @@ from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
|
||||
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
|
||||
from tensorflow.contrib.distributions.python.ops.sample_stats import *
|
||||
from tensorflow.contrib.distributions.python.ops.test_util import *
|
||||
from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import *
|
||||
from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import *
|
||||
from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import *
|
||||
from tensorflow.contrib.distributions.python.ops.wishart import *
|
||||
|
32
tensorflow/contrib/eager/python/BUILD
Normal file
32
tensorflow/contrib/eager/python/BUILD
Normal file
@ -0,0 +1,32 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
py_library(
|
||||
name = "tfe",
|
||||
srcs = ["tfe.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:core",
|
||||
"//tensorflow/python/eager:custom_gradient",
|
||||
"//tensorflow/python/eager:execution_callbacks",
|
||||
"//tensorflow/python/eager:function",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
"g3doc/sitemap.md",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
80
tensorflow/contrib/eager/python/tfe.py
Normal file
80
tensorflow/contrib/eager/python/tfe.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorFlow Eager execution prototype.
|
||||
|
||||
EXPERIMENTAL: APIs here are unstable and likely to change without notice.
|
||||
|
||||
To use, at program startup, call `tfe.enable_eager_execution()`.
|
||||
|
||||
@@list_devices
|
||||
@@device
|
||||
|
||||
|
||||
@@defun
|
||||
@@implicit_gradients
|
||||
@@implicit_value_and_gradients
|
||||
@@gradients_function
|
||||
@@value_and_gradients_function
|
||||
|
||||
@@enable_tracing
|
||||
@@flush_trace
|
||||
|
||||
@@run
|
||||
@@enable_eager_execution
|
||||
|
||||
@@custom_gradient
|
||||
|
||||
@@add_execution_callback
|
||||
@@clear_execution_callbacks
|
||||
@@inf_callback
|
||||
@@inf_nan_callback
|
||||
@@nan_callback
|
||||
@@seterr
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
# pylint:disable=g-bad-import-order,g-import-not-at-top,unused-import
|
||||
#
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager.custom_gradient import custom_gradient
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager.context import context
|
||||
from tensorflow.python.eager.context import device
|
||||
from tensorflow.python.eager.context import enable_eager_execution
|
||||
from tensorflow.python.eager.context import run
|
||||
from tensorflow.python.eager.core import enable_tracing
|
||||
from tensorflow.python.eager.execution_callbacks import add_execution_callback
|
||||
from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks
|
||||
from tensorflow.python.eager.execution_callbacks import inf_callback
|
||||
from tensorflow.python.eager.execution_callbacks import inf_nan_callback
|
||||
from tensorflow.python.eager.execution_callbacks import nan_callback
|
||||
from tensorflow.python.eager.execution_callbacks import seterr
|
||||
|
||||
|
||||
def list_devices():
|
||||
return context().devices()
|
||||
|
||||
defun = function.defun
|
||||
implicit_gradients = backprop.implicit_grad
|
||||
implicit_value_and_gradients = backprop.implicit_val_and_grad
|
||||
gradients_function = backprop.gradients_function
|
||||
value_and_gradients_function = backprop.val_and_grad_function
|
||||
|
||||
remove_undocumented(__name__)
|
@ -119,6 +119,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "logit_fns_test",
|
||||
size = "small",
|
||||
srcs = ["python/learn/estimators/logit_fns_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "estimators_test",
|
||||
size = "small",
|
||||
|
@ -321,6 +321,7 @@ from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassi
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import call_logit_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import dnn_logit_fn_builder
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import linear_logit_fn_builder
|
||||
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||
|
@ -21,7 +21,7 @@ should follow the following signature:
|
||||
Args:
|
||||
`features`: This is the first item returned from the `input_fn` passed to
|
||||
`train`, `evaluate`, and `predict`. This should be a single
|
||||
`Tensor` or `dict` of same.
|
||||
`Tensor` or `dict` of same, and is the only required argument.
|
||||
`mode`: Optional. Specifies if this training, evaluation or prediction. See
|
||||
`ModeKeys`.
|
||||
`params`: Optional `dict` of hyperparameters. Will receive what is passed to
|
||||
@ -39,10 +39,47 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.estimator import util
|
||||
from tensorflow.python.estimator.canned import dnn as dnn_core
|
||||
from tensorflow.python.estimator.canned import linear as linear_core
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
# pylint: disable=protected-access
|
||||
dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
|
||||
linear_logit_fn_builder = linear_core._linear_logit_fn_builder
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def call_logit_fn(logit_fn, features, mode, params, config):
|
||||
"""Calls logit_fn.
|
||||
|
||||
A utility function that calls the provided logit_fn with the relevant subset
|
||||
of provided arguments. Similar to tf.estimator._call_model_fn().
|
||||
|
||||
Args:
|
||||
logit_fn: A logit_fn as defined above.
|
||||
features: The features dict.
|
||||
mode: TRAIN / EVAL / PREDICT ModeKeys.
|
||||
params: The hyperparameter dict.
|
||||
config: The configuration object.
|
||||
|
||||
Returns:
|
||||
A logit Tensor, the output of logit_fn.
|
||||
|
||||
Raises:
|
||||
ValueError: if logit_fn does not return a Tensor.
|
||||
"""
|
||||
logit_fn_args = util.fn_args(logit_fn)
|
||||
kwargs = {}
|
||||
if 'mode' in logit_fn_args:
|
||||
kwargs['mode'] = mode
|
||||
if 'params' in logit_fn_args:
|
||||
kwargs['params'] = params
|
||||
if 'config' in logit_fn_args:
|
||||
kwargs['config'] = config
|
||||
logit_fn_results = logit_fn(features=features, **kwargs)
|
||||
|
||||
if not isinstance(logit_fn_results, ops.Tensor):
|
||||
raise ValueError('model_fn should return a Tensor.')
|
||||
|
||||
return logit_fn_results
|
||||
|
@ -0,0 +1,60 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""logit_fn tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import logit_fns
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.estimator import model_fn
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class LogitFnTest(test.TestCase):
|
||||
|
||||
def test_simple_call_logit_fn(self):
|
||||
def dummy_logit_fn(features, mode):
|
||||
if mode == model_fn.ModeKeys.TRAIN:
|
||||
return features['f1']
|
||||
else:
|
||||
return features['f2']
|
||||
features = {
|
||||
'f1': constant_op.constant([2., 3.]),
|
||||
'f2': constant_op.constant([4., 5.])
|
||||
}
|
||||
logit_fn_result = logit_fns.call_logit_fn(
|
||||
dummy_logit_fn, features, model_fn.ModeKeys.EVAL, 'fake_params',
|
||||
'fake_config')
|
||||
with session.Session():
|
||||
self.assertAllClose([[4., 5.]], logit_fn_result.eval())
|
||||
|
||||
def test_should_return_tensor(self):
|
||||
|
||||
def invalid_logit_fn(features, params):
|
||||
return {
|
||||
'tensor1': features['f1'] * params['input_multiplier'],
|
||||
'tensor2': features['f2'] * params['input_multiplier']
|
||||
}
|
||||
features = {
|
||||
'f1': constant_op.constant([2., 3.]),
|
||||
'f2': constant_op.constant([4., 5.])
|
||||
}
|
||||
params = {'learning_rate': 0.001, 'input_multiplier': 2.0}
|
||||
with self.assertRaisesRegexp(ValueError, 'model_fn should return a Tensor'):
|
||||
logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params,
|
||||
'fake_config')
|
@ -73,8 +73,9 @@ HOST_INCLUDES := \
|
||||
-I. \
|
||||
-I$(MAKEFILE_DIR)/downloads/ \
|
||||
-I$(MAKEFILE_DIR)/downloads/eigen \
|
||||
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
|
||||
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
|
||||
-I$(MAKEFILE_DIR)/downloads/nsync/public \
|
||||
-I$(MAKEFILE_DIR)/downloads/fft2d \
|
||||
-I$(HOST_GENDIR)
|
||||
ifeq ($(HAS_GEN_HOST_PROTOC),true)
|
||||
HOST_INCLUDES += -I$(MAKEFILE_DIR)/gen/protobuf-host/include
|
||||
@ -156,6 +157,7 @@ INCLUDES := \
|
||||
-I$(MAKEFILE_DIR)/downloads/eigen \
|
||||
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
|
||||
-I$(MAKEFILE_DIR)/downloads/nsync/public \
|
||||
-I$(MAKEFILE_DIR)/downloads/fft2d \
|
||||
-I$(PROTOGENDIR) \
|
||||
-I$(PBTGENDIR)
|
||||
ifeq ($(HAS_GEN_HOST_PROTOC),true)
|
||||
@ -237,6 +239,7 @@ ifeq ($(TARGET),ANDROID)
|
||||
$(error "NDK_ROOT is not defined.")
|
||||
endif
|
||||
CXX := $(CC_PREFIX) $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/$(OS_PATH)-x86_64/bin/arm-linux-androideabi-g++
|
||||
CC := $(CC_PREFIX) $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/$(OS_PATH)-x86_64/bin/arm-linux-androideabi-gcc
|
||||
CXXFLAGS +=\
|
||||
--sysroot $(NDK_ROOT)/platforms/android-21/arch-arm \
|
||||
-Wno-narrowing \
|
||||
@ -244,7 +247,6 @@ ifeq ($(TARGET),ANDROID)
|
||||
-mfloat-abi=softfp \
|
||||
-mfpu=neon \
|
||||
-fPIE
|
||||
|
||||
INCLUDES = \
|
||||
-I$(NDK_ROOT)/sources/android/support/include \
|
||||
-I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/include \
|
||||
@ -254,6 +256,7 @@ ifeq ($(TARGET),ANDROID)
|
||||
-I$(MAKEFILE_DIR)/downloads/eigen \
|
||||
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
|
||||
-I$(MAKEFILE_DIR)/downloads/nsync/public \
|
||||
-I$(MAKEFILE_DIR)/downloads/fft2d \
|
||||
-I$(MAKEFILE_DIR)/gen/protobuf/include \
|
||||
-I$(PROTOGENDIR) \
|
||||
-I$(PBTGENDIR)
|
||||
@ -502,11 +505,13 @@ $(wildcard tensorflow/core/user_ops/*.cu.cc) \
|
||||
$(wildcard tensorflow/core/common_runtime/gpu/*) \
|
||||
$(wildcard tensorflow/core/common_runtime/gpu_device_factory.*) \
|
||||
$(wildcard tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.*) \
|
||||
$(wildcard tensorflow/core/grappler/inputs/file_input_yielder.*) \
|
||||
$(wildcard tensorflow/core/grappler/clusters/single_machine.*)
|
||||
# Filter out all the excluded files.
|
||||
TF_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
|
||||
# Add in any extra files that don't fit the patterns easily
|
||||
TF_CC_SRCS += tensorflow/core/platform/default/gpu_tracer.cc
|
||||
TF_CC_SRCS += tensorflow/contrib/makefile/downloads/fft2d/fftsg.c
|
||||
# Also include the op and kernel definitions.
|
||||
TF_CC_SRCS += $(shell cat $(MAKEFILE_DIR)/tf_op_files.txt)
|
||||
PBT_CC_SRCS := $(shell cat $(MAKEFILE_DIR)/tf_pb_text_files.txt)
|
||||
@ -529,7 +534,8 @@ tensorflow/core/kernels/hexagon/hexagon_remote_fused_graph_executor_build.cc
|
||||
endif
|
||||
|
||||
# File names of the intermediate files target compilation generates.
|
||||
TF_CC_OBJS := $(addprefix $(OBJDIR), $(TF_CC_SRCS:.cc=.o))
|
||||
TF_CC_OBJS := $(addprefix $(OBJDIR), \
|
||||
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_CC_SRCS))))
|
||||
PBT_GEN_FILES := $(addprefix $(PBTGENDIR), $(PBT_CC_SRCS))
|
||||
PBT_OBJS := $(addprefix $(OBJDIR), $(PBT_CC_SRCS:.cc=.o))
|
||||
PROTO_CC_SRCS := $(addprefix $(PROTOGENDIR), $(PROTO_SRCS:.proto=.pb.cc))
|
||||
@ -567,6 +573,14 @@ $(OBJDIR)%.o: %.cc | $(PBT_GEN_FILES)
|
||||
$(CXX) $(CXXFLAGS) $(DEPFLAGS) $(INCLUDES) -c $< -o $@
|
||||
@mv -f $(DEPDIR)/$*.Td $(DEPDIR)/$*.d
|
||||
|
||||
# Matches on plain C files.
|
||||
$(OBJDIR)%.o: %.c
|
||||
@mkdir -p $(dir $@)
|
||||
@mkdir -p $(dir $(DEPDIR)$*)
|
||||
$(CXX) $(patsubst --std=c++11,--std=c99, $(CXXFLAGS)) -x c $(DEPFLAGS) \
|
||||
$(INCLUDES) -c $< -o $@
|
||||
@mv -f $(DEPDIR)/$*.Td $(DEPDIR)/$*.d
|
||||
|
||||
# Compiles C++ source files that have been generated by protoc.
|
||||
$(OBJDIR)%.pb.o: $(PROTOGENDIR)%.pb.cc
|
||||
@mkdir -p $(dir $@)
|
||||
|
@ -25,6 +25,7 @@ GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.g
|
||||
NSYNC_URL="$(grep -o 'http.*github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
|
||||
PROTOBUF_URL="$(grep -o 'http.*github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
|
||||
RE2_URL="$(grep -o 'http.*github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
|
||||
FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
|
||||
|
||||
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
|
||||
# so work around it by patching the source.
|
||||
@ -60,6 +61,7 @@ download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest"
|
||||
download_and_extract "${NSYNC_URL}" "${DOWNLOADS_DIR}/nsync"
|
||||
download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf"
|
||||
download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2"
|
||||
download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d"
|
||||
|
||||
replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \
|
||||
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user