Add legalizations for RngReadAndSkip and StatelessRandomUniformFullIntV2.
PiperOrigin-RevId: 357766922 Change-Id: I6f3806d714cbeb62e5deafe580c9641b1f0a9e60
This commit is contained in:
parent
e2cd9cda80
commit
3af7f6d31f
tensorflow/compiler
mlir
tensorflow
xla/transforms
tests
@ -698,6 +698,7 @@ cc_library(
|
|||||||
":decompose_resource_ops_inc_gen",
|
":decompose_resource_ops_inc_gen",
|
||||||
":tensorflow",
|
":tensorflow",
|
||||||
":tensorflow_types",
|
":tensorflow_types",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -12646,6 +12646,27 @@ def TF_RiscDotOp : TF_Op<"RiscDot", [NoSideEffect]> {
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_RngReadAndSkipOp : TF_Op<"RngReadAndSkip", []> {
|
||||||
|
let summary = "Advance the counter of a counter-based RNG.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The state of the RNG after
|
||||||
|
`rng_read_and_skip(n)` will be the same as that after `uniform([n])`
|
||||||
|
(or any other distribution). The actual increment added to the
|
||||||
|
counter is an unspecified implementation choice.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_ResourceTensor:$resource,
|
||||||
|
TF_Int32Tensor:$alg,
|
||||||
|
TF_Uint64Tensor:$delta
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Int64Tensor:$value
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> {
|
def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> {
|
||||||
let summary = "Rolls the elements of a tensor along an axis.";
|
let summary = "Rolls the elements of a tensor along an axis.";
|
||||||
|
|
||||||
@ -15011,6 +15032,32 @@ The outputs are a deterministic function of `shape` and `seed`.
|
|||||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_StatelessRandomUniformFullIntV2Op : TF_Op<"StatelessRandomUniformFullIntV2", [NoSideEffect]> {
|
||||||
|
let summary = [{
|
||||||
|
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The generated values are uniform integers covering the whole range of `dtype`.
|
||||||
|
|
||||||
|
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<TF_I32OrI64Tensor, [{The shape of the output tensor.}]>:$shape,
|
||||||
|
Arg<TF_Uint64Tensor, [{Key for the counter-based RNG algorithm (shape uint64[1]).}]>:$key,
|
||||||
|
Arg<TF_Uint64Tensor, [{Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.}]>:$counter,
|
||||||
|
Arg<TF_Int32Tensor, [{The RNG algorithm (shape int32[]).}]>:$alg
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>, [{Random values with specified shape.}]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect, TF_NoConstantFold]> {
|
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
||||||
|
@ -568,3 +568,17 @@ func @decompose_resource_apply_proximal_adagrad_op(%lr: tensor<f32>, %l1: tensor
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test that tf.RngReadAndSkip op is decomposed.
|
||||||
|
// CHECK-LABEL: func @decompose_rng_read_and_skip_op
|
||||||
|
func @decompose_rng_read_and_skip_op(%resource: tensor<!tf.resource<tensor<3xi64>>>) -> tensor<3xi64> {
|
||||||
|
// We rely on the TensorFlow StatefulRandomOpsTest to check it is lowered
|
||||||
|
// correctly.
|
||||||
|
// CHECK-NOT: tf.RngReadAndSkip
|
||||||
|
%alg = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%delta = "tf.Const"() {value = dense<10> : tensor<ui64>} : () -> tensor<ui64>
|
||||||
|
%0 = "tf.RngReadAndSkip"(%resource, %alg, %delta) : (tensor<!tf.resource<tensor<3xi64>>>, tensor<i32>, tensor<ui64>) -> tensor<3xi64>
|
||||||
|
return %0 : tensor<3xi64>
|
||||||
|
}
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
#include "tensorflow/core/framework/rng_alg.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TF {
|
namespace TF {
|
||||||
@ -68,11 +69,152 @@ static Type GetResourceSubtype(Value resource) {
|
|||||||
.front();
|
.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Decompose tf.RngReadAndSkip.
|
||||||
|
//
|
||||||
|
// For Philox, the resource variable holds a tensor<3xi64> with the state:
|
||||||
|
// [counter_lo, counter_hi, key]
|
||||||
|
//
|
||||||
|
// RngReadAndSkip increments the 128 bit counter value by 256 * delta and
|
||||||
|
// returns the original state value.
|
||||||
|
//
|
||||||
|
// For Threefry, the resource variable holds a tensor<2xi64> with the state:
|
||||||
|
// [counter, key]
|
||||||
|
//
|
||||||
|
// RngReadAndSkip increments the 64 bit counter value by 256 * delta and
|
||||||
|
// returns a tensor<3xi64> value [counter, key, 0].
|
||||||
|
class DecomposeRngReadAndSkipOp : public RewritePattern {
|
||||||
|
public:
|
||||||
|
explicit DecomposeRngReadAndSkipOp(MLIRContext *context)
|
||||||
|
: RewritePattern(RngReadAndSkipOp::getOperationName(),
|
||||||
|
{
|
||||||
|
AddV2Op::getOperationName(),
|
||||||
|
AssignVariableOp::getOperationName(),
|
||||||
|
CastOp::getOperationName(),
|
||||||
|
ConstOp::getOperationName(),
|
||||||
|
LessOp::getOperationName(),
|
||||||
|
MulOp::getOperationName(),
|
||||||
|
PadOp::getOperationName(),
|
||||||
|
PackOp::getOperationName(),
|
||||||
|
ReadVariableOp::getOperationName(),
|
||||||
|
SelectV2Op::getOperationName(),
|
||||||
|
UnpackOp::getOperationName(),
|
||||||
|
},
|
||||||
|
1, context) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto rng_op = cast<RngReadAndSkipOp>(op);
|
||||||
|
|
||||||
|
DenseIntElementsAttr alg_constant;
|
||||||
|
if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
|
||||||
|
op->emitOpError() << "unable to determine algorithm statically";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (alg_constant.getNumElements() != 1) {
|
||||||
|
op->emitOpError() << "expected alg to be a scalar";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
|
||||||
|
tensorflow::Algorithm alg;
|
||||||
|
if (tensorflow::RNG_ALG_PHILOX == alg_value) {
|
||||||
|
alg = tensorflow::RNG_ALG_PHILOX;
|
||||||
|
} else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
|
||||||
|
alg = tensorflow::RNG_ALG_THREEFRY;
|
||||||
|
} else {
|
||||||
|
op->emitOpError() << "unsupported alg";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Type state_element_type = rewriter.getI64Type();
|
||||||
|
RankedTensorType op_type = RankedTensorType::get(
|
||||||
|
{tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
|
||||||
|
state_element_type);
|
||||||
|
if (op_type != rng_op.getType()) {
|
||||||
|
op->emitOpError() << "unexpected op type";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!HasResourceSubtype(rng_op.resource())) {
|
||||||
|
op->emitOpError() << "missing resource subtype";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
int counter_size = tensorflow::GetCounterSize(alg);
|
||||||
|
int state_size = counter_size + tensorflow::RNG_KEY_SIZE;
|
||||||
|
RankedTensorType res_type =
|
||||||
|
RankedTensorType::get({state_size}, state_element_type);
|
||||||
|
if (res_type != GetResourceSubtype(rng_op.resource())) {
|
||||||
|
op->emitOpError() << "unexpected resource subtype";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
|
||||||
|
// Read the state value from the resource.
|
||||||
|
Value state =
|
||||||
|
rewriter.create<ReadVariableOp>(loc, res_type, rng_op.resource());
|
||||||
|
|
||||||
|
// Extract the key and counter from the state.
|
||||||
|
RankedTensorType word_type = RankedTensorType::get({}, state_element_type);
|
||||||
|
auto unpacked = rewriter.create<UnpackOp>(
|
||||||
|
loc, SmallVector<Type, 4>(state_size, word_type), state, 0);
|
||||||
|
Value key = unpacked.getResult(counter_size);
|
||||||
|
|
||||||
|
SmallVector<Value, 4> counter;
|
||||||
|
for (int i = 0; i < counter_size; ++i) {
|
||||||
|
counter.push_back(unpacked.getResult(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the increment to 256 * delta.
|
||||||
|
Type u64 = rewriter.getIntegerType(64, /*isSigned=*/false);
|
||||||
|
RankedTensorType u64_scalar = RankedTensorType::get({}, u64);
|
||||||
|
Value step_size = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 256));
|
||||||
|
Value increment =
|
||||||
|
rewriter.create<MulOp>(loc, u64_scalar, step_size, rng_op.delta());
|
||||||
|
|
||||||
|
// Increment the counter.
|
||||||
|
SmallVector<Value, 4> pack_args;
|
||||||
|
RankedTensorType word_u64_type = RankedTensorType::get({}, u64);
|
||||||
|
Value zero_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 0));
|
||||||
|
Value one_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 1));
|
||||||
|
for (int i = 0; i < counter_size; ++i) {
|
||||||
|
Value word = counter[i];
|
||||||
|
Value word_u64 = rewriter.create<CastOp>(loc, word_u64_type, word);
|
||||||
|
Value new_word_u64 = rewriter.create<AddV2Op>(loc, word_u64, increment);
|
||||||
|
Value new_word = rewriter.create<CastOp>(loc, word_type, new_word_u64);
|
||||||
|
pack_args.push_back(new_word);
|
||||||
|
|
||||||
|
Value overflow = rewriter.create<LessOp>(loc, new_word_u64, word_u64);
|
||||||
|
increment = rewriter.create<SelectV2Op>(loc, overflow, one_u64, zero_u64);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the new state value to the resource.
|
||||||
|
pack_args.push_back(key);
|
||||||
|
Value new_state = rewriter.create<PackOp>(loc, res_type, pack_args);
|
||||||
|
rewriter.create<AssignVariableOp>(loc, rng_op.resource(), new_state);
|
||||||
|
|
||||||
|
// Pad the original state as necessary to fill the output shape.
|
||||||
|
int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size;
|
||||||
|
Type i64 = rewriter.getI64Type();
|
||||||
|
RankedTensorType paddings_ty = RankedTensorType::get({1, 2}, i64);
|
||||||
|
std::vector<int64_t> paddings_values = {0, pad};
|
||||||
|
Value paddings = rewriter.create<ConstOp>(
|
||||||
|
loc, DenseIntElementsAttr::get(paddings_ty, paddings_values));
|
||||||
|
Value output = rewriter.create<PadOp>(loc, op_type, state, paddings);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, output);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
|
void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
|
patterns->insert<DecomposeRngReadAndSkipOp>(context);
|
||||||
populateWithGenerated(context, *patterns);
|
populateWithGenerated(context, *patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,6 +241,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
|||||||
TypeID::get<TF::StatelessRandomNormalV2Op>(),
|
TypeID::get<TF::StatelessRandomNormalV2Op>(),
|
||||||
TypeID::get<TF::StatelessRandomUniformOp>(),
|
TypeID::get<TF::StatelessRandomUniformOp>(),
|
||||||
TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
|
TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
|
||||||
|
TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
|
||||||
TypeID::get<TF::StatelessRandomUniformV2Op>(),
|
TypeID::get<TF::StatelessRandomUniformV2Op>(),
|
||||||
TypeID::get<TF::StatelessRandomUniformIntOp>(),
|
TypeID::get<TF::StatelessRandomUniformIntOp>(),
|
||||||
TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
|
TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
|
||||||
|
@ -1277,6 +1277,7 @@ tf_xla_py_test(
|
|||||||
name = "stateful_random_ops_test",
|
name = "stateful_random_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["stateful_random_ops_test.py"],
|
srcs = ["stateful_random_ops_test.py"],
|
||||||
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 10,
|
shard_count = 10,
|
||||||
tags = [
|
tags = [
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import config
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.kernel_tests.random import util as \
|
from tensorflow.python.kernel_tests.random import util as \
|
||||||
random_test_util
|
random_test_util
|
||||||
from tensorflow.python.ops import gen_stateful_random_ops
|
from tensorflow.python.ops import gen_stateful_random_ops
|
||||||
@ -76,6 +77,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
gen.uniform_full_int(shape=(3,))
|
gen.uniform_full_int(shape=(3,))
|
||||||
|
|
||||||
@parameterized.parameters(ALGS)
|
@parameterized.parameters(ALGS)
|
||||||
|
@test_util.disable_mlir_bridge("TODO(b/180412889): Crashes with MLIR bridge.")
|
||||||
def testDefun(self, alg):
|
def testDefun(self, alg):
|
||||||
"""Test for defun."""
|
"""Test for defun."""
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
@ -248,6 +250,36 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
shape=shape, dtype=dtype))
|
shape=shape, dtype=dtype))
|
||||||
self.assertAllEqual(cpu, xla)
|
self.assertAllEqual(cpu, xla)
|
||||||
|
|
||||||
|
def testXLAEqualsCPUAroundCounterOverflow(self):
|
||||||
|
"""Tests XLA and CPU kernels generate the same integers in overflow case.
|
||||||
|
|
||||||
|
Specifically this tests the case where the counter is incremented past
|
||||||
|
what can fit within 64 bits of the 128 bit Philox counter.
|
||||||
|
"""
|
||||||
|
dtype = dtypes.uint64
|
||||||
|
seed = 2**64 - 10
|
||||||
|
shape = [315, 49]
|
||||||
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
|
with ops.device("/device:CPU:0"):
|
||||||
|
cpu_gen = random.Generator.from_seed(
|
||||||
|
seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
xla_gen = random.Generator.from_seed(
|
||||||
|
seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||||
|
# Repeat multiple times to make sure that the state after
|
||||||
|
# number-generation are the same between CPU and XLA.
|
||||||
|
for _ in range(5):
|
||||||
|
with ops.device("/device:CPU:0"):
|
||||||
|
# Test both number-generation and skip
|
||||||
|
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
|
cpu_gen.skip(100)
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
|
xla_gen.skip(100)
|
||||||
|
self.assertAllEqual(cpu, xla)
|
||||||
|
self.assertAllEqual(cpu_gen.state, xla_gen.state)
|
||||||
|
self.assertAllEqual(cpu, xla)
|
||||||
|
|
||||||
def _testRngIsNotConstant(self, rng, dtype):
|
def _testRngIsNotConstant(self, rng, dtype):
|
||||||
# Tests that 'rng' does not always return the same value.
|
# Tests that 'rng' does not always return the same value.
|
||||||
# The random-number generator, if working correctly, should produce the
|
# The random-number generator, if working correctly, should produce the
|
||||||
@ -352,6 +384,8 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
mean_atol=2e-3, median_atol=4e-3,
|
mean_atol=2e-3, median_atol=4e-3,
|
||||||
variance_rtol=1e-2 if dtype == dtypes.bfloat16 else 5e-3)
|
variance_rtol=1e-2 if dtype == dtypes.bfloat16 else 5e-3)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"b/180412086: MLIR bridge gives wrong error messages.")
|
||||||
def testErrors(self):
|
def testErrors(self):
|
||||||
"""Tests that proper errors are raised.
|
"""Tests that proper errors are raised.
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user