Expose broadcast_weights
as a tf.__internal__ API.
PiperOrigin-RevId: 340952998 Change-Id: If3da39180bc5aa726a93df6cb286eba33c760a4c
This commit is contained in:
parent
50950be11b
commit
a26cc18a8b
@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sets
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def _has_valid_dims(weights_shape, values_shape):
|
||||
@ -133,6 +134,7 @@ def assert_broadcastable(weights, values):
|
||||
return control_flow_ops.Assert(is_valid_shape, data, name=scope)
|
||||
|
||||
|
||||
@tf_export("__internal__.ops.broadcast_weights", v1=[])
|
||||
def broadcast_weights(weights, values):
|
||||
"""Broadcast `weights` to the same shape as `values`.
|
||||
|
||||
|
@ -11,6 +11,7 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"__internal__/distribute/combinations/__init__.py",
|
||||
"__internal__/distribute/multi_process_runner/__init__.py",
|
||||
"__internal__/nest/__init__.py",
|
||||
"__internal__/ops/__init__.py",
|
||||
"__internal__/test/__init__.py",
|
||||
"__internal__/test/combinations/__init__.py",
|
||||
"__internal__/tf2/__init__.py",
|
||||
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.__internal__.ops"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "broadcast_weights"
|
||||
argspec: "args=[\'weights\', \'values\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -20,6 +20,10 @@ tf_module {
|
||||
name: "nest"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "ops"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "test"
|
||||
mtype: "<type \'module\'>"
|
||||
|
Loading…
x
Reference in New Issue
Block a user