[tf.data vectorization] Add vectorizer for Add
op
PiperOrigin-RevId: 216424512
This commit is contained in:
parent
35caff9574
commit
950cf87104
tensorflow
core
graph
grappler/optimizers/data
python/data/experimental/kernel_tests/optimization
@ -34,7 +34,7 @@ namespace tensorflow {
|
||||
|
||||
const int Graph::kControlSlot = -1;
|
||||
|
||||
class NodeProperties {
|
||||
struct NodeProperties {
|
||||
public:
|
||||
NodeProperties(const OpDef* op_def, const NodeDef& node_def,
|
||||
const DataTypeSlice inputs, const DataTypeSlice outputs)
|
||||
|
@ -9,7 +9,11 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
|
||||
|
||||
VECTORIZER_DEPS = [
|
||||
":vectorizer_registry",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/core/grappler/optimizers/data:graph_utils",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/cc:scope_internal",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
] + tf_protos_all()
|
||||
|
||||
cc_library(
|
||||
@ -42,6 +46,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "vectorizer_registry_test",
|
||||
srcs = ["vectorizer_registry_test.cc"],
|
||||
deps = [
|
||||
":vectorizer_registry",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
] + tf_protos_all(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "add_vectorizer",
|
||||
srcs = ["add_vectorizer.cc"],
|
||||
deps = VECTORIZER_DEPS,
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cast_vectorizer",
|
||||
srcs = ["cast_vectorizer.cc"],
|
||||
@ -61,20 +83,10 @@ cc_library(
|
||||
hdrs = ["vectorizer_registry.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":add_vectorizer",
|
||||
":cast_vectorizer",
|
||||
":unpack_vectorizer",
|
||||
":vectorizer",
|
||||
":vectorizer_registry",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "vectorizer_registry_test",
|
||||
srcs = ["vectorizer_registry_test.cc"],
|
||||
deps = [
|
||||
":vectorizer_registry",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
] + tf_protos_all(),
|
||||
)
|
||||
|
@ -0,0 +1,150 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
const char* const kExpandDimsPrefix = "vectorized/expanddims/";
|
||||
|
||||
// Reshapes stacked inputs for broadcast. Stacked inputs have an extra leading
|
||||
// dimension, which may cause automatic broadcasting rules to expand the
|
||||
// input dimensions wrongly when the unstacked shapes have different ranks.
|
||||
// To avoid that, we reshape stacked inputs to the maximum rank they need
|
||||
// to be broadcasted to.
|
||||
//
|
||||
// For example, suppose we have inputs A and B, where A is a stacked tensor with
|
||||
// shape [n, 5] (where n is the stack size) and B is an unstacked tensor with
|
||||
// shape [12, 7, 5]. If we added them directly, tensorflow broadcasting rules
|
||||
// would expand the dimensions of A to [1, n, 5], then (incorrectly) check that
|
||||
// the dimensions n and 7 are compatible, and if so, create an output of shape
|
||||
// [12, 7, 5]. However, correct addition of these inputs would create an output
|
||||
// with shape [n, 12, 7, 5]: we need to manually expand the dimensions of A
|
||||
// *after* the leading dimension, i.e. expand A to the shape [n, 1, 1, 5] before
|
||||
// broadcasting.
|
||||
Status ExpandDimsForBroadcast(std::vector<WrappedTensor>* inputs, Graph* g) {
|
||||
Status status;
|
||||
Scope parent = NewInternalScope(g, &status, nullptr);
|
||||
Scope s = parent.NewSubScope(kExpandDimsPrefix);
|
||||
|
||||
// TODO(rachelim): We can potentially get rid of all these ops if shapes are
|
||||
// known statically
|
||||
|
||||
Output const_0 = ops::Const(s, 0);
|
||||
Output const_1 = ops::Const(s, 1);
|
||||
|
||||
std::vector<Output> ranks;
|
||||
ranks.reserve(inputs->size());
|
||||
|
||||
// Get the stacked rank of each input
|
||||
for (const auto& input : *inputs) {
|
||||
Output rank = ops::Rank(s, Output(input.node, input.output_index));
|
||||
|
||||
if (!input.stacked) {
|
||||
// If the input is unstacked, add 1
|
||||
rank = ops::Add(s, rank, const_1);
|
||||
}
|
||||
|
||||
ranks.push_back(rank);
|
||||
}
|
||||
|
||||
// Pack the ranks into one tensor to get the max
|
||||
Output packed_ranks = ops::Stack(s, ranks);
|
||||
|
||||
Output max_rank =
|
||||
ops::Max(s, packed_ranks, const_0, ops::Max::Attrs().KeepDims(true));
|
||||
|
||||
std::vector<WrappedTensor> expanded_inputs;
|
||||
expanded_inputs.reserve(inputs->size());
|
||||
|
||||
// For all inputs that are stacked, expand dimensions after dim 0.
|
||||
for (size_t i = 0; i < inputs->size(); ++i) {
|
||||
if (!inputs->at(i).stacked) {
|
||||
expanded_inputs.push_back(inputs->at(i));
|
||||
continue;
|
||||
}
|
||||
|
||||
Output input(inputs->at(i).node, inputs->at(i).output_index);
|
||||
|
||||
// Number of dimensions to expand
|
||||
Output rank_diff = ops::Sub(s, max_rank, ranks[i]);
|
||||
|
||||
// [1] * rank_diff
|
||||
Output ones = ops::Tile(s, ops::Const(s, {1}), rank_diff);
|
||||
|
||||
Output const_vec_1 = ops::Const(s, {1});
|
||||
|
||||
Output shape = ops::Shape(s, input);
|
||||
|
||||
// shape[:1]
|
||||
Output concat_pre =
|
||||
ops::StridedSlice(s, shape, const_vec_1, const_vec_1, const_vec_1,
|
||||
ops::StridedSlice::Attrs().BeginMask(1));
|
||||
|
||||
// shape[1:]
|
||||
Output concat_post =
|
||||
ops::StridedSlice(s, shape, const_vec_1, const_vec_1, const_vec_1,
|
||||
ops::StridedSlice::Attrs().EndMask(1));
|
||||
|
||||
// tf.concat([shape[:1], ones, shape[1:]], 0)
|
||||
Output new_shape = ops::Concat(s, {concat_pre, ones, concat_post}, const_0);
|
||||
|
||||
Output result = ops::Reshape(s, input, new_shape);
|
||||
|
||||
expanded_inputs.push_back({result.node(), 0, true});
|
||||
}
|
||||
|
||||
inputs->swap(expanded_inputs);
|
||||
return status;
|
||||
}
|
||||
|
||||
class AddVectorizer : public Vectorizer {
|
||||
public:
|
||||
Status Vectorize(const Node& node, Graph* outer_scope,
|
||||
std::vector<WrappedTensor>&& inputs,
|
||||
std::vector<WrappedTensor>* outputs) override {
|
||||
if (node.num_inputs() != 2) {
|
||||
return errors::Internal("Add op should only have two inputs.");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(ExpandDimsForBroadcast(&inputs, outer_scope));
|
||||
|
||||
// Add new Add node with the same op and attrs as the original node
|
||||
Node* new_add_node;
|
||||
TF_RETURN_IF_ERROR(NodeBuilder("Add", "Add")
|
||||
.Input(inputs[0].node, inputs[0].output_index)
|
||||
.Input(inputs[1].node, inputs[1].output_index)
|
||||
.Finalize(outer_scope, &new_add_node));
|
||||
|
||||
// Add output mappings
|
||||
outputs->push_back({new_add_node, 0, true});
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_VECTORIZER("Add", AddVectorizer);
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -64,9 +64,18 @@ void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
|
||||
}
|
||||
}
|
||||
|
||||
// Update node attrs to keep its properties consistent with the function
|
||||
void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) {
|
||||
map_defun_node->AddAttr("output_types", map_defun_fn->ret_types);
|
||||
|
||||
// TODO(rachelim): Propagate precise shapes if they're known, which may enable
|
||||
// subsequent optimizations.
|
||||
map_defun_node->AddAttr("output_shapes", std::vector<PartialTensorShape>(
|
||||
map_defun_fn->ret_types.size()));
|
||||
}
|
||||
|
||||
Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
|
||||
const TensorDesc& output) {
|
||||
// Note that we don't update MapDefun attrs as we go, only when we are done
|
||||
DataType type = output.first->output_type(output.second);
|
||||
int index = map_defun_fn->ret_nodes.size();
|
||||
|
||||
@ -83,13 +92,13 @@ Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
|
||||
map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
|
||||
map_defun_fn->ret_nodes.push_back(ret_node);
|
||||
map_defun_fn->ret_types.push_back(type);
|
||||
UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
|
||||
FunctionBody* map_defun_fn, Node* map_defun_node) {
|
||||
// Note that we don't update MapDefun attrs as we go, only when we are done
|
||||
DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
|
||||
<< "Trying to remove output that doesn't exist. Output number: "
|
||||
<< output_position;
|
||||
@ -102,6 +111,7 @@ void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
|
||||
output_position);
|
||||
map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
|
||||
output_position);
|
||||
UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
|
||||
|
||||
// Renumber the nodes and edges that come after
|
||||
for (int i = 0; i < num_later_outputs; ++i) {
|
||||
@ -342,13 +352,6 @@ void Vectorization::VectorizeHelper() {
|
||||
// need the MapDefun node and can delete it.
|
||||
if (map_defun_fn_->ret_nodes.empty()) {
|
||||
outer_scope_->RemoveNode(map_defun_node_);
|
||||
} else {
|
||||
// Update MapDefun node attrs accordingly
|
||||
DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
|
||||
map_defun_node_->AddAttr(
|
||||
"output_shapes",
|
||||
std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
|
||||
map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -145,7 +145,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
|
||||
FunctionDef* vectorized;
|
||||
Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
|
||||
LOG(ERROR) << s;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
EXPECT_TRUE(
|
||||
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
|
||||
EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
|
||||
@ -237,7 +237,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
|
||||
*lib.add_function() = outer;
|
||||
*lib.add_function() = inner;
|
||||
FunctionDef* vectorized;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
|
||||
auto map_defun_node = vectorized->node_def(
|
||||
function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
|
||||
@ -311,7 +311,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
|
||||
*lib.add_function() = outer;
|
||||
*lib.add_function() = inner;
|
||||
FunctionDef* vectorized;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
EXPECT_TRUE(
|
||||
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
|
||||
const NodeDef& cast_node = vectorized->node_def(
|
||||
@ -389,7 +389,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
|
||||
*lib.add_function() = outer;
|
||||
*lib.add_function() = inner;
|
||||
FunctionDef* vectorized;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
EXPECT_TRUE(
|
||||
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
|
||||
const NodeDef& cast_node = vectorized->node_def(
|
||||
@ -475,7 +475,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
|
||||
*lib.add_function() = outer;
|
||||
*lib.add_function() = inner;
|
||||
FunctionDef* vectorized;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
EXPECT_TRUE(
|
||||
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
|
||||
const NodeDef& unpack_node = vectorized->node_def(
|
||||
@ -574,7 +574,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
|
||||
*lib.add_function() = outer;
|
||||
*lib.add_function() = inner;
|
||||
FunctionDef* vectorized;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
EXPECT_TRUE(
|
||||
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
|
||||
const NodeDef& cast_node = vectorized->node_def(
|
||||
@ -654,7 +654,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
|
||||
*lib.add_function() = outer;
|
||||
*lib.add_function() = inner;
|
||||
FunctionDef* vectorized;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
// They should be unchanged
|
||||
// We check this somewhat manually as the names of nodes may have changed
|
||||
EXPECT_EQ(vectorized->node_def_size(), 1);
|
||||
@ -738,7 +738,7 @@ TEST(VectorizeMapDefunTest, VectorizeConst) {
|
||||
*lib.add_function() = outer;
|
||||
*lib.add_function() = inner;
|
||||
FunctionDef* vectorized;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
EXPECT_TRUE(
|
||||
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
|
||||
EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized));
|
||||
@ -817,7 +817,7 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) {
|
||||
*lib.add_function() = outer;
|
||||
*lib.add_function() = inner;
|
||||
FunctionDef* vectorized;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
EXPECT_TRUE(
|
||||
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
|
||||
auto const_node = vectorized->node_def(
|
||||
@ -902,7 +902,7 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
|
||||
*lib.add_function() = inner;
|
||||
|
||||
FunctionDef* vectorized;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
|
||||
auto find_const = [vectorized](int val) -> const NodeDef* {
|
||||
for (const auto& n : vectorized->node_def()) {
|
||||
@ -924,6 +924,89 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
|
||||
EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name()));
|
||||
}
|
||||
|
||||
// Before:
|
||||
//
|
||||
// +------+
|
||||
// +-----------------+ Arg0 +----------------------+
|
||||
// | +---+--+ |
|
||||
// | | |
|
||||
// | +---v--+ |
|
||||
// | +-------------+ Arg0 +------------------+ |
|
||||
// | | +---+--+ | |
|
||||
// | | | | |
|
||||
// | | | +-----+ | |
|
||||
// | | | |Const| | |
|
||||
// | | | +-+---+ | |
|
||||
// | | | | | |
|
||||
// | | | +--------+ | |
|
||||
// | | | | | |
|
||||
// | | +-v---v-+ | |
|
||||
// | | | Add | | |
|
||||
// | | +-+-----+ | |
|
||||
// | | | | |
|
||||
// | | | | |
|
||||
// | | MapDefun +-v----+ | |
|
||||
// | +---------------| Ret |----------------+ |
|
||||
// | +--v---+ |
|
||||
// | | |
|
||||
// | | |
|
||||
// | +--v---- |
|
||||
// +-------------------| Ret |--------------------+
|
||||
// +------+
|
||||
//
|
||||
//
|
||||
// After:
|
||||
//
|
||||
// +------+
|
||||
// +------------+ Arg0 +----------------------+
|
||||
// | +---+--+ |
|
||||
// | | |
|
||||
// | | +-----+ |
|
||||
// | | |Const| |
|
||||
// | +-v---------+ +--+--+ |
|
||||
// | |ExpandDims*| | |
|
||||
// | +-----+-----+ | |
|
||||
// | | | |
|
||||
// | +-----+ +-----+ |
|
||||
// | | | |
|
||||
// | +-v-v-+ |
|
||||
// | | Add | |
|
||||
// | +--+--+ |
|
||||
// | | |
|
||||
// | +---v--+ |
|
||||
// +-----------------------+ Ret +-----------+
|
||||
// +------+
|
||||
//
|
||||
TEST(VectorizeMapDefunTest, VectorizeDefunAdd) {
|
||||
// Note that this checks that the "Add" vectorizer is successful, but does not
|
||||
// check that the transformed function is correct (i.e. produces the same
|
||||
// output as the unvectorized map defun). For the latter, the tests are in
|
||||
// tensorflow/python/data/experimental/kernel_tests/optimization/
|
||||
// map_vectorization_test.py
|
||||
FunctionDef inner = FunctionDefHelper::Create(
|
||||
"inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */},
|
||||
{/* nodes */ FunctionDefHelper::Const("Const", 2),
|
||||
{{"Add"}, "Add", {"arg0", "Const:output:0"}, {{"T", DT_INT32}}}},
|
||||
{{"ret0", "Add:z:0"}});
|
||||
|
||||
FunctionDef outer = FunctionDefHelper::Create(
|
||||
"outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"},
|
||||
{/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
|
||||
|
||||
NodeDef* map_defun =
|
||||
AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}},
|
||||
inner.signature().name(), &outer);
|
||||
CHECK_NOTNULL(map_defun);
|
||||
|
||||
FunctionDefLibrary lib;
|
||||
*lib.add_function() = outer;
|
||||
*lib.add_function() = inner;
|
||||
FunctionDef* vectorized;
|
||||
TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
|
||||
EXPECT_TRUE(
|
||||
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
|
||||
}
|
||||
|
||||
// TODO(rachelim): More test cases when we get around to implementing them:
|
||||
// [] A badly defined converter, e.g. doesn't produce nodes that have the
|
||||
// same number of outputs/inputs as the nodes to be converted
|
||||
|
@ -80,6 +80,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
("Basic", lambda x: (x, x + 1), None),
|
||||
("Const", lambda x: 2, 12),
|
||||
("Parallel", lambda x: (x, x + 1), 12),
|
||||
("Broadcast", lambda x: x + np.random.rand(5, 4, 3, 2), None),
|
||||
("Gather", lambda x: array_ops.gather(x, 0), 12),
|
||||
)
|
||||
def testOptimization(self, map_fn, num_parallel_calls):
|
||||
|
Loading…
Reference in New Issue
Block a user