Merge pull request #12455 from caisq/branch_165951046

Branch 165951046
This commit is contained in:
Shanqing Cai 2017-08-21 16:15:59 -04:00 committed by GitHub
commit b0d6bf3425
232 changed files with 7829 additions and 1983 deletions

View File

@ -80,3 +80,13 @@ new_http_archive(
"http://download.tensorflow.org/models/stylize_v1.zip",
],
)
new_http_archive(
name = "speech_commands",
build_file = "models.BUILD",
sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
"http://download.tensorflow.org/models/speech_commands_v0.01.zip",
],
)

View File

@ -286,6 +286,7 @@ filegroup(
"//tensorflow/contrib/data/python/util:all_files",
"//tensorflow/contrib/decision_trees/proto:all_files",
"//tensorflow/contrib/distributions:all_files",
"//tensorflow/contrib/eager/python:all_files",
"//tensorflow/contrib/factorization:all_files",
"//tensorflow/contrib/factorization/kernels:all_files",
"//tensorflow/contrib/ffmpeg:all_files",

View File

@ -1931,6 +1931,8 @@ bool CopyGraph(TF_Graph* src_graph, TF_Graph* dst_graph,
TF_ImportGraphDefOptionsAddInputMapping(opts.get(), src.first.data(),
src.second, dst_inputs[i]);
}
opts.get()->opts.skip_mapped_nodes = true;
// We use the pivot node to control constants in `src_graph`
TF_Operation* pivot = dst_inputs[0].oper;
TF_ImportGraphDefOptionsAddControlDependency(opts.get(), pivot);

View File

@ -64,19 +64,14 @@ struct TFE_Context {
// One FunctionLibraryRuntime per device.
// func_libs[i] is the FunctionLibraryRuntime corresponding to
// session->devices[i].
std::vector<std::unique_ptr<tensorflow::FunctionLibraryRuntime> > func_libs;
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
tensorflow::Fprint128Hasher>
kernel_cache;
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
for (int i = 0; i < session->devices.size(); ++i) {
if (session->devices[i] == d) {
return func_libs[i].get();
}
}
return nullptr;
return pflr->GetFLR(d->name());
}
const std::vector<tensorflow::Device*>& devices() { return session->devices; }
@ -132,12 +127,9 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
}
TFE_Context* ret = new TFE_Context(session);
ret->func_libs.resize(ret->devices().size());
for (int i = 0; i < ret->devices().size(); ++i) {
ret->func_libs[i] = tensorflow::NewFunctionLibraryRuntime(
ret->session->device_mgr, opts->options.env, ret->devices()[i],
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {});
}
ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime(
ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION,
&ret->func_lib_def, {}));
ret->rendezvous =
new tensorflow::IntraProcessRendezvous(ret->session->device_mgr);

View File

@ -174,6 +174,16 @@ TEST_F(CApiWhileLoopTest, BasicLoop) {
EXPECT_TRUE(outputs_[1].oper != nullptr);
EXPECT_GE(outputs_[1].index, 0);
// Check that cond and body inputs are not present
for (int i = 0; i < params_->ninputs; ++i) {
string cond_name =
::tensorflow::strings::StrCat(params_->name, "/cond/cond_input", i);
string body_name =
::tensorflow::strings::StrCat(params_->name, "/body/body_input", i);
EXPECT_TRUE(TF_GraphOperationByName(graph_, cond_name.c_str()) == nullptr);
EXPECT_TRUE(TF_GraphOperationByName(graph_, body_name.c_str()) == nullptr);
}
// Run the graph
Run({-9, 2});
ExpectOutputValue(0, 3);

View File

@ -481,6 +481,42 @@ Status AddNGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("AddN", AddNGrad);
// MaximumMinimumGradCommon adds shared ops to calculate gradients for
// the binary Maximum and Minimum ops.
Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs,
const Output& comparator) {
// comparator is a boolean tensor, with
// y = x_1 at points where comparator is true, and x_2 otherwise
// Therefore
// dy/dx_1 = 1 where comparator is true, and 0 otherwise.
// dy/dx_2 = 0 where comparator is true, and 1 otherwise.
auto grad = grad_inputs[0];
auto zeros = ZerosLike(scope, grad);
auto gx_1 = Where3(scope, comparator, grad, zeros);
auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros);
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
}
Status MaximumGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
comparator);
}
REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
Status MinimumGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto comparator = LessEqual(scope, op.input(0), op.input(1));
return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
comparator);
}
REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
Status RealGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {

View File

@ -925,6 +925,15 @@ class NaryGradTest : public ::testing::Test {
EXPECT_LT(max_error, 1e-3);
}
void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
const TensorShape& y_shape) {
TF_ASSERT_OK(scope_.status());
float max_error;
TF_ASSERT_OK(
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
EXPECT_LT(max_error, 1e-3);
}
Scope scope_;
};
@ -993,5 +1002,27 @@ TEST_F(NaryGradTest, SquaredDifference) {
RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
}
TEST_F(NaryGradTest, Maximum) {
TensorShape shape({3, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = Maximum(scope_, x, Const(scope_, 1.0f));
// Select values away from 1.0f to avoid instability when computing
// finite differences.
Tensor x_init_value =
test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
RunTest(x, x_init_value, y, shape);
}
TEST_F(NaryGradTest, Minimum) {
TensorShape shape({3, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = Minimum(scope_, x, Const(scope_, 1.0f));
// Select values away from 1.0f to avoid instability when computing
// finite differences.
Tensor x_init_value =
test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
RunTest(x, x_init_value, y, shape);
}
} // namespace
} // namespace tensorflow

View File

@ -178,6 +178,9 @@ def tf_library(name, graph, config,
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",

View File

@ -624,15 +624,18 @@ Status EncapsulateSubgraphsPass::Run(
FunctionLibraryDefinition* const library = options.flib_def;
OptimizerOptions opts;
std::unique_ptr<FunctionLibraryRuntime> flr(
NewFunctionLibraryRuntime(nullptr, options.session_options->env, nullptr,
TF_GRAPH_DEF_VERSION, library, opts));
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env,
TF_GRAPH_DEF_VERSION, library, opts));
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
auto rewrite_subgraph = [&flr](
std::unique_ptr<Graph>* subgraph, std::vector<int>* input_permutation,
std::vector<int>* output_permutation, NodeDef* node) {
auto rewrite_subgraph = [flr](std::unique_ptr<Graph>* subgraph,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* node) {
// Optimize the subgraph.
OptimizeGraph(flr.get(), subgraph);
OptimizeGraph(flr, subgraph);
const int num_args = input_permutation->size();
std::vector<bool> const_args(num_args);

View File

@ -176,8 +176,11 @@ Status FindCompilationCandidates(
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
std::unordered_set<Node*>* candidates) {
OptimizerOptions opts;
std::unique_ptr<FunctionLibraryRuntime> lib_runtime(NewFunctionLibraryRuntime(
nullptr, env, nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts));
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
flib_def, opts));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
for (Node* node : graph.op_nodes()) {
DeviceType device_type("");
@ -191,7 +194,7 @@ Status FindCompilationCandidates(
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration));
DeviceType jit_device_type(registration->compilation_device_name);
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) {
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) {
VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
<< ": " << node->type_string();
continue;
@ -203,7 +206,7 @@ Status FindCompilationCandidates(
continue;
}
if (node->type_string() == "While" &&
!IsCompilableWhile(*node, jit_device_type, 0, lib_runtime.get())) {
!IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) {
continue;
}
candidates->insert(node);

View File

@ -63,6 +63,39 @@ class FusedBatchNormTest(XLATestCase):
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
return grad_x, grad_scale, grad_offset
def testInference(self):
x_shape = [2, 2, 6, 2]
scale_shape = [2]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
offset_val = np.random.random_sample(scale_shape).astype(np.float32)
data_format = "NHWC"
with self.test_session() as sess, self.test_scope():
# To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
scale = array_ops.placeholder(np.float32, shape=[2], name="scale")
offset = array_ops.placeholder(np.float32, shape=[2], name="offset")
epsilon = 0.001
y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format)
y, mean, variance = nn.fused_batch_norm(
t_val,
scale,
offset,
mean=mean_ref,
variance=var_ref,
epsilon=epsilon,
data_format=data_format,
is_training=False)
y_val, _, _ = sess.run(
[y, mean,
variance], {t_val: x_val,
scale: scale_val,
offset: offset_val})
self.assertAllClose(y_val, y_ref, atol=1e-3)
def _testLearning(self, use_gradient_checker):
x_shape = [2, 2, 6, 2]
scale_shape = [2]

View File

@ -260,6 +260,13 @@ class UnaryOpsTest(XLATestCase):
np.array([[-2, 0, 8]], dtype=dtype),
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.is_finite,
np.array(
[[42, float("inf"), -123], [float("nan"), 0, -0.0]], dtype=dtype),
expected=np.array(
[[True, False, True], [False, True, True]], dtype=np.bool))
def testNumericOps(self):
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(

View File

@ -31,6 +31,7 @@ tf_kernel_library(
"function_ops.cc",
"gather_op.cc",
"identity_op.cc",
"is_finite_op.cc",
"l2loss_op.cc",
"lrn_ops.cc",
"matmul_op.cc",

View File

@ -39,28 +39,36 @@ class FusedBatchNormOp : public XlaOpKernel {
errors::InvalidArgument("Not supported format"));
feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format);
}
// TODO(b/62843645): Implement BatchNormInference.
OP_REQUIRES(
ctx, is_training_,
errors::InvalidArgument("Fused batch normalization for inference is "
"not supported yet on XLA backend."));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, feature_index_);
if (is_training_) {
xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_,
feature_index_);
// In training mode, outputs the normalized value as well as the calculated
// mean and variance.
for (int i = 0; i < 3; i++) {
ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
// In training mode, outputs the normalized value as well as the
// calculated mean and variance.
for (int i = 0; i < 3; i++) {
ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
}
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
// space 1 & 2". They are used to pass the per-batch mean and
// variance to the gradient. Here we maintain the same behavior by setting
// them to the mean and variance calculated by BatchNormTraining.
ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
} else {
xla::ComputationDataHandle output = ctx->builder()->BatchNormInference(
ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3),
ctx->Input(4), epsilon_, feature_index_);
ctx->SetOutput(0, output);
// Directly send input to output as mean and variance in inference mode.
ctx->SetOutput(1, ctx->Input(3));
ctx->SetOutput(2, ctx->Input(4));
ctx->SetOutput(3, ctx->Input(3));
ctx->SetOutput(4, ctx->Input(4));
}
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
// space 1 & 2". They are used to pass the per-batch mean and
// variance to the gradient. Here we maintain the same behavior by setting
// them to the mean and variance calculated by BatchNormTraining.
ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
}
private:

View File

@ -0,0 +1,43 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace {
class IsFiniteOp : public XlaOpKernel {
public:
explicit IsFiniteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle input = ctx->Input(0);
ctx->SetOutput(0, ctx->builder()->IsFinite(input));
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(IsFiniteOp);
};
REGISTER_XLA_OP(Name("IsFinite"), IsFiniteOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -139,7 +139,7 @@ xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice(
out_index);
auto ip1 = bodyb.Add(i, bodyb.ConstantR0<int32>(1));
bodyb.Tuple({ip1, data, indices_1d, updated_output});
bodyb.Tuple({ip1, data, idcs, updated_output});
}
auto body_status = bodyb.Build();
// TF_CHECK_OK(body_status);

View File

@ -88,15 +88,18 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
}
local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
FunctionDefLibrary{}));
local_flib_runtime_ = NewFunctionLibraryRuntime(
&device_mgr_, Env::Default(), device_, options.graph_def_version,
local_pflr_.reset(new ProcessFunctionLibraryRuntime(
&device_mgr_, Env::Default(), options.graph_def_version,
local_flib_def_.get(), OptimizerOptions(),
nullptr /* custom_kernel_creator */);
flib_runtime_ = NewFunctionLibraryRuntime(
&device_mgr_, Env::Default(), device_, options.graph_def_version,
options.flib_def, OptimizerOptions(),
nullptr /* custom_kernel_creator */);
nullptr /* custom_kernel_creator */));
pflr_.reset(new ProcessFunctionLibraryRuntime(
&device_mgr_, Env::Default(), options.graph_def_version, options.flib_def,
OptimizerOptions(), nullptr /* custom_kernel_creator */));
local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
flib_runtime_ = pflr_->GetFLR(device_->name());
}
XlaCompiler::~XlaCompiler() = default;
@ -137,8 +140,8 @@ Status XlaCompiler::CompileFunction(
}
const FunctionBody* fbody;
if (!GetFunctionBody(function, local_flib_runtime_.get(), &fbody).ok()) {
TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_.get(), &fbody));
if (!GetFunctionBody(function, local_flib_runtime_, &fbody).ok()) {
TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_, &fbody));
}
TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args));
@ -159,7 +162,7 @@ Status XlaCompiler::CompileFunction(
opts.set_do_function_inlining(true);
opts.set_do_constant_folding(true);
GraphOptimizer optimizer(opts);
optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(),
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
/*device=*/nullptr, &graph, /*shape_map=*/nullptr);
VLOG(1) << "====================================================";
@ -464,7 +467,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
context->set_args(std::move(context_args));
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
flib_runtime_.get(), NextStepId()));
flib_runtime_, NextStepId()));
int num_nonconst_outputs;
int num_computation_outputs;

View File

@ -276,7 +276,7 @@ class XlaCompiler {
xla::Client* client() const { return options_.client; }
XlaCompilationDevice* device() const { return device_; }
const DeviceMgr* device_mgr() const { return &device_mgr_; }
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_.get(); }
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
// Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists.
@ -303,9 +303,11 @@ class XlaCompiler {
// library and runtime for functions created as part of the functionalize
// control flow transformation.
std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
std::unique_ptr<FunctionLibraryRuntime> local_flib_runtime_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
FunctionLibraryRuntime* local_flib_runtime_; // owned by local_pflr_.
FunctionLibraryRuntime* flib_runtime_; // owned by pflr_.
struct SignatureHash {
uint64 operator()(

View File

@ -1477,9 +1477,29 @@ ComputationDataHandle ComputationBuilder::BatchNormInference(
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
const ComputationDataHandle& offset, const ComputationDataHandle& mean,
const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
// TODO(b/62843645): Implement BatchNormInference.
NoteError(Unimplemented("BatchNormInference is not implemented yet."));
return ComputationDataHandle();
if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
}
BatchNormInferenceRequest request;
*request.mutable_operand() = operand;
*request.mutable_scale() = scale;
*request.mutable_offset() = offset;
*request.mutable_mean() = mean;
*request.mutable_variance() = variance;
request.set_epsilon(epsilon);
request.set_feature_index(feature_index);
OpRequest op_request;
*op_request.mutable_batch_norm_inference_request() = request;
*op_request.mutable_computation() = computation_.handle();
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making BatchNormInference request";
Status s = client_->stub()->Op(&op_request, &response);
return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::BatchNormGrad(

View File

@ -48,38 +48,56 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
} // namespace
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
if (ShapeUtil::IsTuple(shape)) {
std::vector<std::unique_ptr<Literal>> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
MakeFakeLiteral(element_shape));
elements.push_back(std::move(element));
}
return Literal::MakeTupleOwned(std::move(elements));
}
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
std::minstd_rand0 engine;
switch (shape.element_type()) {
case F32: {
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
TF_CHECK_OK(literal->Populate<float>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case S32: {
std::uniform_int_distribution<int32> generator(
std::numeric_limits<int32>::lowest(),
std::numeric_limits<int32>::max());
TF_CHECK_OK(literal->Populate<int32>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
default:
return Unimplemented("Unsupported type for fake literal generation: %s",
ShapeUtil::HumanString(shape).c_str());
}
return std::move(literal);
}
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
std::minstd_rand0 engine;
switch (shape.element_type()) {
case F32: {
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
TF_CHECK_OK(literal->Populate<float>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case S32: {
std::uniform_int_distribution<int32> generator(
std::numeric_limits<int32>::lowest(),
std::numeric_limits<int32>::max());
TF_CHECK_OK(literal->Populate<int32>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
default:
LOG(WARNING)
<< "Unsupported type for host-side fake data generation: "
<< ShapeUtil::HumanString(shape)
<< "; falling back to making small amount of fake data via device.";
return MakeFakeDataViaDeviceOrDie(shape, client);
StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
if (!literal_status.ok()) {
// If we got an Unimplemented error, fall back to making the fake data via
// an on-device computation.
CHECK_EQ(literal_status.status().code(),
tensorflow::error::UNIMPLEMENTED);
return MakeFakeDataViaDeviceOrDie(shape, client);
}
return client->TransferToServer(*literal).ValueOrDie();
return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie();
}
// If the data is large, generate it on-device.

View File

@ -26,6 +26,10 @@ limitations under the License.
namespace xla {
// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data generation.
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
// Generates fake data of the given shape on the device or dies. The fake data
// is created by performing a computation on the device rather than transferring
// data from the host to the device.

View File

@ -1101,7 +1101,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::
VLOG(4) << " old user: " << user->ToString();
CHECK_EQ(user->operand(reshape_or_broadcast_operand_index),
reshape_or_broadcast);
std::vector<HloInstruction*> new_user_operands = user->operands();
auto new_user_operands = user->operands();
new_user_operands[reshape_or_broadcast_operand_index] = operand;
auto new_user = computation_->AddInstruction(user->CloneWithNewOperands(
ShapeUtil::MakeShapeWithLayout(
@ -1505,9 +1505,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// We cannot insert bitcasts if the layouts will not be compatible.
// TODO(b/33178038): Consider inserting a transpose if a bitcast would be
// invalid.
if (!valid_bitcast_callback_(input_shape, lhs->shape()) ||
!valid_bitcast_callback_(new_filter_shape, rhs->shape()) ||
!valid_bitcast_callback_(convolution_shape, dot_output_shape)) {
if (!valid_bitcast_callback_(input_shape, new_input_shape) ||
!valid_bitcast_callback_(filter_shape, new_filter_shape) ||
!valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
return Status::OK();
}

View File

@ -56,11 +56,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
// Runs the visitor on a computation.
static bool Run(HloComputation* computation, bool rewrite_training_op,
bool rewrite_grad_op, bool use_fusion);
bool rewrite_inference_op, bool rewrite_grad_op,
bool use_fusion);
// Returns whether any batch norm ops were rewritten.
const bool changed() const { return changed_; }
@ -70,9 +73,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
private:
explicit BatchNormRewriterVisitor(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
bool rewrite_grad_op, bool use_fusion)
: computation_(computation),
rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
@ -94,6 +99,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
HloComputation* computation_;
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
bool use_fusion_;
@ -126,11 +132,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
bool BatchNormRewriterVisitor::Run(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
bool rewrite_grad_op, bool use_fusion) {
BatchNormRewriterVisitor visitor(computation,
/*rewrite_training_op=*/rewrite_training_op,
/*rewrite_grad_op=*/rewrite_grad_op,
/*use_fusion=*/use_fusion);
BatchNormRewriterVisitor visitor(
computation,
/*rewrite_training_op=*/rewrite_training_op,
/*rewrite_inference_op=*/rewrite_inference_op,
/*rewrite_grad_op=*/rewrite_grad_op,
/*use_fusion=*/use_fusion);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@ -268,6 +277,82 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
return Status::OK();
}
Status BatchNormRewriterVisitor::HandleBatchNormInference(
HloInstruction* batch_norm) {
if (!rewrite_inference_op_) {
return Status::OK();
}
// Expand batch norm inference into smaller HLO ops.
HloInstruction* operand = batch_norm->mutable_operand(0);
const Shape operand_shape = operand->shape();
int64 feature_index = batch_norm->feature_index();
HloInstruction* scale = batch_norm->mutable_operand(1);
HloInstruction* offset = batch_norm->mutable_operand(2);
HloInstruction* mean = batch_norm->mutable_operand(3);
HloInstruction* var = batch_norm->mutable_operand(4);
const Shape feature_shape = scale->shape();
auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
std::vector<int64> dimensions_without_feature;
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
if (i != feature_index) {
dimensions_without_feature.push_back(i);
}
}
auto scale_broadcasted = computation_->AddInstruction(
HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
auto offset_broadcasted = computation_->AddInstruction(
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
auto mean_broadcasted = computation_->AddInstruction(
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
auto var_broadcasted = computation_->AddInstruction(
HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
// Var[X] + epsilon.
auto var_add_epsilon =
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
// 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon =
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
// X - E[X].
auto operand_minus_mean =
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
// (X - E[X]) / Sqrt[Var[X] + epsilon].
auto normalized = computation_->AddInstruction(
HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
operand_minus_mean, rsqrt_var_add_epsilon));
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
auto scaled_normalized =
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
auto shifted_normalized = HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
TF_CHECK_OK(
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
return Status::OK();
}
Status BatchNormRewriterVisitor::HandleBatchNormGrad(
HloInstruction* batch_norm) {
// Use the following formulas to calculate gradients:
@ -457,7 +542,8 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
}
for (auto& comp : computations) {
if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_,
rewrite_grad_op_, use_fusion_)) {
rewrite_inference_op_, rewrite_grad_op_,
use_fusion_)) {
changed = true;
}
}

View File

@ -30,8 +30,10 @@ class BatchNormRewriter : public HloPassInterface {
public:
// When use_fusion is set, a multi-output fusion node is created.
BatchNormRewriter(bool rewrite_training_op = false,
bool rewrite_inference_op = false,
bool rewrite_grad_op = false, bool use_fusion = true)
: rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
~BatchNormRewriter() = default;
@ -43,6 +45,7 @@ class BatchNormRewriter : public HloPassInterface {
private:
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
bool use_fusion_;
};

View File

@ -64,6 +64,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) {
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining);
BatchNormRewriter rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
@ -105,6 +106,7 @@ TEST_F(BatchNormRewriterTest, BatchNormGrad) {
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad);
BatchNormRewriter rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
root = computation->root_instruction();

View File

@ -1156,8 +1156,7 @@ void BufferAssigner::AddWhileSetToColocatedBufferSets(
// predecessor set, and must be unambiguous.
const PointsToSet& init_points_to =
points_to_analysis.GetPointsToSet(instruction->operand(0));
const std::vector<const LogicalBuffer*>& init_buffers =
init_points_to.element(buffer->index());
const auto& init_buffers = init_points_to.element(buffer->index());
CHECK_EQ(init_buffers.size(), 1);
CHECK_GT(predecessor_set.count(init_buffers[0]), 0);
predecessor_while_buffers.push_back(init_buffers[0]);
@ -1220,8 +1219,8 @@ const LogicalBuffer* AddBufferToColocatedSet(
std::vector<const LogicalBuffer*>* colocated_set) {
// CopyInsertion ensures root points-to set is unambiguous and distinct.
const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
CHECK(!points_to.IsAmbiguous());
CHECK(points_to.IsDistinct());
DCHECK(!points_to.IsAmbiguous());
DCHECK(points_to.IsDistinct());
colocated_set->push_back(points_to.element(index)[0]);
return colocated_set->back();
}

View File

@ -306,7 +306,7 @@ class BufferAssignment {
// Returns the set LogicalBuffers which may be the source of the value at the
// given index and instruction.
const std::vector<const LogicalBuffer*>& GetSourceBuffers(
const PointsToSet::BufferList& GetSourceBuffers(
const HloInstruction* instruction, const ShapeIndex& index) const {
return GetPointsToSet(instruction).element(index);
}

View File

@ -54,8 +54,7 @@ class BufferLiveness {
bool MaybeLiveOut(const LogicalBuffer& buffer) const;
// Returns the complete set of buffers that may be live out of the module.
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& maybe_live_out_buffers()
const {
const PointsToSet::BufferSet& maybe_live_out_buffers() const {
return maybe_live_out_buffers_;
}
@ -106,7 +105,7 @@ class BufferLiveness {
tensorflow::gtl::FlatSet<const LogicalBuffer*> aliased_buffers_;
// LogicalBuffers that may be live out of the entry computation.
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers_;
PointsToSet::BufferSet maybe_live_out_buffers_;
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
};

View File

@ -37,10 +37,9 @@ class BufferLivenessTest : public HloTestBase {
const LogicalBuffer& GetBuffer(const BufferLiveness& liveness,
const HloInstruction* instruction,
const ShapeIndex& index) {
const std::vector<const LogicalBuffer*>& pointed_to =
liveness.points_to_analysis()
.GetPointsToSet(instruction)
.element(index);
const auto& pointed_to = liveness.points_to_analysis()
.GetPointsToSet(instruction)
.element(index);
CHECK_EQ(1, pointed_to.size());
CHECK_EQ(instruction, pointed_to[0]->instruction());
CHECK(index == pointed_to[0]->index());
@ -72,9 +71,9 @@ class BufferLivenessTest : public HloTestBase {
ShapeUtil::GetSubshape(b->shape(), index)));
// Lookup PointsTo set for instructions 'a' and 'b'.
auto& points_to_analysis = liveness.points_to_analysis();
const std::vector<const LogicalBuffer*>& points_to_a =
const auto& points_to_a =
points_to_analysis.GetPointsToSet(a).element(index);
const std::vector<const LogicalBuffer*>& points_to_b =
const auto& points_to_b =
points_to_analysis.GetPointsToSet(b).element(index);
// Make sure PointsTo sets for 'a' and 'b' are unambiguous.
EXPECT_EQ(1, points_to_a.size());
@ -435,8 +434,9 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}),
0,
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}),
"param0"));
// Create independent computations for each tuple elememt.
@ -498,8 +498,9 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}),
0,
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}),
"param0"));
// Create dependent computations for each tuple elememt.

View File

@ -187,27 +187,25 @@ Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant(
// Multiple buffers within a parameter/constant may be live out, so collect
// a set of indices at which to copy first.
points_to.ForEachElement(
[this](const ShapeIndex& index,
const std::vector<const LogicalBuffer*>& buffers) {
if (IsReadOnlyIndex(index)) {
return;
}
for (const LogicalBuffer* buffer : buffers) {
// pointee is the HloInstruction producing the buffer which may be
// liveout.
HloInstruction* pointee = buffer->instruction();
if (pointee->opcode() == HloOpcode::kParameter ||
pointee->opcode() == HloOpcode::kConstant) {
VLOG(2) << "Parameter or constant buffer " << buffer->ToString()
<< " index: " << tensorflow::str_util::Join(index, ",")
<< " may be live out of computation: "
<< pointee->ToString();
RecordIndex(index);
break;
}
}
});
points_to.ForEachElement([this](const ShapeIndex& index,
const PointsToSet::BufferList& buffers) {
if (IsReadOnlyIndex(index)) {
return;
}
for (const LogicalBuffer* buffer : buffers) {
// pointee is the HloInstruction producing the buffer which may be
// liveout.
HloInstruction* pointee = buffer->instruction();
if (pointee->opcode() == HloOpcode::kParameter ||
pointee->opcode() == HloOpcode::kConstant) {
VLOG(2) << "Parameter or constant buffer " << buffer->ToString()
<< " index: " << tensorflow::str_util::Join(index, ",")
<< " may be live out of computation: " << pointee->ToString();
RecordIndex(index);
break;
}
}
});
return Status::OK();
}
@ -230,8 +228,7 @@ Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
buffer_to_source_indices;
points_to.ForEachElement(
[this, &buffer_to_source_indices](
const ShapeIndex& index,
const std::vector<const LogicalBuffer*>& buffers) {
const ShapeIndex& index, const PointsToSet::BufferList& buffers) {
if (buffers.size() > 1) {
// Record ambiguous points-to set at 'index'.
if (!indices_to_copy_.element(index)) {
@ -285,7 +282,7 @@ Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction(
}
const auto& points_to_analysis = liveness.points_to_analysis();
// Lookup buffers for 'instruction_' and 'other_instruction'.
const std::vector<const LogicalBuffer*> instruction_buffers =
const auto instruction_buffers =
points_to_analysis.GetPointsToSet(instruction_).element(index);
// If 'instruction_' has ambiguous points-to-set at 'index', it would
// have been recorded in a previous pass (and we would have returned
@ -294,7 +291,7 @@ Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction(
CHECK_EQ(1, instruction_buffers.size());
const LogicalBuffer* instruction_buffer = instruction_buffers[0];
const std::vector<const LogicalBuffer*> other_instruction_buffers =
const auto other_instruction_buffers =
points_to_analysis.GetPointsToSet(other_instruction).element(index);
// Do not insert a copy if both instructions point at the same buffer.
// This eliminates unnecessary copies of read-only tuple elements.
@ -451,58 +448,57 @@ StatusOr<ShapeTree<HloInstruction*>> RevertReadOnlyIndicesForConstants(
FlatSet<const LogicalBuffer*> buffer_set;
ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
points_to.ForEachElement(
[init_hlo, read_only_indices, shared_copies, &buffer_set,
&copy_overrides](const ShapeIndex& index,
const std::vector<const LogicalBuffer*>& buffers) {
// Look for read-only entry parameters.
if (!read_only_indices->element(index)) {
return;
points_to.ForEachElement([init_hlo, read_only_indices, shared_copies,
&buffer_set, &copy_overrides](
const ShapeIndex& index,
const PointsToSet::BufferList& buffers) {
// Look for read-only entry parameters.
if (!read_only_indices->element(index)) {
return;
}
for (const LogicalBuffer* buffer : buffers) {
HloInstruction* pointee = buffer->instruction();
const bool is_constant = pointee->opcode() == HloOpcode::kConstant;
if (!is_constant) {
continue;
}
// We have found an constant that is read-only in
// the while body. These buffers are managed by the caller, and cannot
// be aliased with HLO buffers. Revert this read-only index,
// to allow it to be copied.
*read_only_indices->mutable_element(index) = false;
// Optimization to allow multiple while loops that share the same
// read-only entry constants to share a single copy.
// Only unambiguous and distinct array-shaped buffers are allowed, to
// reduce code complexity. The shape of the entry parameter must be
// identical to the shape of the init_hlo at this index, to ensure
// there were no intervening bitcast or GTE instructions, which are
// also hard to handle.
const Shape& pointee_shape = pointee->shape();
const Shape& init_shape =
ShapeUtil::GetSubshape(init_hlo->shape(), index);
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
ShapeUtil::Equal(pointee_shape, init_shape) &&
buffer_set.count(buffer) < 1) {
HloInstruction** copy = &(*shared_copies)[pointee];
if (*copy == nullptr) {
*copy = pointee->parent()->AddInstruction(HloInstruction::CreateUnary(
pointee_shape, HloOpcode::kCopy, pointee));
}
for (const LogicalBuffer* buffer : buffers) {
HloInstruction* pointee = buffer->instruction();
const bool is_constant = pointee->opcode() == HloOpcode::kConstant;
if (!is_constant) {
continue;
}
// Add the copy as an override.
*copy_overrides.mutable_element(index) = *copy;
}
// We have found an constant that is read-only in
// the while body. These buffers are managed by the caller, and cannot
// be aliased with HLO buffers. Revert this read-only index,
// to allow it to be copied.
*read_only_indices->mutable_element(index) = false;
// Tracks whether this current buffer is distinct.
buffer_set.insert(buffer);
// Optimization to allow multiple while loops that share the same
// read-only entry constants to share a single copy.
// Only unambiguous and distinct array-shaped buffers are allowed, to
// reduce code complexity. The shape of the entry parameter must be
// identical to the shape of the init_hlo at this index, to ensure
// there were no intervening bitcast or GTE instructions, which are
// also hard to handle.
const Shape& pointee_shape = pointee->shape();
const Shape& init_shape =
ShapeUtil::GetSubshape(init_hlo->shape(), index);
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
ShapeUtil::Equal(pointee_shape, init_shape) &&
buffer_set.count(buffer) < 1) {
HloInstruction** copy = &(*shared_copies)[pointee];
if (*copy == nullptr) {
*copy =
pointee->parent()->AddInstruction(HloInstruction::CreateUnary(
pointee_shape, HloOpcode::kCopy, pointee));
}
// Add the copy as an override.
*copy_overrides.mutable_element(index) = *copy;
}
// Tracks whether this current buffer is distinct.
buffer_set.insert(buffer);
// We've already reverted the read-only index and handled the
// single-copy optimization above, so there's nothing more to do.
break;
}
});
// We've already reverted the read-only index and handled the
// single-copy optimization above, so there's nothing more to do.
break;
}
});
return copy_overrides;
}

View File

@ -53,7 +53,7 @@ class CopyInsertionTest : public HloTestBase {
EXPECT_TRUE(points_to.IsDistinct());
EXPECT_TRUE(!points_to.IsAmbiguous());
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers =
auto maybe_live_out_buffers =
points_to_analysis
->GetPointsToSet(module->entry_computation()->root_instruction())
.CreateFlattenedSet();

View File

@ -102,6 +102,7 @@ cc_library(
":compiler_functor",
":cpu_runtime",
":cpu_runtime_avx",
":cpu_runtime_neon",
":cpu_runtime_sse4_1",
":disassembler",
":runtime_conv2d",
@ -284,6 +285,7 @@ cc_library(
deps = [
":cpu_runtime",
":cpu_runtime_avx",
":cpu_runtime_neon",
":cpu_runtime_sse4_1",
":disassembler",
":llvm_ir_runtime",
@ -309,11 +311,11 @@ cc_library(
srcs = ["cpu_runtime_sse4_1.cc"],
hdrs = ["cpu_runtime_sse4_1.h"],
copts = ["-DEIGEN_AVOID_STL_ARRAY"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
alwayslink = True,
)
cc_library(
@ -321,11 +323,24 @@ cc_library(
srcs = ["cpu_runtime_avx.cc"],
hdrs = ["cpu_runtime_avx.h"],
copts = ["-DEIGEN_AVOID_STL_ARRAY"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
)
cc_library(
name = "cpu_runtime_neon",
srcs = ["cpu_runtime_neon.cc"],
hdrs = ["cpu_runtime_neon.h"],
# runtime_copts() enables -mfpu=neon
copts = ["-DEIGEN_AVOID_STL_ARRAY"] + runtime_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
alwayslink = True,
)
cc_library(

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
#include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@ -54,6 +55,7 @@ CompilerFunctor::AllIntrinsics() {
VectorIntrinsics intrinsics;
intrinsics.sse_intrinsics = true;
intrinsics.avx_intrinsics = true;
intrinsics.neon_intrinsics = true;
return intrinsics;
}
@ -150,20 +152,28 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
CompilerFunctor::VectorIntrinsics const& available_intrinsics) {
std::vector<llvm::VecDesc> vector_functions;
const llvm::VecDesc four_wide_vector_functions[] = {
{"expf", runtime::kExpV4F32SymbolName, 4},
{"llvm.exp.f32", runtime::kExpV4F32SymbolName, 4},
const llvm::VecDesc four_wide_vector_functions_neon[] = {
{"expf", runtime::kExpV4F32NEONSymbolName, 4},
{"llvm.exp.f32", runtime::kExpV4F32NEONSymbolName, 4},
{"logf", runtime::kLogV4F32SymbolName, 4},
{"llvm.log.f32", runtime::kLogV4F32SymbolName, 4},
{"logf", runtime::kLogV4F32NEONSymbolName, 4},
{"llvm.log.f32", runtime::kLogV4F32NEONSymbolName, 4},
};
const llvm::VecDesc eight_wide_vector_functions[] = {
{"expf", runtime::kExpV8F32SymbolName, 8},
{"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8},
const llvm::VecDesc four_wide_vector_functions_sse[] = {
{"expf", runtime::kExpV4F32SSESymbolName, 4},
{"llvm.exp.f32", runtime::kExpV4F32SSESymbolName, 4},
{"logf", runtime::kLogV8F32SymbolName, 8},
{"llvm.log.f32", runtime::kLogV8F32SymbolName, 8},
{"logf", runtime::kLogV4F32SSESymbolName, 4},
{"llvm.log.f32", runtime::kLogV4F32SSESymbolName, 4},
};
const llvm::VecDesc eight_wide_vector_functions_avx[] = {
{"expf", runtime::kExpV8F32AVXSymbolName, 8},
{"llvm.exp.f32", runtime::kExpV8F32AVXSymbolName, 8},
{"logf", runtime::kLogV8F32AVXSymbolName, 8},
{"llvm.log.f32", runtime::kLogV8F32AVXSymbolName, 8},
};
// These functions are generated by XLA as LLVM IR, so they're always
@ -176,27 +186,45 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
{"llvm.tanh.f32", runtime::kTanhV8F32SymbolName, 8},
};
if (arch == llvm::Triple::x86 || llvm::Triple::x86_64) {
llvm::SmallVector<llvm::StringRef, 32> features;
feature_string.split(features, ',', -1, /*KeepEmpty=*/false);
if (std::find(features.begin(), features.end(), "+sse4.1") !=
features.end() &&
available_intrinsics.sse_intrinsics) {
vector_functions.insert(vector_functions.end(),
std::begin(four_wide_vector_functions),
std::end(four_wide_vector_functions));
llvm::SmallVector<llvm::StringRef, 32> features;
feature_string.split(features, ',', -1, /*KeepEmpty=*/false);
auto has_feature = [&features](const llvm::StringRef feature) {
return std::find(features.begin(), features.end(), feature) !=
features.end();
};
switch (arch) {
case llvm::Triple::x86:
case llvm::Triple::x86_64: {
if (has_feature("+sse4.1") && available_intrinsics.sse_intrinsics) {
vector_functions.insert(vector_functions.end(),
std::begin(four_wide_vector_functions_sse),
std::end(four_wide_vector_functions_sse));
}
if (has_feature("+avx") && available_intrinsics.avx_intrinsics) {
vector_functions.insert(vector_functions.end(),
std::begin(eight_wide_vector_functions_avx),
std::end(eight_wide_vector_functions_avx));
}
break;
}
if (std::find(features.begin(), features.end(), "+avx") != features.end() &&
available_intrinsics.avx_intrinsics) {
vector_functions.insert(vector_functions.end(),
std::begin(eight_wide_vector_functions),
std::end(eight_wide_vector_functions));
case llvm::Triple::arm:
case llvm::Triple::aarch64: {
if (has_feature("+neon") && available_intrinsics.neon_intrinsics) {
vector_functions.insert(vector_functions.end(),
std::begin(four_wide_vector_functions_neon),
std::end(four_wide_vector_functions_neon));
}
break;
}
default:
break;
}
vector_functions.insert(vector_functions.end(),
std::begin(ir_vector_functions),
std::end(ir_vector_functions));
return vector_functions;
}
} // namespace

View File

@ -35,6 +35,7 @@ class CompilerFunctor {
struct VectorIntrinsics {
bool sse_intrinsics;
bool avx_intrinsics;
bool neon_intrinsics;
};
// Returns a VectorIntrinsics where all intrinsics are available.

View File

@ -260,6 +260,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
pass.AddPass<BatchNormRewriter>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true,
/*use_fusion=*/false);
pass.AddPass<AlgebraicSimplifier>(

View File

@ -316,7 +316,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> CpuExecutable::ExecuteOnStream(
[&buffers, &buffers_in_result, &result_buffer, this](
const ShapeIndex& index, size_t* buffer_entry) {
if (ShapeUtil::IsLeafIndex(result_buffer->shape(), index)) {
const std::vector<const LogicalBuffer*>& sources =
const auto& sources =
this->GetRootPointsToSet().element(index);
// The points to set is unambiguous so the set should be a
// singleton.

View File

@ -20,13 +20,13 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#ifdef __AVX__
xla::cpu::runtime::V8F32 __xla_cpu_runtime_ExpV8F32(
xla::cpu::runtime::V8F32 x) {
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX(
xla::cpu::runtime::V8F32AVX x) {
return Eigen::internal::pexp(x);
}
xla::cpu::runtime::V8F32 __xla_cpu_runtime_LogV8F32(
xla::cpu::runtime::V8F32 x) {
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
xla::cpu::runtime::V8F32AVX x) {
return Eigen::internal::plog(x);
}
#endif // __AVX__
@ -35,8 +35,8 @@ namespace xla {
namespace cpu {
namespace runtime {
const char *const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32";
const char *const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32";
const char *const kExpV8F32AVXSymbolName = "__xla_cpu_runtime_ExpV8F32AVX";
const char *const kLogV8F32AVXSymbolName = "__xla_cpu_runtime_LogV8F32AVX";
} // namespace runtime
} // namespace cpu

View File

@ -28,11 +28,10 @@ namespace xla {
namespace cpu {
namespace runtime {
extern const char *const kExpV8F32SymbolName;
extern const char *const kLogV8F32SymbolName;
extern const char *const kTanhV8F32SymbolName;
extern const char *const kExpV8F32AVXSymbolName;
extern const char *const kLogV8F32AVXSymbolName;
typedef float V8F32 __attribute__((__vector_size__(32)));
typedef float V8F32AVX __attribute__((__vector_size__(32)));
} // namespace runtime
} // namespace cpu
} // namespace xla
@ -42,12 +41,11 @@ extern "C" {
// The following functions are vectorized versions of a selection of libm
// library functions.
// References to these functions are created by the LLVM vectorizer.
xla::cpu::runtime::V8F32 __xla_cpu_runtime_ExpV8F32(xla::cpu::runtime::V8F32 x)
TF_ATTRIBUTE_WEAK;
xla::cpu::runtime::V8F32 __xla_cpu_runtime_LogV8F32(xla::cpu::runtime::V8F32 x)
TF_ATTRIBUTE_WEAK;
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX(
xla::cpu::runtime::V8F32AVX x) TF_ATTRIBUTE_WEAK;
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
xla::cpu::runtime::V8F32AVX x) TF_ATTRIBUTE_WEAK;
}
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_

View File

@ -0,0 +1,46 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
#define EIGEN_USE_THREADS
#include "third_party/eigen3/Eigen/Core"
#ifdef __ARM_NEON__
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
xla::cpu::runtime::V4F32NEON x) {
return Eigen::internal::pexp(x);
}
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
xla::cpu::runtime::V4F32NEON x) {
Eigen::internal::Packet4f p = x;
return Eigen::internal::plog(p);
}
#endif // __ARM_NEON__
namespace xla {
namespace cpu {
namespace runtime {
const char *const kExpV4F32NEONSymbolName = "__xla_cpu_runtime_ExpV4F32NEON";
const char *const kLogV4F32NEONSymbolName = "__xla_cpu_runtime_LogV4F32NEON";
} // namespace runtime
} // namespace cpu
} // namespace xla

View File

@ -0,0 +1,62 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
// This header declares functions which may be called by the generated code on
// the CPU. Calls to these functions must be resolved explicitly in the JIT in
// xla::cpu::SimpleResolver.
#include "tensorflow/core/platform/macros.h"
#ifdef __ARM_NEON__
// For the other runtimes (AVX, SSE4.1) we define the vector type directly using
// __attribute__((__vector_size__(*))). Unfortunately, the typedef for the ARM
// NEON SIMD types is not portable, so the type has to come from <arm_neon.h>
#include <arm_neon.h>
#endif // __ARM_NEON__
namespace xla {
namespace cpu {
namespace runtime {
extern const char *const kExpV4F32NEONSymbolName;
extern const char *const kLogV4F32NEONSymbolName;
#ifdef __ARM_NEON__
typedef float32x4_t V4F32NEON;
#else
// On non-ARM platforms ensure the declaration is present
struct V4F32NEON;
#endif // __ARM_NEON__
} // namespace runtime
} // namespace cpu
} // namespace xla
extern "C" {
// The following functions are vectorized versions of a selection of libm
// library functions.
// References to these functions are created by the LLVM vectorizer.
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
xla::cpu::runtime::V4F32NEON x) TF_ATTRIBUTE_WEAK;
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
xla::cpu::runtime::V4F32NEON x) TF_ATTRIBUTE_WEAK;
}
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_

View File

@ -21,14 +21,14 @@ limitations under the License.
#ifdef __SSE4_1__
xla::cpu::runtime::V4F32 __xla_cpu_runtime_ExpV4F32(
xla::cpu::runtime::V4F32 x) {
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE(
xla::cpu::runtime::V4F32SSE x) {
Eigen::internal::Packet4f p = x;
return Eigen::internal::pexp(p);
}
xla::cpu::runtime::V4F32 __xla_cpu_runtime_LogV4F32(
xla::cpu::runtime::V4F32 x) {
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
xla::cpu::runtime::V4F32SSE x) {
Eigen::internal::Packet4f p = x;
return Eigen::internal::plog(p);
}
@ -39,8 +39,8 @@ namespace xla {
namespace cpu {
namespace runtime {
const char *const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32";
const char *const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32";
const char *const kExpV4F32SSESymbolName = "__xla_cpu_runtime_ExpV4F32SSE";
const char *const kLogV4F32SSESymbolName = "__xla_cpu_runtime_LogV4F32SSE";
} // namespace runtime
} // namespace cpu

View File

@ -28,11 +28,10 @@ namespace xla {
namespace cpu {
namespace runtime {
extern const char *const kExpV4F32SymbolName;
extern const char *const kLogV4F32SymbolName;
extern const char *const kTanhV4F32SymbolName;
extern const char *const kExpV4F32SSESymbolName;
extern const char *const kLogV4F32SSESymbolName;
typedef float V4F32 __attribute__((__vector_size__(16)));
typedef float V4F32SSE __attribute__((__vector_size__(16)));
} // namespace runtime
} // namespace cpu
@ -43,11 +42,11 @@ extern "C" {
// The following functions are vectorized versions of a selection of libm
// library functions.
// References to these functions are created by the LLVM vectorizer.
xla::cpu::runtime::V4F32 __xla_cpu_runtime_ExpV4F32(xla::cpu::runtime::V4F32 x)
TF_ATTRIBUTE_WEAK;
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE(
xla::cpu::runtime::V4F32SSE x) TF_ATTRIBUTE_WEAK;
xla::cpu::runtime::V4F32 __xla_cpu_runtime_LogV4F32(xla::cpu::runtime::V4F32 x)
TF_ATTRIBUTE_WEAK;
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
xla::cpu::runtime::V4F32SSE x) TF_ATTRIBUTE_WEAK;
}
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_

View File

@ -566,7 +566,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> ParallelCpuExecutable::ExecuteOnStream(
[&buffers, &buffers_in_result, &result_buffer, this](
const ShapeIndex& index, size_t* buffer_entry) {
if (ShapeUtil::IsLeafIndex(result_buffer->shape(), index)) {
const std::vector<const LogicalBuffer*>& sources =
const auto& sources =
this->GetRootPointsToSet().element(index);
// The points to set is unambiguous so the set should be a
// singleton.

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
@ -91,10 +92,12 @@ class JITSymbolTable {
ADD_JIT_SYMBOL_TO_TABLE(ReleaseInfeedBufferAfterDequeue);
ADD_JIT_SYMBOL_TO_TABLE(AcquireOutfeedBufferForPopulation);
ADD_JIT_SYMBOL_TO_TABLE(ReleaseOutfeedBufferAfterPopulation);
ADD_JIT_SYMBOL_TO_TABLE(ExpV8F32);
ADD_JIT_SYMBOL_TO_TABLE(LogV8F32);
ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32);
ADD_JIT_SYMBOL_TO_TABLE(LogV4F32);
ADD_JIT_SYMBOL_TO_TABLE(ExpV8F32AVX);
ADD_JIT_SYMBOL_TO_TABLE(LogV8F32AVX);
ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32SSE);
ADD_JIT_SYMBOL_TO_TABLE(LogV4F32SSE);
ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32NEON);
ADD_JIT_SYMBOL_TO_TABLE(LogV4F32NEON);
ADD_JIT_SYMBOL_TO_TABLE(EigenConvF32);
ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF32);
ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF64);
@ -162,8 +165,9 @@ llvm::StringRef GetHostCpuName() {
CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
CompilerFunctor::VectorIntrinsics intrinsics;
intrinsics.sse_intrinsics = (&__xla_cpu_runtime_ExpV4F32 != nullptr);
intrinsics.avx_intrinsics = (&__xla_cpu_runtime_ExpV8F32 != nullptr);
intrinsics.sse_intrinsics = (&__xla_cpu_runtime_ExpV4F32SSE != nullptr);
intrinsics.avx_intrinsics = (&__xla_cpu_runtime_ExpV8F32AVX != nullptr);
intrinsics.neon_intrinsics = (&__xla_cpu_runtime_ExpV4F32NEON != nullptr);
return intrinsics;
}

View File

@ -228,6 +228,9 @@ class DfsHloVisitor {
virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0;
virtual Status HandleBatchNormInference(
HloInstruction* batchNormInference) = 0;
virtual Status HandleBatchNormGrad(HloInstruction* batchNormGrad) = 0;
// Invoked to inform the visitor that the traversal has completed, and that
@ -245,6 +248,11 @@ class DfsHloVisitor {
VisitState GetVisitState(int id) { return visit_state_.GetState(id); }
VisitState GetVisitState(const HloInstruction& instruction);
// Resize internal state if necessary to hold state for ids <= num.
// This call is purely a performance hint and can be omitted without
// affecting correctness.
void ReserveVisitStates(int num) { visit_state_.Reserve(num); }
void SetVisitState(int id, VisitState state) {
visit_state_.SetState(id, state);
}
@ -298,35 +306,35 @@ class DfsHloVisitor {
private:
class DFSVisitStates {
public:
DFSVisitStates() {
// Avoid frequent resizes of the visited bits array
states_.reserve(512);
DFSVisitStates() {}
void Reserve(uint64 num) {
states_.reserve((num + kStatesPerWord - 1) / kStatesPerWord);
}
VisitState GetState(int id) {
int word_index = id / kStatesPerWord;
VisitState GetState(uint64 id) {
uint64 word_index = id / kStatesPerWord;
if (word_index >= states_.size()) {
return VisitState::kNotVisited;
}
static_assert(static_cast<int>(VisitState::kVisited) < 3,
"VisitState must fit in two bits");
uint64 w = states_[word_index];
int shift = 2 * (id % kStatesPerWord); // 2 bits per state
uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state
return static_cast<VisitState>((w >> shift) & 0x3);
}
void SetState(int id, VisitState state) {
int word_index = id / kStatesPerWord;
void SetState(uint64 id, VisitState state) {
uint64 word_index = id / kStatesPerWord;
if (word_index >= states_.size()) {
states_.resize(word_index + 1, 0);
}
uint64* w = &states_[word_index];
int shift = 2 * (id % kStatesPerWord); // 2 bits per state
uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state
uint64 mask = 0x3ull << shift;
*w = (*w & ~mask) | (static_cast<uint64>(state) << shift);
DCHECK_EQ(GetState(id), state);
}
private:
static const int kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/;
static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/;
// Map from id to two-bit states. We store 32 such states per 64-bit
// value
std::vector<uint64> states_;

View File

@ -54,6 +54,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
return DefaultAction(hlo);
}
Status HandleBatchNormInference(HloInstruction* hlo) override {
return DefaultAction(hlo);
}
Status HandleBatchNormGrad(HloInstruction* hlo) override {
return DefaultAction(hlo);
}

View File

@ -135,6 +135,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
// instead.
pass.AddPass<BatchNormRewriter>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true,
/*use_fusion=*/false);
pass.AddPass<AlgebraicSimplifier>(

View File

@ -232,7 +232,7 @@ StatusOr<se::DeviceMemoryBase> GpuExecutable::ExecuteOnStream(
TF_RETURN_IF_ERROR(GetRootPointsToSet().ForEachElementWithStatus(
[&referred_by_output, &buffer_allocations, this](
const ShapeIndex& /*index*/,
const std::vector<const LogicalBuffer*>& buffers) {
const PointsToSet::BufferList& buffers) {
// The points to set is unambiguous so the set should be a
// singleton. That is, we know exactly which instruction produced
// the array at this element.
@ -311,7 +311,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> GpuExecutable::ExecuteOnStream(
[&buffer_allocations, &buffers_in_result, &shaped_buffer, this](
const ShapeIndex& index, size_t* buffer_entry) {
if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) {
const std::vector<const LogicalBuffer*>& sources =
const auto& sources =
this->GetRootPointsToSet().element(index);
// The points to set is unambiguous so the set should be a
// singleton. That is, we know exactly which instruction

View File

@ -39,7 +39,7 @@ std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
for (const HloInstruction* operand : instruction->operands()) {
points_to_analysis.GetPointsToSet(operand).ForEachElement(
[&](const ShapeIndex& /*index*/,
const std::vector<const LogicalBuffer*>& points_to) {
const PointsToSet::BufferList& points_to) {
buffers.insert(buffers.end(), points_to.begin(), points_to.end());
});
}
@ -107,7 +107,7 @@ Status HeapSimulator::RunComputation(
FlatMap<const LogicalBuffer*, FlatSet<const HloInstruction*>> live_buffers;
const HloInstruction* root = computation.root_instruction();
FlatSet<const LogicalBuffer*> output_source_buffers =
auto output_source_buffers =
points_to_analysis.GetPointsToSet(root).CreateFlattenedSet();
std::vector<const LogicalBuffer*> dead_buffers_to_free;

View File

@ -481,7 +481,7 @@ TEST_F(HeapSimulatorTest, WholeModule) {
// Base class for heap algorithm tests.
class HeapAlgorithmTestBase : public ::testing::Test {
protected:
HeapAlgorithmTestBase() {
HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
buffer_a_ = DummyLogicalBuffer();
buffer_b_ = DummyLogicalBuffer();
buffer_c_ = DummyLogicalBuffer();
@ -505,15 +505,16 @@ class HeapAlgorithmTestBase : public ::testing::Test {
const LogicalBuffer* buffer_i_;
private:
// Create a dummy LogicalBuffer to pass to the heap algorithm. Since the
// algorithms only use the buffer as a handle, we don't need to fill in much
// other than the id and color.
// Create a dummy LogicalBuffer to pass to the heap algorithm.
const LogicalBuffer* DummyLogicalBuffer() {
const LogicalBuffer::Id id = buffers_.size();
buffers_.emplace_back(MakeUnique<LogicalBuffer>(nullptr, ShapeIndex{}, id));
auto const0 = builder_.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
buffers_.emplace_back(MakeUnique<LogicalBuffer>(const0, ShapeIndex{}, id));
return buffers_.back().get();
}
HloComputation::Builder builder_;
std::vector<std::unique_ptr<LogicalBuffer>> buffers_;
};

View File

@ -374,6 +374,12 @@ Status HloCostAnalysis::HandleBatchNormTraining(
return Status::OK();
}
Status HloCostAnalysis::HandleBatchNormInference(
HloInstruction* batchNormInference) {
// TODO(b/62294698): Implement cost analysis for batch-norm-inference.
return Status::OK();
}
Status HloCostAnalysis::HandleBatchNormGrad(HloInstruction* batchNormGrad) {
// TODO(b/62294698): Implement cost analysis for batch-norm-grad.
return Status::OK();

View File

@ -89,6 +89,7 @@ class HloCostAnalysis : public DfsHloVisitor {
tensorflow::gtl::ArraySlice<int64> dimensions,
HloComputation* function_handle) override;
Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override;
Status HandleBatchNormInference(HloInstruction* batchNormInference) override;
Status HandleBatchNormGrad(HloInstruction* batchNormGrad) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleCall(HloInstruction* call) override;

View File

@ -742,6 +742,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kParameter:
return kOrange;
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kReduce:
case HloOpcode::kSelectAndScatter:

View File

@ -406,6 +406,23 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape,
return instruction;
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index) {
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape));
instruction->AppendOperand(operand);
instruction->AppendOperand(scale);
instruction->AppendOperand(offset);
instruction->AppendOperand(mean);
instruction->AppendOperand(variance);
instruction->epsilon_ = epsilon;
instruction->feature_index_ = feature_index;
return instruction;
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
HloInstruction* scale, HloInstruction* mean,
@ -726,7 +743,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// If this is already a multioutput fusion instruction, expand the root
// tuple by 1.
HloInstruction* fused_root = fused_expression_root();
std::vector<HloInstruction*> tuple_elements;
HloInstruction::InstructionVector tuple_elements;
bool newly_created_tuple_instr = false;
if (fused_root->opcode() == HloOpcode::kTuple) {
tuple_elements = fused_root->operands();
@ -735,10 +752,9 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
newly_created_tuple_instr = true;
}
if (clone->opcode() == HloOpcode::kTuple) {
const auto& tuple_elements_to_fuse = clone->operands();
tuple_elements.insert(tuple_elements.end(),
tuple_elements_to_fuse.begin(),
tuple_elements_to_fuse.end());
for (auto inst : clone->operands()) {
tuple_elements.push_back(inst);
}
} else {
tuple_elements.push_back(clone);
}
@ -1065,6 +1081,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return CreateBatchNormTraining(shape, new_operands[0], new_operands[1],
new_operands[2], epsilon(),
feature_index());
case HloOpcode::kBatchNormInference:
CHECK_EQ(new_operands.size(), 5);
return CreateBatchNormInference(
shape, new_operands[0], new_operands[1], new_operands[2],
new_operands[3], new_operands[4], epsilon(), feature_index());
case HloOpcode::kInfeed:
CHECK_EQ(new_operands.size(), 0);
return CreateInfeed(shape, infeed_config());
@ -1355,6 +1377,7 @@ bool HloInstruction::IdenticalSlowPath(
ShapeUtil::Compatible(shape(), other.shape());
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
return feature_index() == other.feature_index() &&
epsilon() == other.epsilon();
@ -1940,8 +1963,8 @@ HloInstruction::fused_instructions() const {
HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
: unique_id_(-1),
shape_(shape),
opcode_(opcode),
shape_(shape),
name_("%" + HloOpcodeString(opcode)) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
}
@ -1952,6 +1975,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
return visitor->HandleAbs(this, operands_[0]);
case HloOpcode::kBatchNormTraining:
return visitor->HandleBatchNormTraining(this);
case HloOpcode::kBatchNormInference:
return visitor->HandleBatchNormInference(this);
case HloOpcode::kBatchNormGrad:
return visitor->HandleBatchNormGrad(this);
case HloOpcode::kSign:
@ -2092,12 +2117,13 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
HloOpcodeString(opcode_).c_str());
}
using DFSStack =
tensorflow::gtl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
// Push "child" onto the dfs_stack if not already visited. Returns false if a
// cycle was detected, and true otherwise.
inline bool PushDFSChild(
DfsHloVisitor* visitor,
std::vector<std::pair<int, HloInstruction*>>* dfs_stack,
HloInstruction* child) {
inline bool PushDFSChild(DfsHloVisitor* visitor, DFSStack* dfs_stack,
HloInstruction* child) {
const int id = child->unique_id();
CHECK_GE(id, 0) << "instruction may not have a parent computation";
switch (visitor->GetVisitState(id)) {
@ -2120,13 +2146,15 @@ using InternalCompareFunction =
static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
const InternalCompareFunction* operand_order,
bool ignore_control_predecessors) {
visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds());
// dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
//
// We need to keep track of both the id and the instruction because
// instructions can get deleted while they are on the stack, so we
// can't always use the (potentiall dead) instruction object to grab
// its id.
std::vector<std::pair<int, HloInstruction*>> dfs_stack;
DFSStack dfs_stack;
dfs_stack.emplace_back(root->unique_id(), root);
do {

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@ -224,6 +225,12 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, float epsilon, int64 feature_index);
// Creates a batch-norm-inference instruction.
static std::unique_ptr<HloInstruction> CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index);
// Creates a batch-norm-grad instruction.
static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
@ -330,7 +337,8 @@ class HloInstruction {
int64 operand_count() const { return operands_.size(); }
// Returns the vector of operands of this instruction.
const std::vector<HloInstruction*>& operands() const { return operands_; }
using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
const InstructionVector& operands() const { return operands_; }
// Returns the index of 'target' in the operands sequence.
// Precondition: target must be an operand (or a fatal error will occur).
@ -980,15 +988,34 @@ class HloInstruction {
int unique_id_; // Unique to this HloInstruction within a HloModule
// Opcode for this instruction.
HloOpcode opcode_;
// Instruction operands.
InstructionVector operands_;
// The set of control predecessors of this instruction.
std::vector<HloInstruction*> control_predecessors_;
// The users of this instruction. Users are HLOs where this instruction is an
// operand. The vector users_ and the set user_set_ contain identical
// members. The set enables fast membership testing and the vector enables
// fast, stable iteration.
std::vector<HloInstruction*> users_;
std::unordered_set<const HloInstruction*> user_set_;
// The set of control successors of this instruction.
std::vector<HloInstruction*> control_successors_;
// The computation in which this instruction is contained.
HloComputation* parent_ = nullptr;
// Shape of outfeed request.
Shape outfeed_shape_;
// Result shape of this instruction.
Shape shape_;
// Opcode for this instruction.
HloOpcode opcode_;
// Literal, only present for kConstant.
std::unique_ptr<Literal> literal_;
@ -1054,22 +1081,6 @@ class HloInstruction {
// Outfeed configuration information, only present for kOutfeed.
string outfeed_config_;
// Instruction operands.
std::vector<HloInstruction*> operands_;
// The users of this instruction. Users are HLOs where this instruction is an
// operand. The vector users_ and the set user_set_ contain identical
// members. The set enables fast membership testing and the vector enables
// fast, stable iteration.
std::vector<HloInstruction*> users_;
std::unordered_set<const HloInstruction*> user_set_;
// The set of control predecessors of this instruction.
std::vector<HloInstruction*> control_predecessors_;
// The set of control successors of this instruction.
std::vector<HloInstruction*> control_successors_;
// A trace instruction that consumes this instruction.
//
// Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
@ -1098,9 +1109,6 @@ class HloInstruction {
// String identifier for instruction.
string name_;
// The computation in which this instruction is contained.
HloComputation* parent_ = nullptr;
// Metadata for debugging.
OpMetadata metadata_;

View File

@ -33,6 +33,8 @@ string HloOpcodeString(HloOpcode opcode) {
return "add";
case HloOpcode::kBatchNormTraining:
return "batch-norm-training";
case HloOpcode::kBatchNormInference:
return "batch-norm-inference";
case HloOpcode::kBatchNormGrad:
return "batch-norm-grad";
case HloOpcode::kBitcast:

View File

@ -31,6 +31,7 @@ enum class HloOpcode {
kAbs,
kAdd,
kBatchNormTraining,
kBatchNormInference,
kBatchNormGrad,
kBitcast,
kBroadcast,

View File

@ -55,16 +55,6 @@ namespace {
// Returns true if the given instruction is rematerializable.
bool IsRematerializable(const HloInstruction* instruction) {
// Conservatively, don't rematerialize instruction with control
// dependencies. For one, control dependencies are added to prevent
// interference of aliased buffers (say, in while bodies) and
// rematerialization is ignorant of liveness and may break the intended
// ordering.
if (!instruction->control_predecessors().empty() ||
!instruction->control_successors().empty()) {
return false;
}
// Don't rematerialize instructions with side effects or instructions which
// cannot be cloned safely.
switch (instruction->opcode()) {
@ -503,7 +493,7 @@ MemoryUsageTracker::MemoryUsageTracker(
const TuplePointsToAnalysis& points_to_analysis,
const InstructionList& instruction_list)
: computation_(computation), instruction_list_(instruction_list) {
tensorflow::gtl::FlatSet<const LogicalBuffer*> live_out_set =
PointsToSet::BufferSet live_out_set =
points_to_analysis.GetPointsToSet(computation_->root_instruction())
.CreateFlattenedSet();
tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
@ -906,6 +896,19 @@ Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
continue;
}
// If any of the candidate's control successor has been placed, we need to
// skip this candidate. Otherwise we will violate control dependency.
bool control_successor_placed =
std::any_of(candidate->control_successors().begin(),
candidate->control_successors().end(),
[&memory_tracker](const HloInstruction* inst) {
return memory_tracker.IsPlaced(inst);
});
if (control_successor_placed) {
continue;
}
const int64 memory_reduced =
memory_tracker.MemoryReducedIfRematerialized(item);
@ -1047,6 +1050,15 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
HloInstruction* remat =
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
// Add control dependencies to the new operation.
for (auto successor : best->control_successors()) {
TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
}
for (auto predecessor : best->control_predecessors()) {
TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
}
Item* remat_item = instruction_list.CreateItem(remat);
// Replace each remaining use of 'best' with the rematerialization.
@ -1082,6 +1094,15 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
}
}
}
// Insert rematerialized instruction before any of its successors to
// preserve ordering regarding control dependency.
for (auto successor : remat->control_successors()) {
Item* successor_item = instruction_list.GetItem(successor);
// Assert to make sure we never remat an operation with control
// successor already placed.
CHECK(!successor_item->placed);
place_before.push_back(successor_item);
}
instruction_list.InsertBeforeInstructions(remat_item, place_before);
// If the rematerialized instruction is dead then rematerialization is

View File

@ -78,6 +78,7 @@ namespace xla {
// Expensive instructions.
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kCall:
case HloOpcode::kConvolution:

View File

@ -628,9 +628,8 @@ Status CheckLayouts(
const PointsToSet& points_to_set =
points_to_analysis->GetPointsToSet(instruction.get());
TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
[&instruction](
ShapeIndex index,
const std::vector<const LogicalBuffer*>& buffers) -> Status {
[&instruction](ShapeIndex index,
const PointsToSet::BufferList& buffers) -> Status {
if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
const Shape& instruction_subshape =
ShapeUtil::GetSubshape(instruction->shape(), index);
@ -934,7 +933,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs(
return points_to_set.ForEachElementWithStatus(
[this, &shape_layout, constraints](
const ShapeIndex& index,
const std::vector<const LogicalBuffer*>& buffers) -> Status {
const PointsToSet::BufferList& buffers) -> Status {
if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
for (const LogicalBuffer* buffer : buffers) {
if (constraints->BufferLayout(*buffer) == nullptr &&
@ -1076,7 +1075,7 @@ StatusOr<Layout> InferArrayLayout(
TF_RET_CHECK(
!points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index));
const std::vector<const LogicalBuffer*>& source_buffers =
const auto& source_buffers =
points_to_analysis.GetPointsToSet(instruction).element(index);
TF_RET_CHECK(!source_buffers.empty());

View File

@ -80,7 +80,7 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
HloInstruction* instruction, const ShapeIndex& index,
const TuplePointsToAnalysis& points_to_analysis) {
std::vector<std::pair<HloInstruction*, int64>> uses;
const std::vector<const LogicalBuffer*>& points_to =
const PointsToSet::BufferList& points_to =
points_to_analysis.GetPointsToSet(instruction).element(index);
for (const LogicalBuffer* buffer : points_to) {
for (const BufferAlias& alias :

View File

@ -26,6 +26,14 @@ limitations under the License.
namespace xla {
LogicalBuffer::LogicalBuffer(HloInstruction* instruction,
const ShapeIndex& index, Id id)
: instruction_(instruction), id_(id), color_(kInvalidColor), index_(index) {
const auto& s = shape();
is_array_ = ShapeUtil::IsArray(s);
is_tuple_ = ShapeUtil::IsTuple(s);
}
string LogicalBuffer::ToString() const {
return tensorflow::strings::StrCat(instruction_->name(), "[",
tensorflow::str_util::Join(index_, ","),

View File

@ -97,11 +97,7 @@ class LogicalBuffer {
using SizeFunction = std::function<int64(const LogicalBuffer&)>;
using AlignmentFunction = std::function<int64(LogicalBuffer::Color)>;
LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id)
: instruction_(instruction),
index_(index),
id_(id),
color_(kInvalidColor) {}
LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id);
Id id() const { return id_; }
@ -140,14 +136,14 @@ class LogicalBuffer {
bool IsTopLevel() const { return index_.empty(); }
// Whether this buffer contains a tuple.
bool IsTuple() const { return ShapeUtil::IsTuple(shape()); }
bool IsTuple() const { return is_tuple_; }
// Whether this buffer contains an array.
bool IsArray() const { return is_array_; }
// operator< is required for std::set.
bool operator<(const LogicalBuffer& other) const { return id_ < other.id_; }
// Whether this buffer contains an array.
bool IsArray() const { return ShapeUtil::IsArray(shape()); }
string ToString() const;
LogicalBufferProto ToProto(const SizeFunction& size_fn) const;
@ -160,9 +156,11 @@ class LogicalBuffer {
private:
HloInstruction* instruction_;
ShapeIndex index_;
Id id_;
Id id_ : 62;
bool is_array_ : 1;
bool is_tuple_ : 1;
Color color_;
ShapeIndex index_;
// Similar to HLO constructs (HloInstruction, etc), pointers are used for
// comparison to equality, so disable all copying.

View File

@ -129,7 +129,7 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) {
// metadata.
bool IsElementwiseOfEquivalentReshapesOrTransposes(
const HloInstruction* instruction) {
const std::vector<HloInstruction*>& operands = instruction->operands();
const auto& operands = instruction->operands();
HloInstruction* first_reshape_operand =
FirstNonScalarAndNonTrivialReshapeOperand(instruction);
// If there are no non-trivial reshapes or transposes, then there is nothing
@ -216,7 +216,7 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
<< "\n\tnew elementwise shape: "
<< ShapeUtil::HumanString(new_elementwise_shape);
std::vector<HloInstruction*> operands = instruction->operands();
auto operands = instruction->operands();
for (size_t i = 0; i < operands.size(); ++i) {
// All scalar operands remain as-is, even if they're reshape or transpose,
// to simplify handling wrt special scalar broadcast rules for ops like

View File

@ -1211,6 +1211,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status = computation->AddBatchNormTrainingInstruction(
arg->batch_norm_training_request());
break;
case OpRequest::kBatchNormInferenceRequest:
handle_status = computation->AddBatchNormInferenceInstruction(
arg->batch_norm_inference_request());
break;
case OpRequest::kBatchNormGradRequest:
handle_status = computation->AddBatchNormGradInstruction(
arg->batch_norm_grad_request());

View File

@ -885,6 +885,150 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
output_shape_for_mean_and_var});
}
/* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
const Shape& operand_shape, const Shape& offset_shape,
const Shape& scale_shape, const Shape& mean_shape,
const Shape& variance_shape, int64 feature_index) {
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
offset_shape, "offset input of batch norm inference"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
scale_shape, "scale input of batch norm inference"));
TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) ==
tensorflow::Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) ==
tensorflow::Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) ==
tensorflow::Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShape(mean_shape) ==
tensorflow::Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShape(variance_shape) ==
tensorflow::Status::OK());
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
return InvalidArgument(
"Expected feature_index of batch-norm-inference to be "
"smaller than the rank of operand_shape; "
"got feature_index %lld, and rank %lld",
feature_index, ShapeUtil::Rank(operand_shape));
}
if (feature_index < 0) {
return InvalidArgument(
"Expected feature_index of batch-norm-inference to "
"be a non-negative number, got %lld",
feature_index);
}
if (ShapeUtil::Rank(operand_shape) < 1) {
return InvalidArgument(
"Expected the rank of operand to "
"batch-norm-inference to be at least 1; got %lld",
ShapeUtil::Rank(operand_shape));
}
if (ShapeUtil::Rank(offset_shape) != 1) {
return InvalidArgument(
"Offset input of batch-norm-inference must have"
" rank 1, but has rank %lld.",
ShapeUtil::Rank(offset_shape));
}
if (ShapeUtil::Rank(scale_shape) != 1) {
return InvalidArgument(
"Scale input of batch-norm-inference must have"
" rank 1, but has rank %lld.",
ShapeUtil::Rank(scale_shape));
}
if (!ShapeUtil::ElementIsFloating(operand_shape)) {
return InvalidArgument(
"The operand to batch-norm-inference must have a floating point "
"element type, but the shape is %s",
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
"but the shape of offset factor is %s "
"and the shape of operand is %s",
PrimitiveType_Name(offset_shape.element_type()).c_str(),
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
"but the shape of scale factor is %s "
"and the shape of operand is %s",
PrimitiveType_Name(scale_shape.element_type()).c_str(),
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
"but the shape of mean is %s "
"and the shape of operand is %s",
PrimitiveType_Name(mean_shape.element_type()).c_str(),
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
"but the shape of variance is %s "
"and the shape of operand is %s",
PrimitiveType_Name(mean_shape.element_type()).c_str(),
PrimitiveType_Name(variance_shape.element_type()).c_str());
}
const int64 feature_count = operand_shape.dimensions(feature_index);
Shape output_shape_for_mean_and_var =
ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
return InvalidArgument(
"The size of offset factor should be the same as feature count,"
"but the size of offset factor is %lld "
"and the feature count is %lld",
ShapeUtil::GetDimension(offset_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
return InvalidArgument(
"The size of scale factor should be the same as feature count,"
"but the size of scale factor is %lld "
"and the feature count is %lld",
ShapeUtil::GetDimension(scale_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
return InvalidArgument(
"The size of mean should be the same as feature count,"
"but the size of mean is %lld "
"and the feature count is %lld",
ShapeUtil::GetDimension(mean_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
return InvalidArgument(
"The size of variance should be the same as feature count,"
"but the size of variance is %lld "
"and the feature count is %lld",
ShapeUtil::GetDimension(variance_shape, 0), feature_count);
}
return operand_shape;
}
/* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
const Shape& operand_shape, const Shape& scale_shape,
const Shape& mean_shape, const Shape& var_shape,

View File

@ -71,6 +71,13 @@ class ShapeInference {
const Shape& scale_shape,
int64 feature_index);
// Infers the shape produced by InferBatchNormInference with the given
// operands.
static StatusOr<Shape> InferBatchNormInferenceShape(
const Shape& operand_shape, const Shape& offset_shape,
const Shape& scale_shape, const Shape& mean_shape,
const Shape& variance_shape, int64 feature_index);
// Infers the shape produced by InferBatchNormGrad with the given operands.
static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
const Shape& scale_shape,

View File

@ -46,8 +46,7 @@ std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
bool PointsToSet::IsAmbiguous() const {
bool ambiguous = false;
ForEachElement(
[&ambiguous](const ShapeIndex& /*index*/,
const std::vector<const LogicalBuffer*>& points_to) {
[&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) {
ambiguous |= points_to.size() > 1;
});
return ambiguous;
@ -56,9 +55,8 @@ bool PointsToSet::IsAmbiguous() const {
bool PointsToSet::IsDistinct() const {
bool distinct = true;
std::set<const LogicalBuffer*> all_points_to;
ForEachElement([&distinct, &all_points_to](
const ShapeIndex& /*index*/,
const std::vector<const LogicalBuffer*>& points_to) {
ForEachElement([&distinct, &all_points_to](const ShapeIndex& /*index*/,
const BufferList& points_to) {
for (auto& buffer : points_to) {
if (all_points_to.count(buffer) != 0) {
distinct = false;
@ -75,34 +73,31 @@ size_t PointsToSet::size() const {
return CreateFlattenedSet().size();
}
tensorflow::gtl::FlatSet<const LogicalBuffer*> PointsToSet::CreateFlattenedSet()
const {
tensorflow::gtl::FlatSet<const LogicalBuffer*> flat_set;
ForEachElement([&flat_set](const ShapeIndex& /*index*/,
const std::vector<const LogicalBuffer*>& buffers) {
flat_set.insert(buffers.begin(), buffers.end());
});
PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const {
BufferSet flat_set;
ForEachElement(
[&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) {
flat_set.insert(buffers.begin(), buffers.end());
});
return flat_set;
}
bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
bool found = false;
ForEachElement(
[&found, &buffer](
const ShapeIndex& /*index*/,
const std::vector<const LogicalBuffer*>& pointed_to_buffers) {
if (!found &&
std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
&buffer) != pointed_to_buffers.end()) {
found = true;
}
});
ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
const BufferList& pointed_to_buffers) {
if (!found &&
std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
&buffer) != pointed_to_buffers.end()) {
found = true;
}
});
return found;
}
bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
const ShapeIndex& index) const {
const std::vector<const LogicalBuffer*>& pointed_to_buffers = element(index);
const auto& pointed_to_buffers = element(index);
return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
&buffer) != pointed_to_buffers.end();
}
@ -115,14 +110,14 @@ void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
mutable_element(index)->push_back(&buffer);
}
const std::set<HloInstruction*>& PointsToSet::tuple_sources(
const PointsToSet::SourceSet& PointsToSet::tuple_sources(
const ShapeIndex& index) const {
return tuple_sources_.element(index);
return tree_.element(index).tuple_sources;
}
void PointsToSet::add_tuple_source(const ShapeIndex& index,
HloInstruction* tuple) {
tuple_sources_.mutable_element(index)->insert(tuple);
tree_.mutable_element(index)->tuple_sources.insert(tuple);
}
/* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
@ -177,7 +172,7 @@ Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(
points_to_set.ForEachElement(
[this, &instruction](
const ShapeIndex& index,
const std::vector<const LogicalBuffer*>& pointed_to_buffers) {
const PointsToSet::BufferList& pointed_to_buffers) {
for (const LogicalBuffer* buffer : pointed_to_buffers) {
PerBuffer(buffer->id())
->buffer_aliases.emplace_back(instruction.get(), index);
@ -205,7 +200,7 @@ Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
points_to_set.ForEachMutableElement(
[this, hlo_instruction](const ShapeIndex& index,
std::vector<const LogicalBuffer*>* buffers) {
PointsToSet::BufferList* buffers) {
const LogicalBuffer& buffer = NewLogicalBuffer(hlo_instruction, index);
buffers->push_back(&buffer);
});
@ -232,7 +227,7 @@ Status TuplePointsToAnalysis::HandleGetTupleElement(
// operand to the points-to set for this GetTupleElement instruction.
points_to_set.ForEachMutableElement(
[&, this](const ShapeIndex& target_index,
std::vector<const LogicalBuffer*>* points_to) {
PointsToSet::BufferList* points_to) {
// Construct an index into the operand by prepending element_index to
// the index for the GetTupleElement instruction's points-to set.
ShapeIndex src_index;
@ -289,7 +284,7 @@ Status TuplePointsToAnalysis::HandleTuple(
operand_points_to_set.ForEachElement(
[&points_to_set, &operand_points_to_set, i](
const ShapeIndex& src_index,
const std::vector<const LogicalBuffer*>& points_to) {
const PointsToSet::BufferList& points_to) {
ShapeIndex target_index;
target_index.push_back(i);
for (auto element : src_index) {
@ -324,7 +319,7 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select,
PointsToSet& points_to_set = CreateCopiedPointsToSet(select, on_true);
const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set;
points_to_set.ForEachMutableElement(
[&](const ShapeIndex& index, std::vector<const LogicalBuffer*>* buffers) {
[&](const ShapeIndex& index, PointsToSet::BufferList* buffers) {
for (const LogicalBuffer* false_buffer :
false_points_to_set.element(index)) {
points_to_set.AddPointedToBuffer(*false_buffer, index);
@ -361,8 +356,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
const HloInstruction* instruction, const ShapeIndex& index) const {
const std::vector<const LogicalBuffer*>& buffers =
GetPointsToSet(instruction).element(index);
const auto& buffers = GetPointsToSet(instruction).element(index);
return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
}
@ -398,8 +392,7 @@ const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
const HloInstruction* instruction, const ShapeIndex& index) const {
const std::vector<const LogicalBuffer*>& buffers =
GetPointsToSet(instruction).element(index);
const auto& buffers = GetPointsToSet(instruction).element(index);
if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
return FailedPrecondition(
"instruction %s does not define buffer at index {%s}",
@ -424,27 +417,26 @@ Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
const HloInstruction* instruction,
TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
GetPointsToSet(instruction)
.ForEachElement(
[this, buffers, instruction](
const ShapeIndex& index,
const std::vector<const LogicalBuffer*>& source_buffers) {
// Add buffers which 'instruction' is the source of.
CHECK(!source_buffers.empty());
if (source_buffers.size() == 1 &&
source_buffers[0]->instruction() == instruction) {
// If this instruction is the source of this buffer the
// indices must match.
DCHECK(source_buffers[0]->index() == index);
buffers->push_back(source_buffers[0]);
} else {
// If the points-to set includes more than one buffer then
// necessarily this instruction did not produce the
// buffer.
for (const LogicalBuffer* source_buffer : source_buffers) {
DCHECK(source_buffer->instruction() != instruction);
}
}
});
.ForEachElement([this, buffers, instruction](
const ShapeIndex& index,
const PointsToSet::BufferList& source_buffers) {
// Add buffers which 'instruction' is the source of.
CHECK(!source_buffers.empty());
if (source_buffers.size() == 1 &&
source_buffers[0]->instruction() == instruction) {
// If this instruction is the source of this buffer the
// indices must match.
DCHECK(source_buffers[0]->index() == index);
buffers->push_back(source_buffers[0]);
} else {
// If the points-to set includes more than one buffer then
// necessarily this instruction did not produce the
// buffer.
for (const LogicalBuffer* source_buffer : source_buffers) {
DCHECK(source_buffer->instruction() != instruction);
}
}
});
return Status::OK();
}
@ -456,7 +448,7 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
const PointsToSet& src_points_to_set = GetPointsToSet(src);
dst_points_to_set.ForEachMutableElement(
[this, &dst_points_to_set, &src_points_to_set](
const ShapeIndex& index, std::vector<const LogicalBuffer*>* buffers) {
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
*buffers = src_points_to_set.element(index);
for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
dst_points_to_set.add_tuple_source(index, tuple_source);
@ -505,19 +497,18 @@ void TuplePointsToAnalysis::InstructionToString(
tensorflow::strings::StrAppend(output, prefix, " instruction ",
instruction->ToShortString(), ":\n");
const PointsToSet& points_to_set = GetPointsToSet(instruction);
points_to_set.ForEachElement(
[&prefix, &output](const ShapeIndex& index,
const std::vector<const LogicalBuffer*>& points_to) {
tensorflow::strings::StrAppend(
output, prefix, " {", tensorflow::str_util::Join(index, ","),
"}: ",
tensorflow::str_util::Join(
points_to, ", ",
[](string* out, const LogicalBuffer* source) {
out->append(source->ToString());
}),
"\n");
});
points_to_set.ForEachElement([&prefix, &output](
const ShapeIndex& index,
const PointsToSet::BufferList& points_to) {
tensorflow::strings::StrAppend(
output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ",
tensorflow::str_util::Join(
points_to, ", ",
[](string* out, const LogicalBuffer* source) {
out->append(source->ToString());
}),
"\n");
});
}
} // namespace xla

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
@ -46,14 +47,12 @@ namespace xla {
// nested tuple). Each node in this tree corresponds to a single buffer in the
// instruction's output and contains the set of Buffers which might define
// the corresponding buffer.
class PointsToSet : public ShapeTree<std::vector<const LogicalBuffer*>> {
class PointsToSet {
public:
// Construct our ShapeTree with a pointer rather than a reference to a Shape
// because this is very hot code, and copying (and then destroying) all these
// Shapes is slow.
explicit PointsToSet(const Shape* shape)
: ShapeTree<std::vector<const LogicalBuffer*>>(shape),
tuple_sources_(shape) {}
explicit PointsToSet(const Shape* shape) : tree_(shape) {}
// Returns true if any points-to sets for any subshape element is not a
// singleton.
@ -69,7 +68,8 @@ class PointsToSet : public ShapeTree<std::vector<const LogicalBuffer*>> {
// Creates a set containing the union of all LogicalBuffers contained in the
// PointsToSet.
tensorflow::gtl::FlatSet<const LogicalBuffer*> CreateFlattenedSet() const;
using BufferSet = tensorflow::gtl::CompactPointerSet<const LogicalBuffer*>;
BufferSet CreateFlattenedSet() const;
// Returns true if the given buffer is in the points-to set at the given
// index.
@ -102,13 +102,49 @@ class PointsToSet : public ShapeTree<std::vector<const LogicalBuffer*>> {
// tuple_sources() at the index of an array shape (not a tuple) returns the
// empty set. The instructions in the set returned by tuple_sources
// necessarily are either Tuple instructions, constants, or parameters.
const std::set<HloInstruction*>& tuple_sources(const ShapeIndex& index) const;
using SourceSet = tensorflow::gtl::CompactPointerSet<HloInstruction*>;
const SourceSet& tuple_sources(const ShapeIndex& index) const;
// Add a tuple source instruction for the given index.
void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple);
using BufferList = tensorflow::gtl::InlinedVector<const LogicalBuffer*, 1>;
// Return the list of logical buffers for the subshape at index.
const BufferList& element(const ShapeIndex& index) const {
return tree_.element(index).buffers;
}
BufferList* mutable_element(const ShapeIndex& index) {
return &tree_.mutable_element(index)->buffers;
}
// Call fn(index, buflist) for every subshape index.
template <typename Fn>
void ForEachElement(const Fn& fn) const {
tree_.ForEachElement([&fn](const ShapeIndex& index, const Elem& elem) {
fn(index, elem.buffers);
});
}
template <typename Fn>
void ForEachMutableElement(const Fn& fn) {
tree_.ForEachMutableElement([&fn](const ShapeIndex& index, Elem* elem) {
fn(index, &elem->buffers);
});
}
template <typename Fn>
Status ForEachElementWithStatus(const Fn& fn) const {
return tree_.ForEachElementWithStatus(
[&fn](const ShapeIndex& index, const Elem& elem) {
return fn(index, elem.buffers);
});
}
private:
ShapeTree<std::set<HloInstruction*>> tuple_sources_;
struct Elem {
BufferList buffers;
SourceSet tuple_sources;
};
ShapeTree<Elem> tree_;
// PointsToSet contains references (const LogicalBuffer*) to elements within
// TuplePointsToAnalysis so disable copying.

View File

@ -35,8 +35,8 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
using ::testing::UnorderedElementsAreArray;
using ::testing::UnorderedElementsAre;
using ::testing::UnorderedElementsAreArray;
class TuplePointsToAnalysisTest : public HloTestBase {
protected:
@ -62,7 +62,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// index. CHECKs if no buffer is defined at that point.
const LogicalBuffer* const GetBuffer(const HloInstruction* instruction,
const ShapeIndex& index) {
const std::vector<const LogicalBuffer*>& pointed_to =
const auto& pointed_to =
points_to_analysis_->GetPointsToSet(instruction).element(index);
CHECK_EQ(1, pointed_to.size());
CHECK_EQ(instruction, pointed_to[0]->instruction());
@ -73,7 +73,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// Checks that the given points-to set contains exactly (unordered) the given
// LogicalBuffers.
void ExpectHasBuffers(
const std::vector<const LogicalBuffer*>& points_to_set,
const PointsToSet::BufferList& points_to_set,
tensorflow::gtl::ArraySlice<const LogicalBuffer*> buffers) {
std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec));
@ -82,22 +82,22 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// Checks that the given points-to set contains exactly (unordered) the
// top-level buffers of the given instructions.
void ExpectHasTopLevelBuffers(
const std::vector<const LogicalBuffer*>& points_to_set,
const PointsToSet::BufferList& points_to_set,
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
std::vector<const LogicalBuffer*> buffers;
PointsToSet::BufferList buffers;
for (auto instruction : instructions) {
buffers.push_back(GetBuffer(instruction, /*index=*/{}));
}
ExpectHasBuffers(points_to_set, buffers);
}
// Overload which takes a std::set instead of a std::vector.
// Overload which takes a set instead of a vector.
void ExpectHasTopLevelBuffers(
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& points_to_set,
const PointsToSet::BufferSet& points_to_set,
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
ExpectHasTopLevelBuffers(std::vector<const LogicalBuffer*>(
points_to_set.begin(), points_to_set.end()),
instructions);
ExpectHasTopLevelBuffers(
PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()),
instructions);
}
// Checks that the buffer defined at the given instruction and index has

View File

@ -507,6 +507,53 @@ UserComputation::AddBatchNormTrainingInstruction(
return handle;
}
StatusOr<ComputationDataHandle>
UserComputation::AddBatchNormInferenceInstruction(
const BatchNormInferenceRequest& batch_norm_inference_request) {
tensorflow::mutex_lock lock(mutex_);
TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
LookUpRequest(batch_norm_inference_request.operand()));
TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
LookUpRequest(batch_norm_inference_request.scale()));
TF_ASSIGN_OR_RETURN(const OperationRequest* offset,
LookUpRequest(batch_norm_inference_request.offset()));
TF_ASSIGN_OR_RETURN(const OperationRequest* mean,
LookUpRequest(batch_norm_inference_request.mean()));
TF_ASSIGN_OR_RETURN(const OperationRequest* variance,
LookUpRequest(batch_norm_inference_request.variance()));
ComputationDataHandle handle = CreateComputationDataHandle();
OperationRequest& request =
(*session_computation_.mutable_requests())[handle.handle()];
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
ShapeInference::InferBatchNormInferenceShape(
operand->output_shape(), scale->output_shape(),
offset->output_shape(), mean->output_shape(),
variance->output_shape(),
batch_norm_inference_request.feature_index()));
*request.mutable_output_shape() = inferred_shape;
*request.mutable_output_handle() = handle;
*request.mutable_request()->mutable_batch_norm_inference_request() =
batch_norm_inference_request;
VLOG(1) << "AddBatchNormInferenceInstruction ("
<< GetVersionedHandleInternal() << "), data handle "
<< handle.handle() << ": "
<< batch_norm_inference_request.ShortDebugString();
return handle;
}
StatusOr<ComputationDataHandle> UserComputation::AddBatchNormGradInstruction(
const BatchNormGradRequest& batch_norm_grad_request) {
tensorflow::mutex_lock lock(mutex_);
@ -1678,6 +1725,25 @@ void ConstantVisitor(const SessionComputation& session_computation,
break;
}
case OpRequest::kBatchNormInferenceRequest: {
const BatchNormInferenceRequest& batch_norm_inference_request =
request.request().batch_norm_inference_request();
ConstantVisitor(session_computation,
batch_norm_inference_request.operand(), visited,
is_constant);
ConstantVisitor(session_computation, batch_norm_inference_request.scale(),
visited, is_constant);
ConstantVisitor(session_computation,
batch_norm_inference_request.offset(), visited,
is_constant);
ConstantVisitor(session_computation, batch_norm_inference_request.mean(),
visited, is_constant);
ConstantVisitor(session_computation,
batch_norm_inference_request.variance(), visited,
is_constant);
break;
}
case OpRequest::kBatchNormGradRequest: {
const BatchNormGradRequest& batch_norm_grad_request =
request.request().batch_norm_grad_request();
@ -2119,6 +2185,18 @@ static void ForEachOperand(
break;
}
case OpRequest::kBatchNormInferenceRequest: {
const BatchNormInferenceRequest& batch_norm_inference_request =
request.request().batch_norm_inference_request();
apply(batch_norm_inference_request.operand());
apply(batch_norm_inference_request.scale());
apply(batch_norm_inference_request.offset());
apply(batch_norm_inference_request.mean());
apply(batch_norm_inference_request.variance());
break;
}
case OpRequest::kBatchNormGradRequest: {
const BatchNormGradRequest& batch_norm_grad_request =
request.request().batch_norm_grad_request();
@ -2647,6 +2725,28 @@ void ComputationLowerer::Visit(
break;
}
case OpRequest::kBatchNormInferenceRequest: {
const BatchNormInferenceRequest& batch_norm_inference_request =
request.request().batch_norm_inference_request();
HloInstruction* operand =
lookup_instruction(batch_norm_inference_request.operand());
HloInstruction* scale =
lookup_instruction(batch_norm_inference_request.scale());
HloInstruction* offset =
lookup_instruction(batch_norm_inference_request.offset());
HloInstruction* mean =
lookup_instruction(batch_norm_inference_request.mean());
HloInstruction* variance =
lookup_instruction(batch_norm_inference_request.variance());
hlo_instruction =
add_instruction(HloInstruction::CreateBatchNormInference(
request.output_shape(), operand, scale, offset, mean, variance,
batch_norm_inference_request.epsilon(),
batch_norm_inference_request.feature_index()));
break;
}
case OpRequest::kBatchNormGradRequest: {
const BatchNormGradRequest& batch_norm_grad_request =
request.request().batch_norm_grad_request();

View File

@ -89,6 +89,10 @@ class UserComputation {
StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction(
const BatchNormTrainingRequest& batch_norm_training_request);
// Enqueues a batch norm inference instruction onto this user computation.
StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction(
const BatchNormInferenceRequest& batch_norm_inference_request);
// Enqueues a batch norm grad instruction onto this user computation.
StatusOr<ComputationDataHandle> AddBatchNormGradInstruction(
const BatchNormGradRequest& batch_norm_grad_request);

View File

@ -115,25 +115,16 @@ class ShapeTree {
ShapeTree(Shape shape, const T& init_value);
ShapeTree(const Shape* shape, const T& init_value);
ShapeTree(const ShapeTree& other)
: root_(other.root_), shape_storage_(other.shape_storage_) {
// Fix up internal pointer if necessary.
if (shape_storage_) {
CHECK_EQ(other.shape_, &*other.shape_storage_);
shape_ = &*shape_storage_;
} else {
shape_ = other.shape_;
}
}
ShapeTree(const ShapeTree& other) { *this = other; }
ShapeTree& operator=(const ShapeTree& other) {
root_ = other.root_;
shape_storage_ = other.shape_storage_;
// Fix up internal pointer if necessary.
if (shape_storage_) {
CHECK_EQ(other.shape_, &*other.shape_storage_);
shape_ = &*shape_storage_;
if (other.shape_storage_) {
CHECK_EQ(other.shape_, other.shape_storage_.get());
shape_storage_.reset(new Shape(*other.shape_));
shape_ = shape_storage_.get();
} else {
shape_ = other.shape_;
}
@ -259,11 +250,11 @@ class ShapeTree {
Node root_;
// If we own our Shape, this field contains it, and shape_ is a pointer into
// here. Otherwise if we don't own our shape, this is nullopt.
tensorflow::gtl::optional<Shape> shape_storage_;
// here. Otherwise if we don't own our shape, this is nullptr.
std::unique_ptr<Shape> shape_storage_;
// The XLA shape mirrored in this ShapeTree. This is either a pointer into
// shape_storage_ or the Shape pointer passed to our constructor.
// The XLA shape mirrored in this ShapeTree. This is either
// shape_storage_.get() or the Shape pointer passed to our constructor.
const Shape* shape_;
};
@ -401,10 +392,12 @@ void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
template <typename T>
ShapeTree<T>::ShapeTree(Shape shape)
: root_(), shape_storage_(std::move(shape)), shape_(&*shape_storage_) {
: root_(),
shape_storage_(MakeUnique<Shape>(std::move(shape))),
shape_(shape_storage_.get()) {
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(&*shape_storage_);
LayoutUtil::ClearLayout(shape_storage_.get());
InitChildren(*shape_, &root_);
}
@ -416,11 +409,11 @@ ShapeTree<T>::ShapeTree(const Shape* shape) : root_(), shape_(shape) {
template <typename T>
ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
: root_(init_value),
shape_storage_(std::move(shape)),
shape_(&*shape_storage_) {
shape_storage_(MakeUnique<Shape>(std::move(shape))),
shape_(shape_storage_.get()) {
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(&*shape_storage_);
LayoutUtil::ClearLayout(shape_storage_.get());
InitChildren(*shape_, init_value, &root_);
}

View File

@ -425,13 +425,42 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) {
HumanString(program_shape.result()));
}
/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(const string& s) {
namespace {
// Parses shapes with simple recursive descent structure -- consumes from the
// front of s and passes that view recursively as required.
StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
tensorflow::str_util::RemoveLeadingWhitespace(s);
if (s->Consume("(")) { // Tuple.
std::vector<Shape> shapes;
bool must_end = false;
while (true) {
if (s->Consume(")")) {
break;
} else if (must_end) {
return InvalidArgument("Expected end of tuple; got: \"%s\"",
s->ToString().c_str());
}
shapes.emplace_back();
TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s));
tensorflow::str_util::RemoveLeadingWhitespace(s);
must_end = !s->Consume(",");
}
return ShapeUtil::MakeTupleShape(shapes);
}
string element_type_string;
string dimensions_string;
string layout_string;
if (RE2::FullMatch(s, "([fsu]32)\\[([\\d,]*)\\](?: {([\\d,]*)})?",
&element_type_string, &dimensions_string,
&layout_string)) {
// tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so
// we convert in to the RE2-consumable type and then consume the corresponding
// amount from our StringPiece type.
tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
if (RE2::Consume(&s_consumable,
"^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*{([\\d,]*)})?",
&element_type_string, &dimensions_string, &layout_string)) {
size_t consumed = s->size() - s_consumable.size();
s->remove_prefix(consumed);
auto comma_list_to_int64s =
[&s](const string& input) -> StatusOr<std::vector<int64>> {
std::vector<int64> results;
@ -439,39 +468,58 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) {
int64 element;
if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) {
return InvalidArgument(
"invalid value in parsed shape string: \"%s\" in \"%s\"",
piece.c_str(), s.c_str());
"Invalid s64 value in parsed shape string: \"%s\" in \"%s\"",
piece.c_str(), s->ToString().c_str());
}
results.push_back(element);
}
return results;
};
// Extract the dimensions.
TF_ASSIGN_OR_RETURN(std::vector<int64> dimensions,
comma_list_to_int64s(dimensions_string));
PrimitiveType primitive_type;
if (element_type_string == "f32") {
primitive_type = F32;
} else if (element_type_string == "s32") {
primitive_type = S32;
} else if (element_type_string == "u32") {
primitive_type = U32;
} else {
LOG(FATAL) << "unhandled element type string: " << element_type_string;
// Extract the primitive element type.
PrimitiveType primitive_type = PRIMITIVE_TYPE_INVALID;
for (PrimitiveType i =
static_cast<PrimitiveType>(PRIMITIVE_TYPE_INVALID + 1);
i < TUPLE; i = static_cast<PrimitiveType>(i + 1)) {
if (tensorflow::str_util::Lowercase(PrimitiveType_Name(i)) ==
element_type_string) {
primitive_type = i;
break;
}
}
if (primitive_type == PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("Invalid element type string: \"%s\".",
element_type_string.c_str());
}
Shape result;
if (layout_string.empty()) {
result = MakeShape(primitive_type, dimensions);
// Create a shape without a layout set.
result = ShapeUtil::MakeShape(primitive_type, dimensions);
} else {
// Extract the layout minor-to-major and set it.
TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
comma_list_to_int64s(layout_string));
TF_RET_CHECK(dimensions.size() == min2maj.size());
result = MakeShapeWithLayout(primitive_type, dimensions, min2maj);
result =
ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj);
}
TF_DCHECK_OK(ValidateShape(result));
return result;
TF_DCHECK_OK(ShapeUtil::ValidateShape(result));
return std::move(result);
}
return InvalidArgument("invalid shape string to parse: \"%s\"", s.c_str());
return InvalidArgument("Invalid shape string to parse: \"%s\"",
s->ToString().c_str());
}
} // namespace
/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(
tensorflow::StringPiece s) {
return ParseShapeStringInternal(&s);
}
/* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,

View File

@ -125,7 +125,7 @@ class ShapeUtil {
// Parses a ShapeUtil::HumanString-format shape string back into a shape
// object.
static StatusOr<Shape> ParseShapeString(const string& s);
static StatusOr<Shape> ParseShapeString(tensorflow::StringPiece s);
// Returns whether the LHS and RHS shapes have the same dimensions; note: does
// not check element type.

View File

@ -78,6 +78,30 @@ TEST(ShapeUtilTest, ParseShapeStringR2F32) {
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
string shape_string = "(f32[1572864],s8[5120,1024])";
Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
Shape expected =
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
ShapeUtil::MakeShape(S8, {5120, 1024})});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
string shape_string = "(f32[1],(f32[2]), f32[3])";
Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
Shape expected = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {1}),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}),
ShapeUtil::MakeShape(F32, {3}),
});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2});
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});

View File

@ -785,6 +785,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
&builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
SetFastMathDisabled(true);
ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<float>({-2.0f, -0.6f, -0.6f, 0.0f});
auto rhs = builder.ConstantR1<float>({0.5f, 0.6f, -0.6f, -0.6f});
auto minimum = builder.Pow(lhs, rhs);
ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<float>({});

View File

@ -306,6 +306,109 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) {
ErrorSpec(0.01, 1));
}
XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) {
float epsilon = 0.001;
ComputationBuilder builder(client_, TestName());
const std::vector<int64>& bounds = GetParam().bounds;
Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
input_array.FillRandom(GetParam().random_value_var,
GetParam().random_value_mean);
const int64 feature_index = GetParam().feature_index;
const int64 num_elements_per_feature =
Product(bounds) / bounds[feature_index];
const int64 feature_bound = bounds[feature_index];
std::vector<float> offset(feature_bound, 1);
std::vector<float> scale(feature_bound, 2);
auto input_squared =
ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
std::vector<int64> reduce_dims;
for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
if (i != feature_index) {
reduce_dims.push_back(i);
}
}
auto sum =
ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
[](float a, float b) { return a + b; });
auto sum_squared =
ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
[](float a, float b) { return a + b; });
std::vector<float> mean(feature_bound);
for (int64 i = 0; i < feature_bound; ++i) {
mean[i] = sum[i] / num_elements_per_feature;
}
std::vector<float> mean_square(feature_bound);
for (int64 i = 0; i < feature_bound; ++i) {
mean_square[i] = mean[i] * mean[i];
}
std::vector<float> square_mean(feature_bound);
for (int64 i = 0; i < feature_bound; ++i) {
square_mean[i] = sum_squared[i] / num_elements_per_feature;
}
std::vector<float> var(feature_bound);
for (int64 i = 0; i < feature_bound; ++i) {
var[i] = square_mean[i] - mean_square[i];
}
Array4D<float> mean4D =
*ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
auto offset4D =
*ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
scale4D, offset4D, epsilon);
auto offset_literal = Literal::CreateR1<float>(offset);
auto scale_literal = Literal::CreateR1<float>(scale);
auto mean_literal = Literal::CreateR1<float>(mean);
auto var_literal = Literal::CreateR1<float>(var);
auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
auto input_activations =
builder.Parameter(0, input_literal->shape(), "input");
auto scale_activations =
builder.Parameter(1, scale_literal->shape(), "offset");
auto offset_activations =
builder.Parameter(2, offset_literal->shape(), "scale");
auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
auto variance_activations =
builder.Parameter(4, var_literal->shape(), "variance");
Array4D<float> expected = normalized;
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> variance_data =
client_->TransferToServer(*var_literal).ConsumeValueOrDie();
builder.BatchNormInference(input_activations, scale_activations,
offset_activations, mean_activations,
variance_activations, epsilon, feature_index);
ComputeAndCompareR4<float>(
&builder, expected,
{input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
variance_data.get()},
ErrorSpec(0.01, 1));
}
XLA_TEST_P(BatchNormTest, RandomizedGradTests) {
float epsilon = 0.001;
ComputationBuilder builder(client_, TestName());

View File

@ -47,6 +47,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
@ -55,11 +56,16 @@ limitations under the License.
namespace xla {
namespace tools {
namespace {
// Invokes the given computation passing arbitrary data for every (unbound)
// parameter if use_fake_data, Otherwise use recorded data if available.
//
// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
// otherwise, no infeed is performed.
StatusOr<std::unique_ptr<Literal>> ReplayComputation(
const SessionModule& module, bool use_fake_data, Client* client) {
const SessionModule& module, tensorflow::StringPiece fake_infeed_shape,
bool use_fake_data, Client* client) {
TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module));
std::vector<std::unique_ptr<GlobalData>> arguments;
@ -74,6 +80,27 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
}
}
// We only instantiate the thread pool if the user has requested that a
// concurrent infeed occur via the fake_infeed_shape.
tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool;
if (!fake_infeed_shape.empty()) {
pool.emplace(tensorflow::Env::Default(), "infeed",
/*num_threads=*/1);
pool->Schedule([fake_infeed_shape, client]() {
StatusOr<Shape> shape_status =
ShapeUtil::ParseShapeString(fake_infeed_shape);
TF_CHECK_OK(shape_status.status());
Shape shape = std::move(shape_status).ValueOrDie();
StatusOr<std::unique_ptr<Literal>> data_status = MakeFakeLiteral(shape);
TF_CHECK_OK(data_status.status());
std::unique_ptr<Literal> data = std::move(data_status).ValueOrDie();
while (true) {
TF_CHECK_OK(client->TransferToInfeed(*data));
}
});
}
std::vector<GlobalData*> execute_arguments;
execute_arguments.reserve(arguments.size());
for (auto& argument : arguments) {
@ -82,17 +109,20 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
return client->ExecuteAndTransfer(computation, execute_arguments);
}
void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
int RealMain(tensorflow::gtl::ArraySlice<char*> args,
tensorflow::StringPiece fake_infeed_shape, bool use_fake_data) {
Client* client = ClientLibrary::LocalClientOrDie();
tensorflow::Env* env = tensorflow::Env::Default();
int exit_status = EXIT_SUCCESS;
for (char* arg : args) {
SessionModule module;
TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module));
StatusOr<std::unique_ptr<Literal>> result_status =
ReplayComputation(module, use_fake_data, client);
ReplayComputation(module, fake_infeed_shape, use_fake_data, client);
if (!result_status.ok()) {
fprintf(stderr, "%s: error: %s\n", arg,
result_status.status().ToString().c_str());
exit_status = EXIT_FAILURE;
continue;
}
std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
@ -105,17 +135,22 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
Literal(module.result()).ToString().c_str());
}
}
return exit_status;
}
} // namespace
} // namespace tools
} // namespace xla
int main(int argc, char** argv) {
// Flags
string fake_infeed_shape;
bool use_fake_data = false;
const std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("use_fake_data", &use_fake_data,
"Replay computation using fake data"),
tensorflow::Flag("fake_infeed_shape", &fake_infeed_shape,
"Shape of fake data to construct for (infinite) infeed"),
};
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
@ -126,6 +161,5 @@ int main(int argc, char** argv) {
tensorflow::gtl::ArraySlice<char*> args(argv, argc);
args.pop_front(); // Pop off the binary name, argv[0]
xla::tools::RealMain(args, use_fake_data);
return 0;
return xla::tools::RealMain(args, fake_infeed_shape, use_fake_data);
}

View File

@ -491,6 +491,16 @@ message BatchNormTrainingRequest {
int64 feature_index = 5;
}
message BatchNormInferenceRequest {
ComputationDataHandle operand = 1;
ComputationDataHandle scale = 2;
ComputationDataHandle offset = 3;
ComputationDataHandle mean = 4;
ComputationDataHandle variance = 5;
float epsilon = 6;
int64 feature_index = 7;
}
message BatchNormGradRequest {
ComputationDataHandle operand = 1;
ComputationDataHandle scale = 2;
@ -813,7 +823,8 @@ message OpRequest {
OutfeedRequest outfeed_request = 32;
BatchNormTrainingRequest batch_norm_training_request = 35;
BatchNormGradRequest batch_norm_grad_request = 37;
// Next: 38
BatchNormInferenceRequest batch_norm_inference_request = 38;
// Next: 39
}
}

View File

@ -23,6 +23,7 @@ py_library(
"//tensorflow/contrib/data",
"//tensorflow/contrib/deprecated:deprecated_py",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow/contrib/factorization:factorization_py",
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
"//tensorflow/contrib/framework:framework_py",
@ -77,9 +78,8 @@ py_library(
"//tensorflow/contrib/text:text_py",
"//tensorflow/contrib/tfprof",
"//tensorflow/contrib/timeseries",
"//tensorflow/contrib/tpu",
"//tensorflow/contrib/tpu:tpu_estimator",
"//tensorflow/contrib/tpu:tpu_helper_library",
"//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
],

View File

@ -241,6 +241,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops
# Broken tensorboard test due to cmake issues.
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker
# tensor_forest tests (also note that we exclude the hybrid tests for now)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.

View File

@ -21,6 +21,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:functional_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:training",
@ -28,6 +29,27 @@ py_test(
],
)
py_test(
name = "iterator_ops_cluster_test",
size = "small",
srcs = ["iterator_ops_cluster_test.py"],
srcs_version = "PY2AND3",
tags = ["no_windows"],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:functional_ops",
"//tensorflow/python:training",
"//third_party/py/numpy",
],
)
py_test(
name = "batch_dataset_op_test",
size = "small",

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
@ -229,6 +231,16 @@ class DatasetConstructorTest(test.TestCase):
dtypes.int64), dataset.output_types)
self.assertEquals(([], ([], []), []), dataset.output_shapes)
def testNestedDict(self):
components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
dataset = dataset_ops.Dataset.from_tensors(components)
self.assertEquals(dtypes.int32, dataset.output_types["a"]["aa"])
self.assertEquals(dtypes.float32, dataset.output_types["a"]["ab"])
self.assertEquals(dtypes.int32, dataset.output_types["b"])
self.assertEquals([], dataset.output_shapes["a"]["aa"])
self.assertEquals([2], dataset.output_shapes["a"]["ab"])
self.assertEquals([3], dataset.output_shapes["b"])
def testNonSequenceNestedStructure(self):
components = np.array([1, 2, 3])
@ -255,6 +267,214 @@ class DatasetConstructorTest(test.TestCase):
self.assertEquals(dtypes.int64, get_next.dtype)
self.assertEquals([3], get_next.shape)
def _testFromGenerator(self, generator, elem_sequence, num_repeats):
iterator = (
dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64)
.repeat(num_repeats)
.prefetch(5)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
for _ in range(2): # Run twice to test reinitialization.
sess.run(init_op)
for _ in range(num_repeats):
for elem in elem_sequence:
self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats):
iterator = (
dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64)
.repeat(num_repeats)
.prefetch(5)
.make_one_shot_iterator())
get_next = iterator.get_next()
with self.test_session() as sess:
for _ in range(num_repeats):
for elem in elem_sequence:
self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testFromGeneratorUsingFunction(self):
def generator():
for i in range(1, 100):
yield [i] * i
elem_sequence = list(generator())
self._testFromGenerator(generator, elem_sequence, 1)
self._testFromGenerator(generator, elem_sequence, 5)
self._testFromGeneratorOneShot(generator, elem_sequence, 1)
self._testFromGeneratorOneShot(generator, elem_sequence, 5)
def testFromGeneratorUsingList(self):
generator = lambda: [[i] * i for i in range(1, 100)]
elem_sequence = list(generator())
self._testFromGenerator(generator, elem_sequence, 1)
self._testFromGenerator(generator, elem_sequence, 5)
def testFromGeneratorUsingNdarray(self):
generator = lambda: np.arange(100, dtype=np.int64)
elem_sequence = list(generator())
self._testFromGenerator(generator, elem_sequence, 1)
self._testFromGenerator(generator, elem_sequence, 5)
def testFromGeneratorUsingGeneratorExpression(self):
# NOTE(mrry): Generator *expressions* are not repeatable (or in
# general reusable), because they eagerly evaluate the `for`
# expression as `iter(range(1, 100))` and discard the means of
# reconstructing `range(1, 100)`. Wrapping the generator
# expression in a `lambda` makes it repeatable.
generator = lambda: ([i] * i for i in range(1, 100))
elem_sequence = list(generator())
self._testFromGenerator(generator, elem_sequence, 1)
self._testFromGenerator(generator, elem_sequence, 5)
def testFromMultipleConcurrentGenerators(self):
num_inner_repeats = 5
num_outer_repeats = 100
def generator():
for i in range(1, 10):
yield ([i] * i, [i, i ** 2, i ** 3])
input_list = list(generator())
# The interleave transformation is essentially a flat map that
# draws from multiple input datasets concurrently (in a cyclic
# fashion). By placing `Datsaet.from_generator()` inside an
# interleave, we test its behavior when multiple iterators are
# active at the same time; by additionally prefetching inside the
# interleave, we create the possibility of parallel (modulo GIL)
# invocations to several iterators created by the same dataset.
def interleave_fn(_):
return (dataset_ops.Dataset.from_generator(
generator, output_types=(dtypes.int64, dtypes.int64),
output_shapes=([None], [3]))
.repeat(num_inner_repeats).prefetch(5))
iterator = (
dataset_ops.Dataset.range(num_outer_repeats)
.interleave(interleave_fn, cycle_length=10,
block_length=len(input_list))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for _ in range(num_inner_repeats * num_outer_repeats):
for elem in input_list:
val0, val1 = sess.run(get_next)
self.assertAllEqual(elem[0], val0)
self.assertAllEqual(elem[1], val1)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testFromGeneratorsRunningInParallel(self):
num_parallel_iterators = 3
# Define shared state that multiple iterator instances will access to
# demonstrate their concurrent activity.
lock = threading.Lock()
condition = threading.Condition(lock)
next_ticket = [0] # GUARDED_BY(lock)
def generator():
# NOTE(mrry): We yield one element before the barrier, because
# the current implementation of `Dataset.interleave()` must
# fetch one element from each incoming dataset to start the
# prefetching.
yield 0
# Define a barrier that `num_parallel_iterators` iterators must enter
# before any can proceed. Demonstrates that multiple iterators may be
# active at the same time.
condition.acquire()
ticket = next_ticket[0]
next_ticket[0] += 1
if ticket == num_parallel_iterators - 1:
# The last iterator to join the barrier notifies the others.
condition.notify_all()
else:
# Wait until the last iterator enters the barrier.
while next_ticket[0] < num_parallel_iterators:
condition.wait()
condition.release()
yield 1
# As in `testFromMultipleConcurrentGenerators()`, we use a combination of
# `Dataset.interleave()` and `Dataset.prefetch()` to cause multiple
# iterators to be active concurrently.
def interleave_fn(_):
return dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2)
iterator = (
dataset_ops.Dataset.range(num_parallel_iterators)
.interleave(
interleave_fn, cycle_length=num_parallel_iterators, block_length=1)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for elem in [0, 1]:
for _ in range(num_parallel_iterators):
self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testFromGeneratorTypeError(self):
def generator():
yield np.array([1, 2, 3], dtype=np.int64)
yield np.array([4, 5, 6], dtype=np.int64)
yield "ERROR"
yield np.array([7, 8, 9], dtype=np.int64)
iterator = (dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64, output_shapes=[3])
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
with self.assertRaisesOpError(r"element of type .*int64.* was expected"):
sess.run(get_next)
self.assertAllEqual([7, 8, 9], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testFromGeneratorShapeError(self):
def generator():
yield np.array([1, 2, 3], dtype=np.int64)
yield np.array([4, 5, 6], dtype=np.int64)
yield np.array([7, 8, 9, 10], dtype=np.int64)
yield np.array([11, 12, 13], dtype=np.int64)
iterator = (dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64, output_shapes=[3])
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
sess.run(get_next)
self.assertAllEqual([11, 12, 13], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,109 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops that need test_util."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.platform import test
class IteratorClusterTest(test.TestCase):
def testRemoteIteratorWithoutRemoteCallFail(self):
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
worker, _ = test_util.create_local_cluster(
1, 1, worker_config=worker_config)
with ops.device("/job:worker/replica:0/task:0/cpu:1"):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
iterator_3 = dataset_3.make_one_shot_iterator()
iterator_3_handle = iterator_3.string_handle()
with ops.device("/job:worker/replica:0/task:0/cpu:0"):
remote_it = dataset_ops.Iterator.from_string_handle(
iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes)
get_next_op = remote_it.get_next()
with session.Session(worker[0].target) as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next_op)
def testRemoteIteratorUsingRemoteCallOp(self):
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
worker, _ = test_util.create_local_cluster(
1, 1, worker_config=worker_config)
with ops.device("/job:worker/replica:0/task:0/cpu:1"):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
iterator_3 = dataset_3.make_one_shot_iterator()
iterator_3_handle = iterator_3.string_handle()
@function.Defun(dtypes.string)
def _remote_fn(h):
remote_iterator = dataset_ops.Iterator.from_string_handle(
h, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()
with ops.device("/job:worker/replica:0/task:0/cpu:0"):
target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
remote_op = functional_ops.remote_call(
args=[iterator_3_handle],
Tout=[dtypes.int32],
f=_remote_fn,
target=target_placeholder)
with session.Session(worker[0].target) as sess:
elem = sess.run(
remote_op,
feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"})
self.assertEqual(elem, [1])
# Fails when target is cpu:0 where the resource is not located.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:worker/replica:0/task:0/cpu:0"
})
elem = sess.run(
remote_op,
feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"})
self.assertEqual(elem, [2])
elem = sess.run(
remote_op,
feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"})
self.assertEqual(elem, [3])
with self.assertRaises(errors.OutOfRangeError):
sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:worker/replica:0/task:0/cpu:1"
})
if __name__ == "__main__":
test.main()

View File

@ -25,8 +25,10 @@ from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -416,6 +418,62 @@ class IteratorTest(test.TestCase):
feedable_int_vector.get_next(),
feed_dict={handle_placeholder: handle_float_vector}))
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
iterator_3 = dataset_3.make_one_shot_iterator()
iterator_3_handle = iterator_3.string_handle()
@function.Defun(dtypes.string)
def _remote_fn(h):
remote_iterator = dataset_ops.Iterator.from_string_handle(
h, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
remote_op = functional_ops.remote_call(
args=[iterator_3_handle],
Tout=[dtypes.int32],
f=_remote_fn,
target=target_placeholder)
with self.test_session(config=worker_config) as sess:
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
})
self.assertEqual(elem, [1])
# Fails when target is cpu:0 where the resource is not located.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
})
self.assertEqual(elem, [2])
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
})
self.assertEqual(elem, [3])
with self.assertRaises(errors.OutOfRangeError):
sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
})
if __name__ == "__main__":
test.main()

View File

@ -549,6 +549,41 @@ class MapDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testReturnList(self):
iterator = (dataset_ops.Dataset.range(10)
.map(lambda x: [x, constant_op.constant(37.0)])
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, 37.0), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testMultiOutputPyFunc(self):
# The `tf.py_func()` op returns a list of tensors for its outputs.
def _map_fn(x_tensor):
def _map_py_func(x):
return x, np.array(37.0, dtype=np.float64)
return script_ops.py_func(
_map_py_func, [x_tensor], [dtypes.int64, dtypes.float64])
iterator = (dataset_ops.Dataset.range(10)
.map(_map_fn)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, 37.0), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -24,6 +24,7 @@ py_library(
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",

View File

@ -18,6 +18,8 @@ from __future__ import division
from __future__ import print_function
import abc
import collections
import threading
import warnings
import numpy as np
@ -40,6 +42,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import gfile
@ -559,6 +562,168 @@ class Dataset(object):
"""
return SparseTensorSliceDataset(sparse_tensor)
class _GeneratorState(object):
"""Stores outstanding iterators created from a Python generator.
This class keeps track of potentially multiple iterators that may have
been created from a generator, e.g. in the case that the dataset is
repeated, or nested within a parallel computation.
"""
def __init__(self, generator):
self._generator = generator
self._lock = threading.Lock()
self._next_id = 0 # GUARDED_BY(self._lock)
self._iterators = collections.defaultdict(lambda: iter(generator()))
def get_next_id(self):
with self._lock:
ret = self._next_id
self._next_id += 1
return ret
def get_iterator(self, iterator_id):
return self._iterators[iterator_id]
def iterator_completed(self, iterator_id):
del self._iterators[iterator_id]
@staticmethod
def from_generator(generator, output_types, output_shapes=None):
"""Creates a `Dataset` whose elements are generated by `generator`.
The `generator` argument must be a callable object that returns
an object that support the `iter()` protocol (e.g. a generator function).
The elements generated by `generator` must be compatible with the given
`output_types` and (optional) `output_shapes` arguments.
Args:
generator: A callable object that takes no arguments and returns an
object that supports the `iter()` protocol.
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element yielded by `generator`.
output_shapes: (Optional.) A nested structure of `tf.TensorShape`
objects corresponding to each component of an element yielded by
`generator`.
Returns:
A `Dataset`.
"""
if not callable(generator):
raise TypeError("`generator` must be callable.")
if output_shapes is None:
output_shapes = nest.map_structure(
lambda _: tensor_shape.TensorShape(None), output_types)
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
flattened_types = nest.flatten(output_types)
flattened_shapes = nest.flatten(output_shapes)
generator_state = Dataset._GeneratorState(generator)
def get_iterator_id_map_fn(unused_dummy):
"""Creates a unique `iterator_id` for each pass over the dataset.
The "iterator_id" disambiguates between multiple concurrently
existing iterators.
Args:
unused_dummy: Ignored value.
Returns:
A `tf.int64` tensor whose value uniquely identifies an iterator in
`generator_state`.
"""
return script_ops.py_func(
generator_state.get_next_id, [], dtypes.int64, stateful=True)
def generator_map_fn(iterator_id_t):
"""Generates the next element from iterator with ID `iterator_id_t`.
We map this function across an infinite repetition of the
`iterator_id_t`, and raise `StopIteration` to terminate the iteration.
Args:
iterator_id_t: A `tf.int64` tensor whose value uniquely identifies
the iterator in `generator_state` from which to generate an element.
Returns:
A nested structure of tensors representing an element from the iterator.
"""
def generator_py_func(iterator_id):
"""A `py_func` that will be called to invoke the iterator."""
try:
values = next(generator_state.get_iterator(iterator_id))
except StopIteration:
generator_state.iterator_completed(iterator_id)
raise StopIteration("Iteration finished.")
# Use the same _convert function from the py_func() implementation to
# convert the returned values to arrays early, so that we can inspect
# their values.
# pylint: disable=protected-access
ret_arrays = [script_ops.FuncRegistry._convert(ret)
for ret in nest.flatten_up_to(output_types, values)]
# pylint: enable=protected-access
# Additional type and shape checking to ensure that the components
# of the generated element match the `output_types` and `output_shapes`
# arguments.
for (ret_array, expected_dtype, expected_shape) in zip(
ret_arrays, flattened_types, flattened_shapes):
if ret_array.dtype != expected_dtype.as_numpy_dtype:
raise TypeError(
"`generator` yielded an element of type %s where an element "
"of type %s was expected."
% (ret_array.dtype, expected_dtype.as_numpy_dtype))
if not expected_shape.is_compatible_with(ret_array.shape):
raise ValueError(
"`generator` yielded an element of shape %s where an element "
"of shape %s was expected." % (ret_array.shape, expected_shape))
return ret_arrays
flat_values = script_ops.py_func(
generator_py_func, [iterator_id_t], flattened_types, stateful=True)
# The `py_func()` op drops the inferred shapes, so we add them back in
# here.
if output_shapes is not None:
for ret_t, shape in zip(flat_values, flattened_shapes):
ret_t.set_shape(shape)
return nest.pack_sequence_as(output_types, flat_values)
# This function associates each traversal of `generator` with a unique
# iterator ID.
def flat_map_fn(iterator_id_t):
# First, generate an infinite dataset containing the iterator ID repeated
# forever.
repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None)
# The `generator_map_fn` gets the next element from the iterator with the
# relevant ID, and raises StopIteration when that iterator contains no
# more elements.
return repeated_id.map(generator_map_fn)
# A single-element dataset that, each time it is evaluated, contains a
# freshly-generated and unique (for the returned dataset) int64
# ID that will be used to identify the appropriate Python state, which
# is encapsulated in `generator_state`, and captured in
# `get_iterator_id_map_fn`.
dummy = 0
id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn)
# A dataset that contains all of the elements generated by a
# single iterator created from `generator`, identified by the
# iterator ID contained in `id_dataset`. Lifting the iteration
# into a flat_map here enables multiple repetitions and/or nested
# versions of the returned dataset to be created, because it forces
# the generation of a new ID for each version.
return id_dataset.flat_map(flat_map_fn)
@staticmethod
def range(*args):
"""Creates a `Dataset` of a step-separated range of values.
@ -1123,6 +1288,11 @@ class Dataset(object):
}
```
NOTE: The order of elements yielded by this transformation is
deterministic, as long as `map_func` is a pure function. If
`map_func` contains any stateful operations, the order in which
that state is accessed is undefined.
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
@ -1821,6 +1991,19 @@ class MapDataset(Dataset):
else:
ret = map_func(nested_args)
# If `map_func` returns a list of tensors, `nest.flatten()` and
# `ops.convert_to_tensor()` would conspire to attempt to stack
# those tensors into a single tensor, because the customized
# version of `nest.flatten()` does not recurse into lists. Since
# it is more likely that the list arose from returning the
# result of an operation (such as `tf.py_func()`) that returns a
# list of not-necessarily-stackable tensors, we treat the
# returned value is a `tuple` instead. A user wishing to pack
# the return value into a single tensor can use an explicit
# `tf.stack()` before returning.
if isinstance(ret, list):
ret = tuple(ret)
# Extract shape information from the returned values.
flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)]
self._output_shapes = nest.pack_sequence_as(

View File

@ -40,6 +40,14 @@ import six as _six
from tensorflow.python.util.all_util import remove_undocumented
def _sorted(dict_):
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
try:
return sorted(_six.iterkeys(dict_))
except TypeError:
raise TypeError("nest only supports dicts with sortable keys.")
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
@ -51,9 +59,13 @@ def _sequence_like(instance, args):
`args` with the type of `instance`.
"""
if isinstance(instance, dict):
# This is a dict. Iterate over the keys in sorted order to make
# this deterministic.
return {k: v for k, v in zip(sorted(instance.keys()), args)}
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
# ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back).
result = dict(zip(_sorted(instance), args))
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
elif (isinstance(instance, tuple) and
hasattr(instance, "_fields") and
isinstance(instance._fields, _collections.Sequence) and
@ -65,16 +77,22 @@ def _sequence_like(instance, args):
return type(instance)(args)
def _elements_of(nest):
if isinstance(nest, dict):
# Iterate over dict keys in sorted order to make this deterministic.
return [v for _, v in sorted(nest.items())]
def _yield_value(iterable):
if isinstance(iterable, dict):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
# ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back).
for key in _sorted(iterable):
yield iterable[key]
else:
return nest
for value in iterable:
yield value
def _yield_flat_nest(nest):
for n in _elements_of(nest):
for n in _yield_value(nest):
if is_sequence(n):
for ni in _yield_flat_nest(n):
yield ni
@ -132,7 +150,7 @@ def _recursive_assert_same_structure(nest1, nest2, check_types):
"structure has type %s, while second structure has type %s."
% (type_nest1, type_nest2))
for n1, n2 in zip(_elements_of(nest1), _elements_of(nest2)):
for n1, n2 in zip(_yield_value(nest1), _yield_value(nest2)):
_recursive_assert_same_structure(n1, n2, check_types)
@ -181,7 +199,7 @@ def _packed_nest_with_indices(structure, flat, index):
(assuming indexing starts from `index`).
"""
packed = []
for s in structure:
for s in _yield_value(structure):
if is_sequence(s):
new_index, child = _packed_nest_with_indices(s, flat, index)
packed.append(_sequence_like(s, child))
@ -286,8 +304,8 @@ def map_structure(func, *structure, **check_types_dict):
def _yield_flat_up_to(shallow_tree, input_tree):
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
if is_sequence(shallow_tree):
for shallow_branch, input_branch in zip(_elements_of(shallow_tree),
_elements_of(input_tree)):
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
_yield_value(input_tree)):
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
yield input_leaf
else:

View File

@ -65,6 +65,73 @@ class NestTest(test.TestCase):
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
def testFlattenDictOrder(self):
"""`flatten` orders dicts by key, including OrderedDicts."""
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = nest.flatten(ordered)
plain_flat = nest.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
def testPackDictOrder(self):
"""Packing orders dicts by key, including OrderedDicts."""
ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
plain_reconstruction = nest.pack_sequence_as(plain, seq)
self.assertEqual(
collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
ordered_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
def testFlattenAndPack_withDicts(self):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
named_tuple = collections.namedtuple("A", ("b", "c"))
mess = (
"z",
named_tuple(3, 4),
{
"c": (
1,
collections.OrderedDict([
("b", 3),
("a", 2),
]),
),
"b": 5
},
17
)
flattened = nest.flatten(mess)
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
structure_of_mess = (
14,
named_tuple("a", True),
{
"c": (
0,
collections.OrderedDict([
("b", 9),
("a", 8),
]),
),
"b": 3
},
"hi everybody",
)
unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
self.assertEqual(unflattened, mess)
# Check also that the OrderedDict was created, with the correct key order.
unflattened_ordered_dict = unflattened[2]["c"][1]
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
def testIsSequence(self):
self.assertFalse(nest.is_sequence("1234"))
self.assertFalse(nest.is_sequence([1, 3, [4, 5]]))

View File

@ -50,6 +50,7 @@ from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
from tensorflow.contrib.distributions.python.ops.sample_stats import *
from tensorflow.contrib.distributions.python.ops.test_util import *
from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import *
from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import *
from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import *
from tensorflow.contrib.distributions.python.ops.wishart import *

View File

@ -0,0 +1,32 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
py_library(
name = "tfe",
srcs = ["tfe.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:core",
"//tensorflow/python/eager:custom_gradient",
"//tensorflow/python/eager:execution_callbacks",
"//tensorflow/python/eager:function",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
"g3doc/sitemap.md",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,80 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Eager execution prototype.
EXPERIMENTAL: APIs here are unstable and likely to change without notice.
To use, at program startup, call `tfe.enable_eager_execution()`.
@@list_devices
@@device
@@defun
@@implicit_gradients
@@implicit_value_and_gradients
@@gradients_function
@@value_and_gradients_function
@@enable_tracing
@@flush_trace
@@run
@@enable_eager_execution
@@custom_gradient
@@add_execution_callback
@@clear_execution_callbacks
@@inf_callback
@@inf_nan_callback
@@nan_callback
@@seterr
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint:disable=g-bad-import-order,g-import-not-at-top,unused-import
#
from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.eager import backprop
from tensorflow.python.eager.custom_gradient import custom_gradient
from tensorflow.python.eager import function
from tensorflow.python.eager.context import context
from tensorflow.python.eager.context import device
from tensorflow.python.eager.context import enable_eager_execution
from tensorflow.python.eager.context import run
from tensorflow.python.eager.core import enable_tracing
from tensorflow.python.eager.execution_callbacks import add_execution_callback
from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks
from tensorflow.python.eager.execution_callbacks import inf_callback
from tensorflow.python.eager.execution_callbacks import inf_nan_callback
from tensorflow.python.eager.execution_callbacks import nan_callback
from tensorflow.python.eager.execution_callbacks import seterr
def list_devices():
return context().devices()
defun = function.defun
implicit_gradients = backprop.implicit_grad
implicit_value_and_gradients = backprop.implicit_val_and_grad
gradients_function = backprop.gradients_function
value_and_gradients_function = backprop.val_and_grad_function
remove_undocumented(__name__)

View File

@ -119,6 +119,18 @@ py_test(
],
)
py_test(
name = "logit_fns_test",
size = "small",
srcs = ["python/learn/estimators/logit_fns_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow/python:client_testlib",
"//tensorflow/python/estimator:model_fn",
],
)
py_test(
name = "estimators_test",
size = "small",

View File

@ -321,6 +321,7 @@ from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassi
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import call_logit_fn
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import dnn_logit_fn_builder
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import linear_logit_fn_builder
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey

View File

@ -21,7 +21,7 @@ should follow the following signature:
Args:
`features`: This is the first item returned from the `input_fn` passed to
`train`, `evaluate`, and `predict`. This should be a single
`Tensor` or `dict` of same.
`Tensor` or `dict` of same, and is the only required argument.
`mode`: Optional. Specifies if this training, evaluation or prediction. See
`ModeKeys`.
`params`: Optional `dict` of hyperparameters. Will receive what is passed to
@ -39,10 +39,47 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import dnn as dnn_core
from tensorflow.python.estimator.canned import linear as linear_core
from tensorflow.python.framework import ops
# pylint: disable=protected-access
dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
linear_logit_fn_builder = linear_core._linear_logit_fn_builder
# pylint: enable=protected-access
def call_logit_fn(logit_fn, features, mode, params, config):
"""Calls logit_fn.
A utility function that calls the provided logit_fn with the relevant subset
of provided arguments. Similar to tf.estimator._call_model_fn().
Args:
logit_fn: A logit_fn as defined above.
features: The features dict.
mode: TRAIN / EVAL / PREDICT ModeKeys.
params: The hyperparameter dict.
config: The configuration object.
Returns:
A logit Tensor, the output of logit_fn.
Raises:
ValueError: if logit_fn does not return a Tensor.
"""
logit_fn_args = util.fn_args(logit_fn)
kwargs = {}
if 'mode' in logit_fn_args:
kwargs['mode'] = mode
if 'params' in logit_fn_args:
kwargs['params'] = params
if 'config' in logit_fn_args:
kwargs['config'] = config
logit_fn_results = logit_fn(features=features, **kwargs)
if not isinstance(logit_fn_results, ops.Tensor):
raise ValueError('model_fn should return a Tensor.')
return logit_fn_results

View File

@ -0,0 +1,60 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""logit_fn tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators import logit_fns
from tensorflow.python.client import session
from tensorflow.python.estimator import model_fn
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class LogitFnTest(test.TestCase):
def test_simple_call_logit_fn(self):
def dummy_logit_fn(features, mode):
if mode == model_fn.ModeKeys.TRAIN:
return features['f1']
else:
return features['f2']
features = {
'f1': constant_op.constant([2., 3.]),
'f2': constant_op.constant([4., 5.])
}
logit_fn_result = logit_fns.call_logit_fn(
dummy_logit_fn, features, model_fn.ModeKeys.EVAL, 'fake_params',
'fake_config')
with session.Session():
self.assertAllClose([[4., 5.]], logit_fn_result.eval())
def test_should_return_tensor(self):
def invalid_logit_fn(features, params):
return {
'tensor1': features['f1'] * params['input_multiplier'],
'tensor2': features['f2'] * params['input_multiplier']
}
features = {
'f1': constant_op.constant([2., 3.]),
'f2': constant_op.constant([4., 5.])
}
params = {'learning_rate': 0.001, 'input_multiplier': 2.0}
with self.assertRaisesRegexp(ValueError, 'model_fn should return a Tensor'):
logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params,
'fake_config')

View File

@ -73,8 +73,9 @@ HOST_INCLUDES := \
-I. \
-I$(MAKEFILE_DIR)/downloads/ \
-I$(MAKEFILE_DIR)/downloads/eigen \
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
-I$(MAKEFILE_DIR)/downloads/nsync/public \
-I$(MAKEFILE_DIR)/downloads/fft2d \
-I$(HOST_GENDIR)
ifeq ($(HAS_GEN_HOST_PROTOC),true)
HOST_INCLUDES += -I$(MAKEFILE_DIR)/gen/protobuf-host/include
@ -156,6 +157,7 @@ INCLUDES := \
-I$(MAKEFILE_DIR)/downloads/eigen \
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
-I$(MAKEFILE_DIR)/downloads/nsync/public \
-I$(MAKEFILE_DIR)/downloads/fft2d \
-I$(PROTOGENDIR) \
-I$(PBTGENDIR)
ifeq ($(HAS_GEN_HOST_PROTOC),true)
@ -237,6 +239,7 @@ ifeq ($(TARGET),ANDROID)
$(error "NDK_ROOT is not defined.")
endif
CXX := $(CC_PREFIX) $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/$(OS_PATH)-x86_64/bin/arm-linux-androideabi-g++
CC := $(CC_PREFIX) $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/$(OS_PATH)-x86_64/bin/arm-linux-androideabi-gcc
CXXFLAGS +=\
--sysroot $(NDK_ROOT)/platforms/android-21/arch-arm \
-Wno-narrowing \
@ -244,7 +247,6 @@ ifeq ($(TARGET),ANDROID)
-mfloat-abi=softfp \
-mfpu=neon \
-fPIE
INCLUDES = \
-I$(NDK_ROOT)/sources/android/support/include \
-I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/include \
@ -254,6 +256,7 @@ ifeq ($(TARGET),ANDROID)
-I$(MAKEFILE_DIR)/downloads/eigen \
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
-I$(MAKEFILE_DIR)/downloads/nsync/public \
-I$(MAKEFILE_DIR)/downloads/fft2d \
-I$(MAKEFILE_DIR)/gen/protobuf/include \
-I$(PROTOGENDIR) \
-I$(PBTGENDIR)
@ -502,11 +505,13 @@ $(wildcard tensorflow/core/user_ops/*.cu.cc) \
$(wildcard tensorflow/core/common_runtime/gpu/*) \
$(wildcard tensorflow/core/common_runtime/gpu_device_factory.*) \
$(wildcard tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.*) \
$(wildcard tensorflow/core/grappler/inputs/file_input_yielder.*) \
$(wildcard tensorflow/core/grappler/clusters/single_machine.*)
# Filter out all the excluded files.
TF_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
# Add in any extra files that don't fit the patterns easily
TF_CC_SRCS += tensorflow/core/platform/default/gpu_tracer.cc
TF_CC_SRCS += tensorflow/contrib/makefile/downloads/fft2d/fftsg.c
# Also include the op and kernel definitions.
TF_CC_SRCS += $(shell cat $(MAKEFILE_DIR)/tf_op_files.txt)
PBT_CC_SRCS := $(shell cat $(MAKEFILE_DIR)/tf_pb_text_files.txt)
@ -529,7 +534,8 @@ tensorflow/core/kernels/hexagon/hexagon_remote_fused_graph_executor_build.cc
endif
# File names of the intermediate files target compilation generates.
TF_CC_OBJS := $(addprefix $(OBJDIR), $(TF_CC_SRCS:.cc=.o))
TF_CC_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_CC_SRCS))))
PBT_GEN_FILES := $(addprefix $(PBTGENDIR), $(PBT_CC_SRCS))
PBT_OBJS := $(addprefix $(OBJDIR), $(PBT_CC_SRCS:.cc=.o))
PROTO_CC_SRCS := $(addprefix $(PROTOGENDIR), $(PROTO_SRCS:.proto=.pb.cc))
@ -567,6 +573,14 @@ $(OBJDIR)%.o: %.cc | $(PBT_GEN_FILES)
$(CXX) $(CXXFLAGS) $(DEPFLAGS) $(INCLUDES) -c $< -o $@
@mv -f $(DEPDIR)/$*.Td $(DEPDIR)/$*.d
# Matches on plain C files.
$(OBJDIR)%.o: %.c
@mkdir -p $(dir $@)
@mkdir -p $(dir $(DEPDIR)$*)
$(CXX) $(patsubst --std=c++11,--std=c99, $(CXXFLAGS)) -x c $(DEPFLAGS) \
$(INCLUDES) -c $< -o $@
@mv -f $(DEPDIR)/$*.Td $(DEPDIR)/$*.d
# Compiles C++ source files that have been generated by protoc.
$(OBJDIR)%.pb.o: $(PROTOGENDIR)%.pb.cc
@mkdir -p $(dir $@)

View File

@ -25,6 +25,7 @@ GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.g
NSYNC_URL="$(grep -o 'http.*github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
PROTOBUF_URL="$(grep -o 'http.*github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
RE2_URL="$(grep -o 'http.*github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
# so work around it by patching the source.
@ -60,6 +61,7 @@ download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest"
download_and_extract "${NSYNC_URL}" "${DOWNLOADS_DIR}/nsync"
download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf"
download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2"
download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d"
replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"

Some files were not shown because too many files have changed in this diff Show More