Add legalizations for RngReadAndSkip and StatelessRandomUniformFullIntV2.

PiperOrigin-RevId: 357766922
Change-Id: I6f3806d714cbeb62e5deafe580c9641b1f0a9e60
This commit is contained in:
Richard Uhler 2021-02-16 11:34:14 -08:00 committed by TensorFlower Gardener
parent e2cd9cda80
commit 3af7f6d31f
7 changed files with 240 additions and 0 deletions

View File

@ -698,6 +698,7 @@ cc_library(
":decompose_resource_ops_inc_gen",
":tensorflow",
":tensorflow_types",
"//tensorflow/core:framework",
"@llvm-project//mlir:IR",
],
)

View File

@ -12646,6 +12646,27 @@ def TF_RiscDotOp : TF_Op<"RiscDot", [NoSideEffect]> {
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_RngReadAndSkipOp : TF_Op<"RngReadAndSkip", []> {
let summary = "Advance the counter of a counter-based RNG.";
let description = [{
The state of the RNG after
`rng_read_and_skip(n)` will be the same as that after `uniform([n])`
(or any other distribution). The actual increment added to the
counter is an unspecified implementation choice.
}];
let arguments = (ins
TF_ResourceTensor:$resource,
TF_Int32Tensor:$alg,
TF_Uint64Tensor:$delta
);
let results = (outs
TF_Int64Tensor:$value
);
}
def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> {
let summary = "Rolls the elements of a tensor along an axis.";
@ -15011,6 +15032,32 @@ The outputs are a deterministic function of `shape` and `seed`.
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatelessRandomUniformFullIntV2Op : TF_Op<"StatelessRandomUniformFullIntV2", [NoSideEffect]> {
let summary = [{
Outputs deterministic pseudorandom random integers from a uniform distribution.
}];
let description = [{
The generated values are uniform integers covering the whole range of `dtype`.
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
}];
let arguments = (ins
Arg<TF_I32OrI64Tensor, [{The shape of the output tensor.}]>:$shape,
Arg<TF_Uint64Tensor, [{Key for the counter-based RNG algorithm (shape uint64[1]).}]>:$key,
Arg<TF_Uint64Tensor, [{Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.}]>:$counter,
Arg<TF_Int32Tensor, [{The RNG algorithm (shape int32[]).}]>:$alg
);
let results = (outs
Res<TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>, [{Random values with specified shape.}]>:$output
);
TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
Outputs deterministic pseudorandom random integers from a uniform distribution.

View File

@ -568,3 +568,17 @@ func @decompose_resource_apply_proximal_adagrad_op(%lr: tensor<f32>, %l1: tensor
return
}
// -----
// Test that tf.RngReadAndSkip op is decomposed.
// CHECK-LABEL: func @decompose_rng_read_and_skip_op
func @decompose_rng_read_and_skip_op(%resource: tensor<!tf.resource<tensor<3xi64>>>) -> tensor<3xi64> {
// We rely on the TensorFlow StatefulRandomOpsTest to check it is lowered
// correctly.
// CHECK-NOT: tf.RngReadAndSkip
%alg = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%delta = "tf.Const"() {value = dense<10> : tensor<ui64>} : () -> tensor<ui64>
%0 = "tf.RngReadAndSkip"(%resource, %alg, %delta) : (tensor<!tf.resource<tensor<3xi64>>>, tensor<i32>, tensor<ui64>) -> tensor<3xi64>
return %0 : tensor<3xi64>
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/framework/rng_alg.h"
namespace mlir {
namespace TF {
@ -68,11 +69,152 @@ static Type GetResourceSubtype(Value resource) {
.front();
}
// Decompose tf.RngReadAndSkip.
//
// For Philox, the resource variable holds a tensor<3xi64> with the state:
// [counter_lo, counter_hi, key]
//
// RngReadAndSkip increments the 128 bit counter value by 256 * delta and
// returns the original state value.
//
// For Threefry, the resource variable holds a tensor<2xi64> with the state:
// [counter, key]
//
// RngReadAndSkip increments the 64 bit counter value by 256 * delta and
// returns a tensor<3xi64> value [counter, key, 0].
class DecomposeRngReadAndSkipOp : public RewritePattern {
public:
explicit DecomposeRngReadAndSkipOp(MLIRContext *context)
: RewritePattern(RngReadAndSkipOp::getOperationName(),
{
AddV2Op::getOperationName(),
AssignVariableOp::getOperationName(),
CastOp::getOperationName(),
ConstOp::getOperationName(),
LessOp::getOperationName(),
MulOp::getOperationName(),
PadOp::getOperationName(),
PackOp::getOperationName(),
ReadVariableOp::getOperationName(),
SelectV2Op::getOperationName(),
UnpackOp::getOperationName(),
},
1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto rng_op = cast<RngReadAndSkipOp>(op);
DenseIntElementsAttr alg_constant;
if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
op->emitOpError() << "unable to determine algorithm statically";
return failure();
}
if (alg_constant.getNumElements() != 1) {
op->emitOpError() << "expected alg to be a scalar";
return failure();
}
uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
tensorflow::Algorithm alg;
if (tensorflow::RNG_ALG_PHILOX == alg_value) {
alg = tensorflow::RNG_ALG_PHILOX;
} else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
alg = tensorflow::RNG_ALG_THREEFRY;
} else {
op->emitOpError() << "unsupported alg";
return failure();
}
Type state_element_type = rewriter.getI64Type();
RankedTensorType op_type = RankedTensorType::get(
{tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
state_element_type);
if (op_type != rng_op.getType()) {
op->emitOpError() << "unexpected op type";
return failure();
}
if (!HasResourceSubtype(rng_op.resource())) {
op->emitOpError() << "missing resource subtype";
return failure();
}
int counter_size = tensorflow::GetCounterSize(alg);
int state_size = counter_size + tensorflow::RNG_KEY_SIZE;
RankedTensorType res_type =
RankedTensorType::get({state_size}, state_element_type);
if (res_type != GetResourceSubtype(rng_op.resource())) {
op->emitOpError() << "unexpected resource subtype";
return failure();
}
Location loc = op->getLoc();
// Read the state value from the resource.
Value state =
rewriter.create<ReadVariableOp>(loc, res_type, rng_op.resource());
// Extract the key and counter from the state.
RankedTensorType word_type = RankedTensorType::get({}, state_element_type);
auto unpacked = rewriter.create<UnpackOp>(
loc, SmallVector<Type, 4>(state_size, word_type), state, 0);
Value key = unpacked.getResult(counter_size);
SmallVector<Value, 4> counter;
for (int i = 0; i < counter_size; ++i) {
counter.push_back(unpacked.getResult(i));
}
// Set the increment to 256 * delta.
Type u64 = rewriter.getIntegerType(64, /*isSigned=*/false);
RankedTensorType u64_scalar = RankedTensorType::get({}, u64);
Value step_size = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 256));
Value increment =
rewriter.create<MulOp>(loc, u64_scalar, step_size, rng_op.delta());
// Increment the counter.
SmallVector<Value, 4> pack_args;
RankedTensorType word_u64_type = RankedTensorType::get({}, u64);
Value zero_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 0));
Value one_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 1));
for (int i = 0; i < counter_size; ++i) {
Value word = counter[i];
Value word_u64 = rewriter.create<CastOp>(loc, word_u64_type, word);
Value new_word_u64 = rewriter.create<AddV2Op>(loc, word_u64, increment);
Value new_word = rewriter.create<CastOp>(loc, word_type, new_word_u64);
pack_args.push_back(new_word);
Value overflow = rewriter.create<LessOp>(loc, new_word_u64, word_u64);
increment = rewriter.create<SelectV2Op>(loc, overflow, one_u64, zero_u64);
}
// Save the new state value to the resource.
pack_args.push_back(key);
Value new_state = rewriter.create<PackOp>(loc, res_type, pack_args);
rewriter.create<AssignVariableOp>(loc, rng_op.resource(), new_state);
// Pad the original state as necessary to fill the output shape.
int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size;
Type i64 = rewriter.getI64Type();
RankedTensorType paddings_ty = RankedTensorType::get({1, 2}, i64);
std::vector<int64_t> paddings_values = {0, pad};
Value paddings = rewriter.create<ConstOp>(
loc, DenseIntElementsAttr::get(paddings_ty, paddings_values));
Value output = rewriter.create<PadOp>(loc, op_type, state, paddings);
rewriter.replaceOp(op, output);
return success();
}
};
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
} // namespace
void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns->insert<DecomposeRngReadAndSkipOp>(context);
populateWithGenerated(context, *patterns);
}

View File

@ -241,6 +241,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::StatelessRandomNormalV2Op>(),
TypeID::get<TF::StatelessRandomUniformOp>(),
TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
TypeID::get<TF::StatelessRandomUniformV2Op>(),
TypeID::get<TF::StatelessRandomUniformIntOp>(),
TypeID::get<TF::StatelessRandomUniformIntV2Op>(),

View File

@ -1277,6 +1277,7 @@ tf_xla_py_test(
name = "stateful_random_ops_test",
size = "medium",
srcs = ["stateful_random_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 10,
tags = [

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.kernel_tests.random import util as \
random_test_util
from tensorflow.python.ops import gen_stateful_random_ops
@ -76,6 +77,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
gen.uniform_full_int(shape=(3,))
@parameterized.parameters(ALGS)
@test_util.disable_mlir_bridge("TODO(b/180412889): Crashes with MLIR bridge.")
def testDefun(self, alg):
"""Test for defun."""
with ops.device(xla_device_name()):
@ -248,6 +250,36 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
shape=shape, dtype=dtype))
self.assertAllEqual(cpu, xla)
def testXLAEqualsCPUAroundCounterOverflow(self):
"""Tests XLA and CPU kernels generate the same integers in overflow case.
Specifically this tests the case where the counter is incremented past
what can fit within 64 bits of the 128 bit Philox counter.
"""
dtype = dtypes.uint64
seed = 2**64 - 10
shape = [315, 49]
if compat.forward_compatible(2020, 10, 25):
with ops.device("/device:CPU:0"):
cpu_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
with ops.device(xla_device_name()):
xla_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
# Repeat multiple times to make sure that the state after
# number-generation are the same between CPU and XLA.
for _ in range(5):
with ops.device("/device:CPU:0"):
# Test both number-generation and skip
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
cpu_gen.skip(100)
with ops.device(xla_device_name()):
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
xla_gen.skip(100)
self.assertAllEqual(cpu, xla)
self.assertAllEqual(cpu_gen.state, xla_gen.state)
self.assertAllEqual(cpu, xla)
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
# The random-number generator, if working correctly, should produce the
@ -352,6 +384,8 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
mean_atol=2e-3, median_atol=4e-3,
variance_rtol=1e-2 if dtype == dtypes.bfloat16 else 5e-3)
@test_util.disable_mlir_bridge(
"b/180412086: MLIR bridge gives wrong error messages.")
def testErrors(self):
"""Tests that proper errors are raised.
"""