AddN for variants adds in a tree structure (pairwise summation)
Improves numerical precision (if applicable) using pairwise summation: https://en.wikipedia.org/wiki/Pairwise_summation Thanks to Rasmus Larsen for the succinct binary tree aggregation pseudocode. Also adds AsString for Variant types: this emits the Variant as a string via its DebugString(). PiperOrigin-RevId: 340246073 Change-Id: I009281f46cbea30d6e33ecf79a1723d62e96cc6d
This commit is contained in:
parent
faac5b2fc4
commit
a77e4aec3f
tensorflow
compiler/mlir/tensorflow/ir
core
python/kernel_tests
@ -509,7 +509,7 @@ array([b'3.14', b'2.72'], dtype=object)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Bool, TF_Complex128, TF_Complex64, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$input,
|
||||
TensorOf<[TF_Bool, TF_Complex128, TF_Complex64, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Variant]>:$input,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "-1">:$precision,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$scientific,
|
||||
@ -15226,4 +15226,4 @@ execution the transfer corresponds to.}]>:$dynamic_key,
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
|
||||
}
|
||||
}
|
||||
|
@ -370,24 +370,77 @@ class AddNOp<Device, Variant, OpKernelT, OpKernelConstructionT,
|
||||
i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
|
||||
}
|
||||
|
||||
// Step 2: attempt to add using
|
||||
// Step 2: Sum input variants in a tree-like structure using
|
||||
// BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
|
||||
// For the output create a default-constructed variant object.
|
||||
// TODO(ebrevdo): Perform summation in a tree-structure.
|
||||
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||
Variant* v_out = &(out.scalar<Variant>()());
|
||||
OP_REQUIRES_OK(ctx, BinaryOpVariants<Device>(
|
||||
ctx, ADD_VARIANT_BINARY_OP,
|
||||
ctx->input(0).template scalar<Variant>()(),
|
||||
ctx->input(1).template scalar<Variant>()(), v_out));
|
||||
for (int i = 2; i < num; ++i) {
|
||||
const Variant tmp = std::move(*v_out);
|
||||
const Variant& inp = ctx->input(i).template scalar<Variant>()();
|
||||
OP_REQUIRES_OK(ctx, BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP,
|
||||
inp, tmp, v_out));
|
||||
//
|
||||
// Pairwise summation provides better numerical precision by
|
||||
// reducing round-off error:
|
||||
//
|
||||
// https://en.wikipedia.org/wiki/Pairwise_summation
|
||||
//
|
||||
// These two vectors are used to store and mark intermediate sums.
|
||||
gtl::InlinedVector<bool, 4> temp_filled(num, false);
|
||||
gtl::InlinedVector<Variant, 4> temp(num);
|
||||
|
||||
// Tree-based summation.
|
||||
int skip = 1;
|
||||
int n = num;
|
||||
while (skip < n) {
|
||||
int i = skip;
|
||||
while (i < n) {
|
||||
// TODO(ebrevdo, rmlarsen): Parallelize the pairwise summations in the
|
||||
// inner loop if the variants are "large".
|
||||
|
||||
// x[i - skip] += x[i]
|
||||
OP_REQUIRES_OK(ctx,
|
||||
AddVariantTo(ctx, i - skip, i, &temp, &temp_filled));
|
||||
// We won't use this index again, recover its memory.
|
||||
temp[i].clear();
|
||||
i += 2 * skip;
|
||||
}
|
||||
if (i == n) {
|
||||
// x[0] += x[i - skip]
|
||||
OP_REQUIRES_OK(ctx,
|
||||
AddVariantTo(ctx, 0, i - skip, &temp, &temp_filled));
|
||||
// We won't use this index again, recover its memory.
|
||||
temp[i - skip].clear();
|
||||
n -= skip;
|
||||
}
|
||||
skip *= 2;
|
||||
}
|
||||
|
||||
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||
out.scalar<Variant>()() = std::move(temp[0]);
|
||||
ctx->set_output(0, out);
|
||||
}
|
||||
|
||||
private:
|
||||
// AddVariantTo efficiently performs:
|
||||
// temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
|
||||
// where array(ix) := (temp_filled[ix]
|
||||
// ? temp[ix]
|
||||
// : ctx->input(ix).scalar<Variant>()())
|
||||
// This reduces (possibly expensive) copying of Variants from
|
||||
// the inputs into temp at the lowest levels of the summation tree.
|
||||
static inline Status AddVariantTo(OpKernelContextT* ctx, const int lhs_ix,
|
||||
const int rhs_ix,
|
||||
gtl::InlinedVector<Variant, 4>* temp,
|
||||
gtl::InlinedVector<bool, 4>* temp_filled) {
|
||||
Variant tmp;
|
||||
if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
|
||||
const Variant& a = temp_filled->at(lhs_ix)
|
||||
? tmp
|
||||
: ctx->input(lhs_ix).template scalar<Variant>()();
|
||||
const Variant& b = temp_filled->at(rhs_ix)
|
||||
? temp->at(rhs_ix)
|
||||
: ctx->input(rhs_ix).template scalar<Variant>()();
|
||||
Variant* c = &temp->at(lhs_ix);
|
||||
TF_RETURN_IF_ERROR(
|
||||
BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
|
||||
temp_filled->at(lhs_ix) = true;
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
@ -112,6 +115,8 @@ class AsStringOp : public OpKernel {
|
||||
break;
|
||||
case DT_BOOL:
|
||||
break;
|
||||
case DT_VARIANT:
|
||||
break;
|
||||
default:
|
||||
bool type_not_supported = true;
|
||||
OP_REQUIRES(ctx, !type_not_supported,
|
||||
@ -156,6 +161,12 @@ class AsStringOp : public OpKernel {
|
||||
output_flat(i) = (input_flat(i)) ? "true" : "false";
|
||||
}
|
||||
} break;
|
||||
case (DT_VARIANT): {
|
||||
const auto& input_flat = input_tensor->flat<Variant>();
|
||||
for (int i = 0; i < input_flat.size(); ++i) {
|
||||
output_flat(i) = input_flat(i).DebugString();
|
||||
}
|
||||
} break;
|
||||
case (DT_COMPLEX64): {
|
||||
const auto& input_flat = input_tensor->flat<complex64>();
|
||||
for (int i = 0; i < input_flat.size(); ++i) {
|
||||
|
@ -18,6 +18,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -148,6 +151,25 @@ TEST_F(AsStringGraphTest, Bool) {
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, Variant) {
|
||||
TF_ASSERT_OK(Init(DT_VARIANT));
|
||||
|
||||
AddInput(DT_VARIANT, TensorShape({4}));
|
||||
auto inputs = mutable_input(0)->flat<Variant>();
|
||||
inputs(0) = 2;
|
||||
inputs(1) = 3;
|
||||
inputs(2) = true;
|
||||
inputs(3) = Tensor("hi");
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||
test::FillValues<tstring>(
|
||||
&expected, {"Variant<type: int value: 2>", "Variant<type: int value: 3>",
|
||||
"Variant<type: bool value: 1>",
|
||||
("Variant<type: tensorflow::Tensor value: Tensor<type: string"
|
||||
" shape: [] values: hi>>")});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, String) {
|
||||
Status s = Init(DT_STRING);
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
|
@ -116,7 +116,7 @@ REGISTER_OP("AsString")
|
||||
.Output("output: string")
|
||||
.Attr(
|
||||
"T: {int8, int16, int32, int64, complex64, complex128, float, double, "
|
||||
"bool}")
|
||||
"bool, variant}")
|
||||
.Attr("precision: int = -1")
|
||||
.Attr("scientific: bool = false")
|
||||
.Attr("shortest: bool = false")
|
||||
|
@ -1606,6 +1606,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -26,8 +26,8 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -100,24 +100,28 @@ class AddNTest(test.TestCase):
|
||||
# TODO(ebrevdo): Re-enable use_gpu=True once non-DMA Variant
|
||||
# copying between CPU and GPU is supported.
|
||||
with self.session(use_gpu=False):
|
||||
variant_const_3 = create_constant_variant(3)
|
||||
variant_const_4 = create_constant_variant(4)
|
||||
variant_const_5 = create_constant_variant(5)
|
||||
# 3 + 3 + 5 + 4 = 15.
|
||||
result = math_ops.add_n((variant_const_3, variant_const_3,
|
||||
variant_const_5, variant_const_4))
|
||||
num_tests = 127
|
||||
values = list(range(100))
|
||||
variant_consts = [create_constant_variant(x) for x in values]
|
||||
sum_count_indices = np.random.randint(1, 29, size=num_tests)
|
||||
sum_indices = [
|
||||
np.random.randint(100, size=count) for count in sum_count_indices]
|
||||
expected_sums = [np.sum(x) for x in sum_indices]
|
||||
variant_sums = [math_ops.add_n([variant_consts[i] for i in x])
|
||||
for x in sum_indices]
|
||||
|
||||
# Smoke test -- ensure this executes without trouble.
|
||||
# We use as_string() to get the Variant DebugString for the
|
||||
# variant_sums; we know its value so we can check via string equality
|
||||
# here.
|
||||
#
|
||||
# Right now, non-numpy-compatible objects cannot be returned from a
|
||||
# session.run call; similarly, objects that can't be converted to
|
||||
# native numpy types cannot be passed to ops.convert_to_tensor.
|
||||
# For now, run the test and examine the output to see that the result is
|
||||
# equal to 15.
|
||||
result_op = logging_ops.Print(
|
||||
result, [variant_const_3, variant_const_4, variant_const_5, result],
|
||||
message=("Variants stored an int: c(3), c(4), c(5), "
|
||||
"add_n(c(3), c(3), c(5), c(4)): ")).op
|
||||
result_op.run()
|
||||
variant_sums_string = string_ops.as_string(variant_sums)
|
||||
self.assertAllEqual(
|
||||
variant_sums_string,
|
||||
["Variant<type: int value: {}>".format(s).encode("utf-8")
|
||||
for s in expected_sums])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user