[tf.data vectorization] Add vectorizer for Add op

PiperOrigin-RevId: 216424512
This commit is contained in:
Rachel Lim 2018-10-09 14:36:33 -07:00 committed by TensorFlower Gardener
parent 35caff9574
commit 950cf87104
6 changed files with 280 additions and 31 deletions
tensorflow
core
python/data/experimental/kernel_tests/optimization

View File

@ -34,7 +34,7 @@ namespace tensorflow {
const int Graph::kControlSlot = -1;
class NodeProperties {
struct NodeProperties {
public:
NodeProperties(const OpDef* op_def, const NodeDef& node_def,
const DataTypeSlice inputs, const DataTypeSlice outputs)

View File

@ -9,7 +9,11 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
VECTORIZER_DEPS = [
":vectorizer_registry",
"//tensorflow/cc:ops",
"//tensorflow/core/grappler/optimizers/data:graph_utils",
"//tensorflow/core:core_cpu",
"//tensorflow/cc:scope_internal",
"//tensorflow/cc:cc_ops",
] + tf_protos_all()
cc_library(
@ -42,6 +46,24 @@ cc_library(
],
)
tf_cc_test(
name = "vectorizer_registry_test",
srcs = ["vectorizer_registry_test.cc"],
deps = [
":vectorizer_registry",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
] + tf_protos_all(),
)
cc_library(
name = "add_vectorizer",
srcs = ["add_vectorizer.cc"],
deps = VECTORIZER_DEPS,
alwayslink = 1,
)
cc_library(
name = "cast_vectorizer",
srcs = ["cast_vectorizer.cc"],
@ -61,20 +83,10 @@ cc_library(
hdrs = ["vectorizer_registry.h"],
visibility = ["//visibility:public"],
deps = [
":add_vectorizer",
":cast_vectorizer",
":unpack_vectorizer",
":vectorizer",
":vectorizer_registry",
],
)
tf_cc_test(
name = "vectorizer_registry_test",
srcs = ["vectorizer_registry_test.cc"],
deps = [
":vectorizer_registry",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
] + tf_protos_all(),
)

View File

@ -0,0 +1,150 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
namespace grappler {
namespace {
const char* const kExpandDimsPrefix = "vectorized/expanddims/";
// Reshapes stacked inputs for broadcast. Stacked inputs have an extra leading
// dimension, which may cause automatic broadcasting rules to expand the
// input dimensions wrongly when the unstacked shapes have different ranks.
// To avoid that, we reshape stacked inputs to the maximum rank they need
// to be broadcasted to.
//
// For example, suppose we have inputs A and B, where A is a stacked tensor with
// shape [n, 5] (where n is the stack size) and B is an unstacked tensor with
// shape [12, 7, 5]. If we added them directly, tensorflow broadcasting rules
// would expand the dimensions of A to [1, n, 5], then (incorrectly) check that
// the dimensions n and 7 are compatible, and if so, create an output of shape
// [12, 7, 5]. However, correct addition of these inputs would create an output
// with shape [n, 12, 7, 5]: we need to manually expand the dimensions of A
// *after* the leading dimension, i.e. expand A to the shape [n, 1, 1, 5] before
// broadcasting.
Status ExpandDimsForBroadcast(std::vector<WrappedTensor>* inputs, Graph* g) {
Status status;
Scope parent = NewInternalScope(g, &status, nullptr);
Scope s = parent.NewSubScope(kExpandDimsPrefix);
// TODO(rachelim): We can potentially get rid of all these ops if shapes are
// known statically
Output const_0 = ops::Const(s, 0);
Output const_1 = ops::Const(s, 1);
std::vector<Output> ranks;
ranks.reserve(inputs->size());
// Get the stacked rank of each input
for (const auto& input : *inputs) {
Output rank = ops::Rank(s, Output(input.node, input.output_index));
if (!input.stacked) {
// If the input is unstacked, add 1
rank = ops::Add(s, rank, const_1);
}
ranks.push_back(rank);
}
// Pack the ranks into one tensor to get the max
Output packed_ranks = ops::Stack(s, ranks);
Output max_rank =
ops::Max(s, packed_ranks, const_0, ops::Max::Attrs().KeepDims(true));
std::vector<WrappedTensor> expanded_inputs;
expanded_inputs.reserve(inputs->size());
// For all inputs that are stacked, expand dimensions after dim 0.
for (size_t i = 0; i < inputs->size(); ++i) {
if (!inputs->at(i).stacked) {
expanded_inputs.push_back(inputs->at(i));
continue;
}
Output input(inputs->at(i).node, inputs->at(i).output_index);
// Number of dimensions to expand
Output rank_diff = ops::Sub(s, max_rank, ranks[i]);
// [1] * rank_diff
Output ones = ops::Tile(s, ops::Const(s, {1}), rank_diff);
Output const_vec_1 = ops::Const(s, {1});
Output shape = ops::Shape(s, input);
// shape[:1]
Output concat_pre =
ops::StridedSlice(s, shape, const_vec_1, const_vec_1, const_vec_1,
ops::StridedSlice::Attrs().BeginMask(1));
// shape[1:]
Output concat_post =
ops::StridedSlice(s, shape, const_vec_1, const_vec_1, const_vec_1,
ops::StridedSlice::Attrs().EndMask(1));
// tf.concat([shape[:1], ones, shape[1:]], 0)
Output new_shape = ops::Concat(s, {concat_pre, ones, concat_post}, const_0);
Output result = ops::Reshape(s, input, new_shape);
expanded_inputs.push_back({result.node(), 0, true});
}
inputs->swap(expanded_inputs);
return status;
}
class AddVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
std::vector<WrappedTensor>&& inputs,
std::vector<WrappedTensor>* outputs) override {
if (node.num_inputs() != 2) {
return errors::Internal("Add op should only have two inputs.");
}
TF_RETURN_IF_ERROR(ExpandDimsForBroadcast(&inputs, outer_scope));
// Add new Add node with the same op and attrs as the original node
Node* new_add_node;
TF_RETURN_IF_ERROR(NodeBuilder("Add", "Add")
.Input(inputs[0].node, inputs[0].output_index)
.Input(inputs[1].node, inputs[1].output_index)
.Finalize(outer_scope, &new_add_node));
// Add output mappings
outputs->push_back({new_add_node, 0, true});
return Status::OK();
}
};
REGISTER_VECTORIZER("Add", AddVectorizer);
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -64,9 +64,18 @@ void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
}
}
// Update node attrs to keep its properties consistent with the function
void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) {
map_defun_node->AddAttr("output_types", map_defun_fn->ret_types);
// TODO(rachelim): Propagate precise shapes if they're known, which may enable
// subsequent optimizations.
map_defun_node->AddAttr("output_shapes", std::vector<PartialTensorShape>(
map_defun_fn->ret_types.size()));
}
Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
const TensorDesc& output) {
// Note that we don't update MapDefun attrs as we go, only when we are done
DataType type = output.first->output_type(output.second);
int index = map_defun_fn->ret_nodes.size();
@ -83,13 +92,13 @@ Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
map_defun_fn->ret_nodes.push_back(ret_node);
map_defun_fn->ret_types.push_back(type);
UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
return s;
}
void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
FunctionBody* map_defun_fn, Node* map_defun_node) {
// Note that we don't update MapDefun attrs as we go, only when we are done
DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
<< "Trying to remove output that doesn't exist. Output number: "
<< output_position;
@ -102,6 +111,7 @@ void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
output_position);
map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
output_position);
UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
// Renumber the nodes and edges that come after
for (int i = 0; i < num_later_outputs; ++i) {
@ -342,13 +352,6 @@ void Vectorization::VectorizeHelper() {
// need the MapDefun node and can delete it.
if (map_defun_fn_->ret_nodes.empty()) {
outer_scope_->RemoveNode(map_defun_node_);
} else {
// Update MapDefun node attrs accordingly
DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
map_defun_node_->AddAttr(
"output_shapes",
std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
}

View File

@ -145,7 +145,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
FunctionDef* vectorized;
Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
LOG(ERROR) << s;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
@ -237,7 +237,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
auto map_defun_node = vectorized->node_def(
function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
@ -311,7 +311,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& cast_node = vectorized->node_def(
@ -389,7 +389,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& cast_node = vectorized->node_def(
@ -475,7 +475,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& unpack_node = vectorized->node_def(
@ -574,7 +574,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& cast_node = vectorized->node_def(
@ -654,7 +654,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
// They should be unchanged
// We check this somewhat manually as the names of nodes may have changed
EXPECT_EQ(vectorized->node_def_size(), 1);
@ -738,7 +738,7 @@ TEST(VectorizeMapDefunTest, VectorizeConst) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized));
@ -817,7 +817,7 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
auto const_node = vectorized->node_def(
@ -902,7 +902,7 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
*lib.add_function() = inner;
FunctionDef* vectorized;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
auto find_const = [vectorized](int val) -> const NodeDef* {
for (const auto& n : vectorized->node_def()) {
@ -924,6 +924,89 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name()));
}
// Before:
//
// +------+
// +-----------------+ Arg0 +----------------------+
// | +---+--+ |
// | | |
// | +---v--+ |
// | +-------------+ Arg0 +------------------+ |
// | | +---+--+ | |
// | | | | |
// | | | +-----+ | |
// | | | |Const| | |
// | | | +-+---+ | |
// | | | | | |
// | | | +--------+ | |
// | | | | | |
// | | +-v---v-+ | |
// | | | Add | | |
// | | +-+-----+ | |
// | | | | |
// | | | | |
// | | MapDefun +-v----+ | |
// | +---------------| Ret |----------------+ |
// | +--v---+ |
// | | |
// | | |
// | +--v---- |
// +-------------------| Ret |--------------------+
// +------+
//
//
// After:
//
// +------+
// +------------+ Arg0 +----------------------+
// | +---+--+ |
// | | |
// | | +-----+ |
// | | |Const| |
// | +-v---------+ +--+--+ |
// | |ExpandDims*| | |
// | +-----+-----+ | |
// | | | |
// | +-----+ +-----+ |
// | | | |
// | +-v-v-+ |
// | | Add | |
// | +--+--+ |
// | | |
// | +---v--+ |
// +-----------------------+ Ret +-----------+
// +------+
//
TEST(VectorizeMapDefunTest, VectorizeDefunAdd) {
// Note that this checks that the "Add" vectorizer is successful, but does not
// check that the transformed function is correct (i.e. produces the same
// output as the unvectorized map defun). For the latter, the tests are in
// tensorflow/python/data/experimental/kernel_tests/optimization/
// map_vectorization_test.py
FunctionDef inner = FunctionDefHelper::Create(
"inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */},
{/* nodes */ FunctionDefHelper::Const("Const", 2),
{{"Add"}, "Add", {"arg0", "Const:output:0"}, {{"T", DT_INT32}}}},
{{"ret0", "Add:z:0"}});
FunctionDef outer = FunctionDefHelper::Create(
"outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"},
{/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
NodeDef* map_defun =
AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}},
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
FunctionDefLibrary lib;
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
}
// TODO(rachelim): More test cases when we get around to implementing them:
// [] A badly defined converter, e.g. doesn't produce nodes that have the
// same number of outputs/inputs as the nodes to be converted

View File

@ -80,6 +80,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
("Basic", lambda x: (x, x + 1), None),
("Const", lambda x: 2, 12),
("Parallel", lambda x: (x, x + 1), 12),
("Broadcast", lambda x: x + np.random.rand(5, 4, 3, 2), None),
("Gather", lambda x: array_ops.gather(x, 0), 12),
)
def testOptimization(self, map_fn, num_parallel_calls):