[TF:XLA] Handle bitcasts between different bitwidths

PiperOrigin-RevId: 251810931
This commit is contained in:
A. Unique TensorFlower 2019-06-06 02:09:32 -07:00 committed by TensorFlower Gardener
parent e3b76130b4
commit 3ed65f64e6
2 changed files with 83 additions and 19 deletions

View File

@ -12,16 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace {
@ -112,23 +117,85 @@ class BitcastOp : public XlaOpKernel {
if (src_dtype_ == dst_dtype_) {
output = input;
} else {
ctx->SetOutput(0, output);
return;
}
// Error out if the bitcast has a complex source or destination type and
// the bitcast is not trivial.
OP_REQUIRES(ctx,
!xla::primitive_util::IsComplexType(src_type_) &&
!xla::primitive_util::IsComplexType(dst_type_),
errors::Unimplemented("Complex types not supported."));
// XLA bitcast requires that the bit-width of the source and destination
// matches, and currently only the simple lowering is performed.
auto input_bit_width = xla::primitive_util::BitWidth(src_type_);
auto output_bit_width = xla::primitive_util::BitWidth(dst_type_);
auto input_logical_type =
xla::primitive_util::UnsignedIntegralTypeForBitWidth(input_bit_width);
auto output_logical_type =
xla::primitive_util::UnsignedIntegralTypeForBitWidth(output_bit_width);
OP_REQUIRES(ctx,
xla::primitive_util::BitWidth(src_type_) ==
xla::primitive_util::BitWidth(dst_type_),
errors::Unimplemented(
"Only bitcasts between equally sized types supported."));
output = xla::BitcastConvertType(input, dst_type_);
output_bit_width % input_bit_width == 0 ||
input_bit_width % output_bit_width == 0,
errors::InvalidArgument(
"Neither bit width is a multiple of the other."));
// Modify the input as needed so we only need to bitcast to create the
// output.
if (input_bit_width > output_bit_width) {
// Casting to a smaller bit width results in a new inner dimension.
auto broadcasted_input_shape = ctx->InputShape(0);
auto reshaped_input_shape = ctx->InputShape(0);
broadcasted_input_shape.AddDim(input_bit_width / output_bit_width);
reshaped_input_shape.AddDim(1);
auto output_bit_width_mask = (1 << output_bit_width) - 1;
auto status_or_input =
BroadcastTo(xla::Reshape(input, reshaped_input_shape.dim_sizes()),
broadcasted_input_shape.dim_sizes());
OP_REQUIRES_OK(ctx, status_or_input.status());
input = xla::BitcastConvertType(status_or_input.ConsumeValueOrDie(),
input_logical_type);
auto xla_input_shape_status = ctx->builder()->GetShape(input);
OP_REQUIRES_OK(ctx, xla_input_shape_status.status());
auto xla_input_shape = xla_input_shape_status.ConsumeValueOrDie();
auto iota = xla::Iota(ctx->builder(), xla_input_shape,
xla_input_shape.dimensions_size() - 1);
xla::XlaOp iota_m =
xla::Mul(xla::ScalarLike(input, output_bit_width), iota);
input = xla::And(xla::ShiftRightLogical(input, iota_m),
xla::ScalarLike(input, output_bit_width_mask));
input = xla::ConvertElementType(input, output_logical_type);
} else if (input_bit_width < output_bit_width) {
// Casting to a larger bit width results in removing the innermost
// dimension.
auto input_shape = ctx->InputShape(0);
xla::Shape input_xla_shape =
TensorShapeToXLAShape(dst_type_, input_shape);
OP_REQUIRES(
ctx,
input_shape.dim_size(input_shape.dims() - 1) ==
output_bit_width / input_bit_width,
errors::InvalidArgument(
"Inner dimension of operand should be removed after cast."));
auto zero = XlaHelpers::Zero(ctx->builder(), dst_dtype_);
input = xla::ConvertElementType(input, dst_type_);
// Shift bits and OR them together to reduce the inner dimension.
xla::XlaOp iota_m =
xla::Mul(xla::ScalarLike(input, input_bit_width),
xla::Iota(ctx->builder(), input_xla_shape,
input_xla_shape.dimensions_size() - 1));
input = xla::ShiftLeft(input, iota_m);
input = xla::Reduce(input, zero,
CreateScalarOrComputation(dst_type_, ctx->builder()),
{input_xla_shape.dimensions_size() - 1});
}
output = xla::BitcastConvertType(input, dst_type_);
ctx->SetOutput(0, output);
}

View File

@ -38,14 +38,12 @@ class BitcastTest(test.TestCase):
self.assertEqual(tf_ans.get_shape(), shape)
self.assertEqual(tf_ans.dtype, datatype)
@test_util.disable_xla("Different bitwidths not supported")
def testSmaller(self):
x = np.random.rand(3, 2)
datatype = dtypes.int8
shape = [3, 2, 8]
self._testBitcast(x, datatype, shape)
@test_util.disable_xla("Different bitwidths not supported")
def testLarger(self):
x = np.arange(16, dtype=np.int8).reshape([4, 4])
datatype = dtypes.int32
@ -69,7 +67,6 @@ class BitcastTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "Cannot bitcast due to shape"):
array_ops.bitcast(x, datatype, None)
@test_util.disable_xla("Different bitwidths not supported")
def testEmpty(self):
x = np.ones([], np.int32)
datatype = dtypes.int8