Internal changes only.
PiperOrigin-RevId: 332484552 Change-Id: Ie9b9b1529706d284fe6d3485ca5eeea04335d04f
This commit is contained in:
parent
2ed4c2279c
commit
aafebf2488
@ -114,8 +114,27 @@ class TestCombination(object):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("__internal__.test.combinations.ParameterModifier", v1=[])
|
||||||
class ParameterModifier(object):
|
class ParameterModifier(object):
|
||||||
"""Customizes the behavior of a particular parameter."""
|
"""Customizes the behavior of a particular parameter.
|
||||||
|
|
||||||
|
Users should override `modified_arguments()` to modify the parameter they
|
||||||
|
want, eg: change the value of certain parameter or filter it from the params
|
||||||
|
passed to the test case.
|
||||||
|
|
||||||
|
See the sample usage below, it will change any negative parameters to zero
|
||||||
|
before it gets passed to test case.
|
||||||
|
```
|
||||||
|
class NonNegativeParameterModifier(ParameterModifier):
|
||||||
|
|
||||||
|
def modified_arguments(self, kwargs, requested_parameters):
|
||||||
|
updates = {}
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
if value < 0:
|
||||||
|
updates[name] = 0
|
||||||
|
return updates
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
DO_NOT_PASS_TO_THE_TEST = object()
|
DO_NOT_PASS_TO_THE_TEST = object()
|
||||||
|
|
||||||
@ -171,8 +190,39 @@ class ParameterModifier(object):
|
|||||||
return id(self.__class__)
|
return id(self.__class__)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("__internal__.test.combinations.OptionalParameter", v1=[])
|
||||||
class OptionalParameter(ParameterModifier):
|
class OptionalParameter(ParameterModifier):
|
||||||
"""A parameter that is optional in `combine()` and in the test signature."""
|
"""A parameter that is optional in `combine()` and in the test signature.
|
||||||
|
|
||||||
|
`OptionalParameter` is usually used with `TestCombination` in the
|
||||||
|
`parameter_modifiers()`. It allows `TestCombination` to skip certain
|
||||||
|
parameters when passing them to `combine()`, since the `TestCombination` might
|
||||||
|
consume the param and create some context based on the value it gets.
|
||||||
|
|
||||||
|
See the sample usage below:
|
||||||
|
|
||||||
|
```
|
||||||
|
class EagerGraphCombination(TestCombination):
|
||||||
|
|
||||||
|
def context_managers(self, kwargs):
|
||||||
|
mode = kwargs.pop("mode", None)
|
||||||
|
if mode is None:
|
||||||
|
return []
|
||||||
|
elif mode == "eager":
|
||||||
|
return [context.eager_mode()]
|
||||||
|
elif mode == "graph":
|
||||||
|
return [ops.Graph().as_default(), context.graph_mode()]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"'mode' has to be either 'eager' or 'graph', got {}".format(mode))
|
||||||
|
|
||||||
|
def parameter_modifiers(self):
|
||||||
|
return [test_combinations.OptionalParameter("mode")]
|
||||||
|
```
|
||||||
|
|
||||||
|
When the test case is generated, the param "mode" will not be passed to the
|
||||||
|
test method, since it is consumed by the `EagerGraphCombination`.
|
||||||
|
"""
|
||||||
|
|
||||||
def modified_arguments(self, kwargs, requested_parameters):
|
def modified_arguments(self, kwargs, requested_parameters):
|
||||||
if self._parameter_name in requested_parameters:
|
if self._parameter_name in requested_parameters:
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.__internal__.test.combinations.OptionalParameter"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.framework.test_combinations.OptionalParameter\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.framework.test_combinations.ParameterModifier\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "DO_NOT_PASS_TO_THE_TEST"
|
||||||
|
mtype: "<type \'object\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'parameter_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "modified_arguments"
|
||||||
|
argspec: "args=[\'self\', \'kwargs\', \'requested_parameters\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,17 @@
|
|||||||
|
path: "tensorflow.__internal__.test.combinations.ParameterModifier"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.framework.test_combinations.ParameterModifier\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "DO_NOT_PASS_TO_THE_TEST"
|
||||||
|
mtype: "<type \'object\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'parameter_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "modified_arguments"
|
||||||
|
argspec: "args=[\'self\', \'kwargs\', \'requested_parameters\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,14 @@ tf_module {
|
|||||||
name: "NamedObject"
|
name: "NamedObject"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "OptionalParameter"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "ParameterModifier"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "TestCombination"
|
name: "TestCombination"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
Loading…
Reference in New Issue
Block a user