Delete some unnecessary code.

PiperOrigin-RevId: 168026197
This commit is contained in:
A. Unique TensorFlower 2017-09-08 11:54:09 -07:00 committed by TensorFlower Gardener
parent c17096fe77
commit affbc9b7b3
5 changed files with 42 additions and 50 deletions

View File

@ -33,7 +33,7 @@ from tensorflow.python.framework import test_util
def truncated_normal(shape): def truncated_normal(shape):
return execute.execute( return execute.execute(
'TruncatedNormal', b'TruncatedNormal',
1, 1,
inputs=[shape], inputs=[shape],
attrs=('dtype', dtypes.float32.as_datatype_enum, 'T', attrs=('dtype', dtypes.float32.as_datatype_enum, 'T',
@ -118,7 +118,7 @@ class TFETest(test_util.TensorFlowTestCase):
y = tensor.Tensor(2.) y = tensor.Tensor(2.)
# Add would fail if t2 were not on GPU # Add would fail if t2 were not on GPU
result = execute.execute( result = execute.execute(
'Add', 1, inputs=[x, y], b'Add', 1, inputs=[x, y],
attrs=('T', x.dtype.as_datatype_enum))[0].as_cpu_tensor().numpy() attrs=('T', x.dtype.as_datatype_enum))[0].as_cpu_tensor().numpy()
self.assertEqual(3, result) self.assertEqual(3, result)
@ -161,7 +161,7 @@ class TFETest(test_util.TensorFlowTestCase):
three = tensor.Tensor(3) three = tensor.Tensor(3)
five = tensor.Tensor(5) five = tensor.Tensor(5)
product = execute.execute( product = execute.execute(
'Mul', b'Mul',
num_outputs=1, num_outputs=1,
inputs=[three, five], inputs=[three, five],
attrs=('T', three.dtype.as_datatype_enum))[0] attrs=('T', three.dtype.as_datatype_enum))[0]
@ -171,7 +171,7 @@ class TFETest(test_util.TensorFlowTestCase):
# num_outputs provided is 50, but only one output is produced. # num_outputs provided is 50, but only one output is produced.
# That should be okay. # That should be okay.
product = execute.execute( product = execute.execute(
'Mul', b'Mul',
num_outputs=50, num_outputs=50,
inputs=[tensor.Tensor(3), tensor.Tensor(5)], inputs=[tensor.Tensor(3), tensor.Tensor(5)],
attrs=('T', dtypes.int32.as_datatype_enum))[0] attrs=('T', dtypes.int32.as_datatype_enum))[0]
@ -183,7 +183,7 @@ class TFETest(test_util.TensorFlowTestCase):
three = tensor.Tensor([[3.]]).as_gpu_tensor() three = tensor.Tensor([[3.]]).as_gpu_tensor()
five = tensor.Tensor([[5.]]).as_gpu_tensor() five = tensor.Tensor([[5.]]).as_gpu_tensor()
product = execute.execute( product = execute.execute(
'MatMul', b'MatMul',
num_outputs=1, num_outputs=1,
inputs=[three, five], inputs=[three, five],
attrs=('transpose_a', False, 'transpose_b', False, 'T', attrs=('transpose_a', False, 'transpose_b', False, 'T',
@ -192,7 +192,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteStringAttr(self): def testExecuteStringAttr(self):
checked_three = execute.execute( checked_three = execute.execute(
'CheckNumerics', b'CheckNumerics',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor(3.)], inputs=[tensor.Tensor(3.)],
attrs=('message', 'just checking', 'T', attrs=('message', 'just checking', 'T',
@ -202,14 +202,14 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteStringAttrBadValue(self): def testExecuteStringAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
_ = execute.execute( _ = execute.execute(
'CheckNumerics', b'CheckNumerics',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor(3.)], inputs=[tensor.Tensor(3.)],
attrs=('message', 1, 'T', dtypes.float32.as_datatype_enum)) attrs=('message', 1, 'T', dtypes.float32.as_datatype_enum))
def testExecuteFloatAttr(self): def testExecuteFloatAttr(self):
almost_equal = execute.execute( almost_equal = execute.execute(
'ApproximateEqual', b'ApproximateEqual',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)], inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)],
attrs=('tolerance', 0.3, 'T', dtypes.float32.as_datatype_enum))[0] attrs=('tolerance', 0.3, 'T', dtypes.float32.as_datatype_enum))[0]
@ -218,14 +218,14 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteFloatAttrBadValue(self): def testExecuteFloatAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
_ = execute.execute( _ = execute.execute(
'ApproximateEqual', b'ApproximateEqual',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)], inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)],
attrs=('tolerance', '0.3', 'T', dtypes.float32.as_datatype_enum)) attrs=('tolerance', '0.3', 'T', dtypes.float32.as_datatype_enum))
def testExecuteIntAttr(self): def testExecuteIntAttr(self):
total = execute.execute( total = execute.execute(
'AddN', b'AddN',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor(3), tensor.Tensor(4)], inputs=[tensor.Tensor(3), tensor.Tensor(4)],
attrs=('T', dtypes.int32.as_datatype_enum, 'N', 2))[0] attrs=('T', dtypes.int32.as_datatype_enum, 'N', 2))[0]
@ -234,7 +234,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteIntAttrBadValue(self): def testExecuteIntAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
_ = execute.execute( _ = execute.execute(
'AddN', b'AddN',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor(3), tensor.Tensor(4)], inputs=[tensor.Tensor(3), tensor.Tensor(4)],
attrs=('T', dtypes.int32.as_datatype_enum, 'N', '2')) attrs=('T', dtypes.int32.as_datatype_enum, 'N', '2'))
@ -242,7 +242,7 @@ class TFETest(test_util.TensorFlowTestCase):
# Looks like we don't have an existing op with list(bool) attrs. # Looks like we don't have an existing op with list(bool) attrs.
def testExecuteBoolAttr(self): def testExecuteBoolAttr(self):
product = execute.execute( product = execute.execute(
'MatMul', b'MatMul',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor([[3]]), inputs=[tensor.Tensor([[3]]),
tensor.Tensor([[5]])], tensor.Tensor([[5]])],
@ -252,7 +252,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteShapeAttr(self): def testExecuteShapeAttr(self):
execute.execute( execute.execute(
'VarHandleOp', b'VarHandleOp',
num_outputs=1, num_outputs=1,
inputs=[], inputs=[],
attrs=('shape', [1, 2], 'dtype', dtypes.int32.as_datatype_enum, attrs=('shape', [1, 2], 'dtype', dtypes.int32.as_datatype_enum,
@ -261,7 +261,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteShapeAttrBadValue(self): def testExecuteShapeAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'VarHandleOp', b'VarHandleOp',
num_outputs=1, num_outputs=1,
inputs=[], inputs=[],
attrs=('shape', 1, 'dtype', dtypes.int32.as_datatype_enum, attrs=('shape', 1, 'dtype', dtypes.int32.as_datatype_enum,
@ -269,7 +269,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListStringAttr(self): def testExecuteListStringAttr(self):
execute.execute( execute.execute(
'TensorSummary', b'TensorSummary',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor(3.0)], inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum, 'description', attrs=('T', dtypes.float32.as_datatype_enum, 'description',
@ -279,7 +279,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListStringAttrBadValue(self): def testExecuteListStringAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'TensorSummary', b'TensorSummary',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor(3.0)], inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum, 'description', '', attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
@ -288,7 +288,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListStringAttrBadListValue(self): def testExecuteListStringAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'TensorSummary', b'TensorSummary',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor(3.0)], inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum, 'description', '', attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
@ -296,7 +296,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListFloatAttr(self): def testExecuteListFloatAttr(self):
b = execute.execute( b = execute.execute(
'Bucketize', b'Bucketize',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor([3.0, 5.0, 7.0])], inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', [4.0, attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', [4.0,
@ -306,7 +306,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListFloatAttrBadValue(self): def testExecuteListFloatAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'Bucketize', b'Bucketize',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor([3.0, 5.0, 7.0])], inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 4.0)) attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 4.0))
@ -314,7 +314,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListFloatAttrBadListValue(self): def testExecuteListFloatAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'Bucketize', b'Bucketize',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor([3.0, 5.0, 7.0])], inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries',
@ -322,7 +322,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListIntAttr(self): def testExecuteListIntAttr(self):
b = execute.execute( b = execute.execute(
'Squeeze', b'Squeeze',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor([[[3.0]]])], inputs=[tensor.Tensor([[[3.0]]])],
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', [0, 2]))[0] attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', [0, 2]))[0]
@ -331,7 +331,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListIntAttrBadValue(self): def testExecuteListIntAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'Squeeze', b'Squeeze',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor([[[3.0]]])], inputs=[tensor.Tensor([[[3.0]]])],
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 0)) attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 0))
@ -339,7 +339,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListIntAttrBadListValue(self): def testExecuteListIntAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'Squeeze', b'Squeeze',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor([[[3.0]]])], inputs=[tensor.Tensor([[[3.0]]])],
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims',
@ -347,7 +347,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListTypeListShapeAttr(self): def testExecuteListTypeListShapeAttr(self):
execute.execute( execute.execute(
'Barrier', b'Barrier',
num_outputs=1, num_outputs=1,
inputs=[], inputs=[],
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes', attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
@ -356,7 +356,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListTypeAttrBadValue(self): def testExecuteListTypeAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'Barrier', b'Barrier',
num_outputs=1, num_outputs=1,
inputs=[], inputs=[],
attrs=('component_types', dtypes.float64.as_datatype_enum, 'shapes', attrs=('component_types', dtypes.float64.as_datatype_enum, 'shapes',
@ -365,7 +365,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListTypeAttrBadListValue(self): def testExecuteListTypeAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'Barrier', b'Barrier',
num_outputs=1, num_outputs=1,
inputs=[], inputs=[],
attrs=('component_types', '1', 'shapes', [[1, 2]], 'capacity', -1, attrs=('component_types', '1', 'shapes', [[1, 2]], 'capacity', -1,
@ -374,7 +374,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListShapeAttrBadValue(self): def testExecuteListShapeAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'Barrier', b'Barrier',
num_outputs=1, num_outputs=1,
inputs=[], inputs=[],
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes', attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
@ -383,7 +383,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteListShapeAttrBadListValue(self): def testExecuteListShapeAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'Barrier', b'Barrier',
num_outputs=1, num_outputs=1,
inputs=[], inputs=[],
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes', attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
@ -393,7 +393,7 @@ class TFETest(test_util.TensorFlowTestCase):
split_dim = 1 split_dim = 1
value = [[0, 1, 2], [3, 4, 5]] value = [[0, 1, 2], [3, 4, 5]]
x1, x2, x3 = execute.execute( x1, x2, x3 = execute.execute(
'Split', b'Split',
num_outputs=3, num_outputs=3,
inputs=[tensor.Tensor(split_dim), inputs=[tensor.Tensor(split_dim),
tensor.Tensor(value)], tensor.Tensor(value)],
@ -405,18 +405,18 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteBadNumOutputsArgument(self): def testExecuteBadNumOutputsArgument(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
execute.execute( execute.execute(
'Relu', [], b'Relu', [],
inputs=[tensor.Tensor(3.0)], inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum)) attrs=('T', dtypes.float32.as_datatype_enum))
def testExecuteUnknownOp(self): def testExecuteUnknownOp(self):
with self.assertRaises(errors.NotFoundError): with self.assertRaises(errors.NotFoundError):
execute.execute('BlahBlahBlah', num_outputs=1, inputs=[], attrs=None) execute.execute(b'BlahBlahBlah', num_outputs=1, inputs=[], attrs=None)
def testExecuteUnknownAttr(self): def testExecuteUnknownAttr(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
execute.execute( execute.execute(
'Identity', b'Identity',
num_outputs=1, num_outputs=1,
inputs=[tensor.Tensor(3)], inputs=[tensor.Tensor(3)],
attrs=('T', dtypes.int32.as_datatype_enum, 'unknown_attr', 'blah')) attrs=('T', dtypes.int32.as_datatype_enum, 'unknown_attr', 'blah'))
@ -425,7 +425,7 @@ class TFETest(test_util.TensorFlowTestCase):
def add(x, y): def add(x, y):
return execute.execute( return execute.execute(
'Add', b'Add',
num_outputs=1, num_outputs=1,
inputs=[x, y], inputs=[x, y],
attrs=('T', dtypes.int32.as_datatype_enum))[0] attrs=('T', dtypes.int32.as_datatype_enum))[0]
@ -447,7 +447,7 @@ class TFETest(test_util.TensorFlowTestCase):
y = truncated_normal(shape) y = truncated_normal(shape)
# Add would fail if x and y were not on the same device. # Add would fail if x and y were not on the same device.
execute.execute( execute.execute(
'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum)) b'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum))
def testInvalidDevice(self): def testInvalidDevice(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):

View File

@ -63,15 +63,14 @@ def execute(op_name, num_outputs, inputs, attrs=None, name=None):
device_name = ctx.device_name device_name = ctx.device_name
try: try:
outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name, outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
str(op_name), input_handles, attrs, op_name, input_handles, attrs,
num_outputs) num_outputs)
# pylint: enable=protected-access except core._NotOkStatusException as e:
except core._NotOkStatusException as e: # pylint: disable=protected-access
if name is not None: if name is not None:
message = e.message + " name: " + name message = e.message + " name: " + name
else: else:
message = e.message message = e.message
raise core._status_to_exception(e.code, message) # pylint: disable=protected-access raise core._status_to_exception(e.code, message)
# pylint: enable=protected-access # pylint: enable=protected-access
tensors = [tensor._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access tensors = [tensor._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access

View File

@ -261,7 +261,7 @@ class _GraphModeFunction(object):
outputs[i].set_shape(s) outputs[i].set_shape(s)
else: else:
outputs = execute.execute( outputs = execute.execute(
signature.name, str(signature.name),
num_outputs=len(signature.output_arg), num_outputs=len(signature.output_arg),
inputs=all_args) inputs=all_args)
real_outputs = outputs[:len(self._returns)] real_outputs = outputs[:len(self._returns)]
@ -321,7 +321,7 @@ class _GraphModeFunction(object):
for x in tensor_inputs for x in tensor_inputs
] ]
result = execute.execute( result = execute.execute(
self._func_name, str(self._func_name),
num_outputs=self._num_outputs, num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs) inputs=tensor_inputs + self._extra_inputs)

View File

@ -650,7 +650,7 @@ void GenEagerPythonOp::AddEagerAttrs() {
void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) { void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) {
const string return_prefix = " _result = _execute.execute("; const string return_prefix = " _result = _execute.execute(";
const string return_args = const string return_args =
strings::StrCat("\"", op_def_.name(), "\", ", num_outputs_expr, strings::StrCat("b\"", op_def_.name(), "\", ", num_outputs_expr,
", inputs=_inputs_flat, attrs=_attrs, name=name)"); ", inputs=_inputs_flat, attrs=_attrs, name=name)");
strings::StrAppend(&result_, strings::StrAppend(&result_,
// Wrap the arguments, and indent to the (. // Wrap the arguments, and indent to the (.

View File

@ -60,7 +60,7 @@ def _eager_reshape(tensor, shape):
attr_tshape = attr_tshape.as_datatype_enum attr_tshape = attr_tshape.as_datatype_enum
inputs_flat = [tensor, shape] inputs_flat = [tensor, shape]
attrs = ("T", attr_t, "Tshape", attr_tshape) attrs = ("T", attr_t, "Tshape", attr_tshape)
result, = execute.execute("Reshape", 1, inputs=inputs_flat, attrs=attrs) result, = execute.execute(b"Reshape", 1, inputs=inputs_flat, attrs=attrs)
return result return result
@ -70,7 +70,7 @@ def _eager_fill(dims, value):
dims = convert_to_eager_tensor(dims, dtypes.int32) dims = convert_to_eager_tensor(dims, dtypes.int32)
inputs_flat = [dims, value] inputs_flat = [dims, value]
attrs = ("T", attr_t) attrs = ("T", attr_t)
result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs) result, = execute.execute(b"Fill", 1, inputs=inputs_flat, attrs=attrs)
return result return result
@ -84,13 +84,6 @@ def convert_to_eager_tensor(t, dtype=None):
if dtype is not None and t.dtype != dtype: if dtype is not None and t.dtype != dtype:
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
return t return t
# Handle converting ResourceVariable to Tensor.
# TODO(josh11b): get rid of this explicit ugly conversion once we have a more
# general scheme in place.
try:
return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
except AttributeError:
pass
if isinstance(t, (int, float)): if isinstance(t, (int, float)):
# Use a scalar cache. This will put each scalar of each type only once on # Use a scalar cache. This will put each scalar of each type only once on
# each device. Scalars don't use much device memory but copying scalars can # each device. Scalars don't use much device memory but copying scalars can