[Grappler] Add arithmetic optimizer stage for optimizing tf.nn.embedding_lookup_sparse()
.
This optimization eliminates unnecessary `tf.unique()` and `tf.gather()` operations from `tf.nn.embedding_lookup_sparse()` when the embeddings are unpartitioned (e.g. at inference time) and weights are not used. Instead, the `tf.sparse.segment_<combiner>()` operation is applied directly to the embedding matrix without uniquifying the IDs. PiperOrigin-RevId: 308338718 Change-Id: I6e689610a8f4f3dd0a3e8af77cce609ac6d4f9f9
This commit is contained in:
parent
64b83e8724
commit
3b336d3173
@ -76,6 +76,15 @@ bool IsAnyMin(const NodeDef& node) {
|
|||||||
return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
|
return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsAnySparseSegmentReduction(const NodeDef& node) {
|
||||||
|
const auto& op = node.op();
|
||||||
|
return op == "SparseSegmentSum" || op == "SparseSegmentSumWithNumSegments" ||
|
||||||
|
op == "SparseSegmentMean" ||
|
||||||
|
op == "SparseSegmentMeanWithNumSegments" ||
|
||||||
|
op == "SparseSegmentSqrtN" ||
|
||||||
|
op == "SparseSegmentSqrtNWithNumSegments";
|
||||||
|
}
|
||||||
|
|
||||||
bool IsApproximateEqual(const NodeDef& node) {
|
bool IsApproximateEqual(const NodeDef& node) {
|
||||||
return node.op() == "ApproximateEqual";
|
return node.op() == "ApproximateEqual";
|
||||||
}
|
}
|
||||||
@ -268,6 +277,11 @@ bool IsFusedBatchNormGrad(const NodeDef& node) {
|
|||||||
op == "FusedBatchNormGradV3";
|
op == "FusedBatchNormGradV3";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsGather(const NodeDef& node) {
|
||||||
|
const auto& op = node.op();
|
||||||
|
return op == "Gather" || op == "GatherV2";
|
||||||
|
}
|
||||||
|
|
||||||
bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
|
bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
|
||||||
|
|
||||||
bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
|
bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
|
||||||
@ -589,6 +603,11 @@ bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
|
|||||||
|
|
||||||
bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
|
bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
|
||||||
|
|
||||||
|
bool IsUnique(const NodeDef& node) {
|
||||||
|
const auto& op = node.op();
|
||||||
|
return op == "Unique" || op == "UniqueV2";
|
||||||
|
}
|
||||||
|
|
||||||
bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
|
bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
|
||||||
|
|
||||||
bool IsVariable(const NodeDef& node) {
|
bool IsVariable(const NodeDef& node) {
|
||||||
|
@ -34,6 +34,7 @@ bool IsAnyMax(const NodeDef& node);
|
|||||||
bool IsAnyMaxPool(const NodeDef& node);
|
bool IsAnyMaxPool(const NodeDef& node);
|
||||||
bool IsAnyMin(const NodeDef& node);
|
bool IsAnyMin(const NodeDef& node);
|
||||||
bool IsAnyMul(const NodeDef& node);
|
bool IsAnyMul(const NodeDef& node);
|
||||||
|
bool IsAnySparseSegmentReduction(const NodeDef& node);
|
||||||
bool IsApproximateEqual(const NodeDef& node);
|
bool IsApproximateEqual(const NodeDef& node);
|
||||||
bool IsArg(const NodeDef& node);
|
bool IsArg(const NodeDef& node);
|
||||||
bool IsArgMax(const NodeDef& node);
|
bool IsArgMax(const NodeDef& node);
|
||||||
@ -81,6 +82,7 @@ bool IsFloorMod(const NodeDef& node);
|
|||||||
bool IsFusedBatchNorm(const NodeDef& node);
|
bool IsFusedBatchNorm(const NodeDef& node);
|
||||||
bool IsFusedBatchNormEx(const NodeDef& node);
|
bool IsFusedBatchNormEx(const NodeDef& node);
|
||||||
bool IsFusedBatchNormGrad(const NodeDef& node);
|
bool IsFusedBatchNormGrad(const NodeDef& node);
|
||||||
|
bool IsGather(const NodeDef& node);
|
||||||
bool IsGreater(const NodeDef& node);
|
bool IsGreater(const NodeDef& node);
|
||||||
bool IsGreaterEqual(const NodeDef& node);
|
bool IsGreaterEqual(const NodeDef& node);
|
||||||
bool IsHistogramSummary(const NodeDef& node);
|
bool IsHistogramSummary(const NodeDef& node);
|
||||||
@ -187,6 +189,7 @@ bool IsTile(const NodeDef& node);
|
|||||||
bool IsTranspose(const NodeDef& node);
|
bool IsTranspose(const NodeDef& node);
|
||||||
bool IsTruncateDiv(const NodeDef& node);
|
bool IsTruncateDiv(const NodeDef& node);
|
||||||
bool IsTruncateMod(const NodeDef& node);
|
bool IsTruncateMod(const NodeDef& node);
|
||||||
|
bool IsUnique(const NodeDef& node);
|
||||||
bool IsUnpack(const NodeDef& node);
|
bool IsUnpack(const NodeDef& node);
|
||||||
bool IsVariable(const NodeDef& node);
|
bool IsVariable(const NodeDef& node);
|
||||||
bool IsWhile(const NodeDef& node);
|
bool IsWhile(const NodeDef& node);
|
||||||
|
@ -3466,6 +3466,108 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Eliminates unnecessary copies during sparse embedding lookup operations.
|
||||||
|
//
|
||||||
|
// For non-partitioned variables, the `tf.nn.embedding_lookup_sparse()` function
|
||||||
|
// generates code of the form:
|
||||||
|
//
|
||||||
|
// embeddings = <a 2D Tensor>
|
||||||
|
// sparse_ids = <a tf.int64 SparseTensor>
|
||||||
|
// segment_ids = sparse_ids.indices[:, 0]
|
||||||
|
// ids, idx = tf.unique(sparse_ids.values)
|
||||||
|
// gathered_rows = tf.gather(params, ids)
|
||||||
|
// result = tf.sparse.segment_<combiner>(gathered_rows, idx, segment_ids)
|
||||||
|
//
|
||||||
|
// In this case, all of the work in `tf.unique()` and `tf.gather()`
|
||||||
|
// can be avoided by passing the full embeddings to
|
||||||
|
// `tf.sparse.segment_<combiner>()` and performing the same amount of
|
||||||
|
// computation (but fewer copies and allocations) as follows:
|
||||||
|
//
|
||||||
|
// embeddings = <a 2D Tensor>
|
||||||
|
// sparse_ids = <a tf.int64 SparseTensor>
|
||||||
|
// segment_ids = sparse_ids.indices[:, 0]
|
||||||
|
// result = tf.sparse.segment_<combiner>(
|
||||||
|
// embeddings, sparse_ids.values, segment_ids)
|
||||||
|
class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage {
|
||||||
|
public:
|
||||||
|
explicit SimplifyEmbeddingLookupStage(
|
||||||
|
const GraphOptimizerContext& ctx,
|
||||||
|
const ArithmeticOptimizerContext& ctx_ext)
|
||||||
|
: ArithmeticOptimizerStage("SimplifyEmbeddingLookupStage", ctx, ctx_ext) {
|
||||||
|
}
|
||||||
|
~SimplifyEmbeddingLookupStage() override = default;
|
||||||
|
|
||||||
|
bool IsSupported(const NodeDef* node) const override {
|
||||||
|
return IsAnySparseSegmentReduction(*node);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TrySimplify(NodeDef* reduction_node,
|
||||||
|
string* simplified_node_name) override {
|
||||||
|
if (IsInPreserveSet(*reduction_node)) return Status::OK();
|
||||||
|
|
||||||
|
// Input 0 (data) of the reduction node must be a tf.gather() on the 0th
|
||||||
|
// axis.
|
||||||
|
NodeDef* gather_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &gather_node));
|
||||||
|
if (!IsGather(*gather_node) || IsInPreserveSet(*gather_node) ||
|
||||||
|
gather_node->device() != reduction_node->device())
|
||||||
|
return Status::OK();
|
||||||
|
if (gather_node->op() == "GatherV2" && !IsAxis0(*gather_node, 2))
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
// Input 1 (indices) of the gather node must be a tf.unique() on the 0th
|
||||||
|
// axis.
|
||||||
|
NodeDef* unique_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(1), &unique_node));
|
||||||
|
if (!IsUnique(*unique_node) || IsInPreserveSet(*unique_node) ||
|
||||||
|
unique_node->device() != gather_node->device())
|
||||||
|
return Status::OK();
|
||||||
|
if (unique_node->op() == "UniqueV2" && !IsAxis0(*unique_node, 1))
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
DataType unique_element_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(*unique_node, "T", &unique_element_type));
|
||||||
|
|
||||||
|
// Input 1 (indices) of the reduction node must be output 1 of the unique
|
||||||
|
// node.
|
||||||
|
const TensorId idx_tensor = ParseTensorName(reduction_node->input(1));
|
||||||
|
if (idx_tensor != TensorId(unique_node->name(), 1)) return Status::OK();
|
||||||
|
|
||||||
|
// Input 0 (data) of the reduction node becomes input 1 (params) of the
|
||||||
|
// gather node.
|
||||||
|
reduction_node->set_input(0, gather_node->input(0));
|
||||||
|
ctx().node_map->UpdateInput(reduction_node->name(),
|
||||||
|
reduction_node->input(0),
|
||||||
|
gather_node->input(0));
|
||||||
|
|
||||||
|
// Input 1 (indices) of the reduction node becomes input 0 (x) of the unique
|
||||||
|
// node.
|
||||||
|
reduction_node->set_input(1, unique_node->input(0));
|
||||||
|
ctx().node_map->UpdateInput(reduction_node->name(),
|
||||||
|
reduction_node->input(1),
|
||||||
|
unique_node->input(0));
|
||||||
|
SetDataTypeToAttr(unique_element_type, "Tidx", reduction_node);
|
||||||
|
|
||||||
|
*simplified_node_name = reduction_node->name();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool IsAxis0(const NodeDef& node, int axis_input) {
|
||||||
|
Tensor axis_tensor;
|
||||||
|
if (!GetTensorFromConstNode(node.input(axis_input), &axis_tensor))
|
||||||
|
return false;
|
||||||
|
if (axis_tensor.NumElements() != 1) return false;
|
||||||
|
if (axis_tensor.dtype() == DT_INT32) {
|
||||||
|
return axis_tensor.flat<int32>()(0) == 0;
|
||||||
|
} else if (axis_tensor.dtype() == DT_INT64) {
|
||||||
|
return axis_tensor.flat<int64>()(0) == 0;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
|
Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
|
||||||
@ -3538,6 +3640,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
|
|||||||
pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
|
pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
|
||||||
if (options_.fuse_squared_diff)
|
if (options_.fuse_squared_diff)
|
||||||
pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
|
pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
|
||||||
|
if (options_.simplify_embedding_lookup)
|
||||||
|
pipeline.AddStage<SimplifyEmbeddingLookupStage>(ctx, ctx_ext);
|
||||||
|
|
||||||
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
|
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
|
||||||
<< absl::StrJoin(pipeline.StageNames(), ", ");
|
<< absl::StrJoin(pipeline.StageNames(), ", ");
|
||||||
|
@ -85,6 +85,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
|
|||||||
bool convert_expm1 = true;
|
bool convert_expm1 = true;
|
||||||
bool unary_ops_composition = true;
|
bool unary_ops_composition = true;
|
||||||
bool remove_stack_slice_same_axis = true;
|
bool remove_stack_slice_same_axis = true;
|
||||||
|
bool simplify_embedding_lookup = true;
|
||||||
|
|
||||||
// Choose which arithmetic optimizer stages will be enabled for a given
|
// Choose which arithmetic optimizer stages will be enabled for a given
|
||||||
// optimization level by default.
|
// optimization level by default.
|
||||||
|
@ -4085,5 +4085,49 @@ TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
|
|||||||
test::ExpectTensorEqual<bfloat16>(tensors[0], tensors_expected[0]);
|
test::ExpectTensorEqual<bfloat16>(tensors[0], tensors_expected[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) {
|
||||||
|
for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output embeddings = ops::Const(s.WithOpName("embeddings"),
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||||
|
Output segment_ids =
|
||||||
|
ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
|
||||||
|
Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
|
||||||
|
auto unique = ops::Unique(s.WithOpName("unique"), indices,
|
||||||
|
/*attrs=*/{unique_idx_type});
|
||||||
|
Output ids = unique.y;
|
||||||
|
Output idx = unique.idx;
|
||||||
|
Output gathered_rows =
|
||||||
|
ops::Gather(s.WithOpName("gathered_rows"), embeddings, ids);
|
||||||
|
Output result = ops::SparseSegmentSum(s.WithOpName("result"), gathered_rows,
|
||||||
|
idx, segment_ids);
|
||||||
|
Output id = ops::Identity(s.WithOpName("id"), result);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
item.fetch = {"id"};
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
ASSERT_EQ(tensors_expected.size(), 1);
|
||||||
|
|
||||||
|
GraphDef output;
|
||||||
|
ArithmeticOptimizer optimizer;
|
||||||
|
EnableOnlySimplifyEmbeddingLookup(&optimizer);
|
||||||
|
OptimizeAndPrune(&optimizer, &item, &output);
|
||||||
|
|
||||||
|
for (const auto& node : output.node()) {
|
||||||
|
if (node.name() == "result") {
|
||||||
|
EXPECT_EQ(node.input(0), "embeddings");
|
||||||
|
EXPECT_EQ(node.input(1), "indices");
|
||||||
|
}
|
||||||
|
EXPECT_NE(node.op(), "Unique");
|
||||||
|
EXPECT_NE(node.op(), "Gather");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
|
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -223,6 +223,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
|
|||||||
optimizer->options_.remove_stack_slice_same_axis = true;
|
optimizer->options_.remove_stack_slice_same_axis = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EnableOnlySimplifyEmbeddingLookup(ArithmeticOptimizer* optimizer) {
|
||||||
|
DisableAllStages(optimizer);
|
||||||
|
optimizer->options_.simplify_embedding_lookup = true;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void DisableAllStages(ArithmeticOptimizer* optimizer) {
|
void DisableAllStages(ArithmeticOptimizer* optimizer) {
|
||||||
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
|
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
|
||||||
@ -250,6 +255,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
|
|||||||
options.replace_mul_with_square = false;
|
options.replace_mul_with_square = false;
|
||||||
options.simplify_aggregation = false;
|
options.simplify_aggregation = false;
|
||||||
options.unary_ops_composition = false;
|
options.unary_ops_composition = false;
|
||||||
|
options.simplify_embedding_lookup = false;
|
||||||
optimizer->options_ = options;
|
optimizer->options_ = options;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -4416,7 +4416,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
|
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
with self.assertRaisesRegexp(errors.OpError,
|
with self.assertRaisesRegexp(errors.OpError,
|
||||||
r'indices\[0\] = 2 is not in \[0, 2\)'):
|
r'indices\[0\] .* 2 .* \[0, 2\)'):
|
||||||
self.evaluate(embedding_lookup)
|
self.evaluate(embedding_lookup)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user