Merge pull request #35973 from fsx950223:script-fix

PiperOrigin-RevId: 290946023
Change-Id: Ibfd06e3042665586f11cba7fb436aed74285471e
This commit is contained in:
TensorFlower Gardener 2020-01-22 07:10:12 -08:00
commit 05bc4f07f4
3 changed files with 57 additions and 1 deletions

View File

@ -4260,6 +4260,20 @@ py_test(
],
)
py_test(
name = "script_ops_test",
srcs = ["ops/script_ops_test.py"],
python_version = "PY3",
deps = [
":client_testlib",
":dtypes",
":framework_ops",
":framework_test_lib",
":script_ops",
"//third_party/py/numpy",
],
)
py_library(
name = "standard_ops",
srcs = ["ops/standard_ops.py"],

View File

@ -509,7 +509,7 @@ def py_func_common(func, inp, Tout, stateful=True, name=None):
A list of `Tensor` or a single `Tensor` which `func` computes.
"""
if context.executing_eagerly():
result = func(*[x.numpy() for x in inp])
result = func(*[np.array(x) for x in inp])
result = nest.flatten(result)
result = [x if x is None else ops.convert_to_tensor(x) for x in result]

View File

@ -0,0 +1,42 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for script operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
class NumpyFunctionTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_numpy_arguments(self):
def plus(a, b):
return a + b
actual_result = script_ops.numpy_function(plus, [1, 2], dtypes.int32)
expect_result = constant_op.constant(3, dtypes.int32)
self.assertAllEqual(actual_result, expect_result)
if __name__ == "__main__":
test.main()