Internal changes only.

PiperOrigin-RevId: 332484552
Change-Id: Ie9b9b1529706d284fe6d3485ca5eeea04335d04f
This commit is contained in:
Scott Zhu 2020-09-18 10:59:08 -07:00 committed by TensorFlower Gardener
parent 2ed4c2279c
commit aafebf2488
4 changed files with 95 additions and 2 deletions

View File

@ -114,8 +114,27 @@ class TestCombination(object):
return []
@tf_export("__internal__.test.combinations.ParameterModifier", v1=[])
class ParameterModifier(object):
"""Customizes the behavior of a particular parameter."""
"""Customizes the behavior of a particular parameter.
Users should override `modified_arguments()` to modify the parameter they
want, eg: change the value of certain parameter or filter it from the params
passed to the test case.
See the sample usage below, it will change any negative parameters to zero
before it gets passed to test case.
```
class NonNegativeParameterModifier(ParameterModifier):
def modified_arguments(self, kwargs, requested_parameters):
updates = {}
for name, value in kwargs.items():
if value < 0:
updates[name] = 0
return updates
```
"""
DO_NOT_PASS_TO_THE_TEST = object()
@ -171,8 +190,39 @@ class ParameterModifier(object):
return id(self.__class__)
@tf_export("__internal__.test.combinations.OptionalParameter", v1=[])
class OptionalParameter(ParameterModifier):
"""A parameter that is optional in `combine()` and in the test signature."""
"""A parameter that is optional in `combine()` and in the test signature.
`OptionalParameter` is usually used with `TestCombination` in the
`parameter_modifiers()`. It allows `TestCombination` to skip certain
parameters when passing them to `combine()`, since the `TestCombination` might
consume the param and create some context based on the value it gets.
See the sample usage below:
```
class EagerGraphCombination(TestCombination):
def context_managers(self, kwargs):
mode = kwargs.pop("mode", None)
if mode is None:
return []
elif mode == "eager":
return [context.eager_mode()]
elif mode == "graph":
return [ops.Graph().as_default(), context.graph_mode()]
else:
raise ValueError(
"'mode' has to be either 'eager' or 'graph', got {}".format(mode))
def parameter_modifiers(self):
return [test_combinations.OptionalParameter("mode")]
```
When the test case is generated, the param "mode" will not be passed to the
test method, since it is consumed by the `EagerGraphCombination`.
"""
def modified_arguments(self, kwargs, requested_parameters):
if self._parameter_name in requested_parameters:

View File

@ -0,0 +1,18 @@
path: "tensorflow.__internal__.test.combinations.OptionalParameter"
tf_class {
is_instance: "<class \'tensorflow.python.framework.test_combinations.OptionalParameter\'>"
is_instance: "<class \'tensorflow.python.framework.test_combinations.ParameterModifier\'>"
is_instance: "<type \'object\'>"
member {
name: "DO_NOT_PASS_TO_THE_TEST"
mtype: "<type \'object\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'parameter_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "modified_arguments"
argspec: "args=[\'self\', \'kwargs\', \'requested_parameters\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,17 @@
path: "tensorflow.__internal__.test.combinations.ParameterModifier"
tf_class {
is_instance: "<class \'tensorflow.python.framework.test_combinations.ParameterModifier\'>"
is_instance: "<type \'object\'>"
member {
name: "DO_NOT_PASS_TO_THE_TEST"
mtype: "<type \'object\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'parameter_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "modified_arguments"
argspec: "args=[\'self\', \'kwargs\', \'requested_parameters\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -4,6 +4,14 @@ tf_module {
name: "NamedObject"
mtype: "<type \'type\'>"
}
member {
name: "OptionalParameter"
mtype: "<type \'type\'>"
}
member {
name: "ParameterModifier"
mtype: "<type \'type\'>"
}
member {
name: "TestCombination"
mtype: "<type \'type\'>"