[`numpy.asscalar()` is deprecated since version 1.16](https://github.com/numpy/numpy/blob/master/numpy/lib/type_check.py#L519-L548). This PR replaces its usage with [`numpy.ndarray.item()`](https://www.numpy.org/devdocs/reference/generated/numpy.ndarray.item.html)
96 lines
3.8 KiB
Python
96 lines
3.8 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ======================================
|
|
"""XLA LiteralProto utilities."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as _np # Avoids becoming a part of public Tensorflow API.
|
|
|
|
from tensorflow.compiler.xla import xla_data_pb2
|
|
from tensorflow.compiler.xla.python_api import types
|
|
from tensorflow.compiler.xla.python_api import xla_shape
|
|
|
|
|
|
def ConvertLiteralToNumpyArray(literal):
|
|
"""Converts a XLA literal to a Numpy array."""
|
|
element_type = literal.shape.element_type
|
|
if element_type == xla_data_pb2.TUPLE:
|
|
return tuple(
|
|
ConvertLiteralToNumpyArray(subliteral)
|
|
for subliteral in literal.tuple_literals)
|
|
|
|
type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type]
|
|
if not literal.shape.dimensions:
|
|
return _np.array(
|
|
getattr(literal, type_record.literal_field_name)[0],
|
|
type_record.numpy_dtype)
|
|
else:
|
|
# Infer the proper Numpy order from the LiteralProto's layout. The repeated
|
|
# field representing the array's content in the Literal is linearized.
|
|
# Reading is done in two steps:
|
|
#
|
|
# 1. Read the array as 1D from the LiteralProto repeated field.
|
|
# 2. Reshape the array to its proper shape, using the right order depending
|
|
# on the LiteralProto's layout.
|
|
layout_order = literal.shape.layout.minor_to_major
|
|
numpy_shape = tuple(literal.shape.dimensions)
|
|
if layout_order == range(len(literal.shape.dimensions)):
|
|
numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='F')
|
|
elif layout_order == range(len(literal.shape.dimensions) - 1, -1, -1):
|
|
numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C')
|
|
else:
|
|
raise NotImplementedError('Unsupported layout: {0}'.format(layout_order))
|
|
ndarray = _np.array(
|
|
getattr(literal, type_record.literal_field_name),
|
|
copy=False,
|
|
dtype=type_record.numpy_dtype)
|
|
return numpy_reshaper(ndarray)
|
|
|
|
|
|
def _ConvertNumpyArrayToLiteral(ndarray):
|
|
"""Converts a Numpy array to a XLA literal."""
|
|
type_record = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)]
|
|
literal = xla_data_pb2.LiteralProto()
|
|
literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message)
|
|
|
|
if ndarray.ndim == 0:
|
|
getattr(literal, type_record.literal_field_name).append(
|
|
ndarray.astype(type_record.literal_field_type).item())
|
|
else:
|
|
# Ndarrays with boolean dtypes need special type conversion with protobufs
|
|
if ndarray.dtype in {_np.bool_, _np.dtype('bool')}:
|
|
for element in _np.nditer(ndarray):
|
|
getattr(literal, type_record.literal_field_name).append(
|
|
type_record.literal_field_type(element))
|
|
else:
|
|
ndarray_flat = ndarray.ravel(order='A')
|
|
getattr(literal, type_record.literal_field_name).extend(ndarray_flat)
|
|
return literal
|
|
|
|
|
|
def ConvertNumpyArrayToLiteral(value):
|
|
"""Converts a Numpy array or a nested tuple thereof to an XLA literal."""
|
|
if isinstance(value, tuple):
|
|
literal = xla_data_pb2.LiteralProto()
|
|
literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message)
|
|
for component in value:
|
|
component_literal = literal.tuple_literals.add()
|
|
component_literal.CopyFrom(ConvertNumpyArrayToLiteral(component))
|
|
return literal
|
|
else:
|
|
return _ConvertNumpyArrayToLiteral(value)
|