From aafebf2488854e4607147f4f3d442f7329fb06eb Mon Sep 17 00:00:00 2001 From: Scott Zhu <scottzhu@google.com> Date: Fri, 18 Sep 2020 10:59:08 -0700 Subject: [PATCH] Internal changes only. PiperOrigin-RevId: 332484552 Change-Id: Ie9b9b1529706d284fe6d3485ca5eeea04335d04f --- .../python/framework/test_combinations.py | 54 ++++++++++++++++++- ...est.combinations.-optional-parameter.pbtxt | 18 +++++++ ...est.combinations.-parameter-modifier.pbtxt | 17 ++++++ ...rflow.__internal__.test.combinations.pbtxt | 8 +++ 4 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-optional-parameter.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-parameter-modifier.pbtxt diff --git a/tensorflow/python/framework/test_combinations.py b/tensorflow/python/framework/test_combinations.py index 09920f68adf..09b6ba478db 100644 --- a/tensorflow/python/framework/test_combinations.py +++ b/tensorflow/python/framework/test_combinations.py @@ -114,8 +114,27 @@ class TestCombination(object): return [] +@tf_export("__internal__.test.combinations.ParameterModifier", v1=[]) class ParameterModifier(object): - """Customizes the behavior of a particular parameter.""" + """Customizes the behavior of a particular parameter. + + Users should override `modified_arguments()` to modify the parameter they + want, eg: change the value of certain parameter or filter it from the params + passed to the test case. + + See the sample usage below, it will change any negative parameters to zero + before it gets passed to test case. + ``` + class NonNegativeParameterModifier(ParameterModifier): + + def modified_arguments(self, kwargs, requested_parameters): + updates = {} + for name, value in kwargs.items(): + if value < 0: + updates[name] = 0 + return updates + ``` + """ DO_NOT_PASS_TO_THE_TEST = object() @@ -171,8 +190,39 @@ class ParameterModifier(object): return id(self.__class__) +@tf_export("__internal__.test.combinations.OptionalParameter", v1=[]) class OptionalParameter(ParameterModifier): - """A parameter that is optional in `combine()` and in the test signature.""" + """A parameter that is optional in `combine()` and in the test signature. + + `OptionalParameter` is usually used with `TestCombination` in the + `parameter_modifiers()`. It allows `TestCombination` to skip certain + parameters when passing them to `combine()`, since the `TestCombination` might + consume the param and create some context based on the value it gets. + + See the sample usage below: + + ``` + class EagerGraphCombination(TestCombination): + + def context_managers(self, kwargs): + mode = kwargs.pop("mode", None) + if mode is None: + return [] + elif mode == "eager": + return [context.eager_mode()] + elif mode == "graph": + return [ops.Graph().as_default(), context.graph_mode()] + else: + raise ValueError( + "'mode' has to be either 'eager' or 'graph', got {}".format(mode)) + + def parameter_modifiers(self): + return [test_combinations.OptionalParameter("mode")] + ``` + + When the test case is generated, the param "mode" will not be passed to the + test method, since it is consumed by the `EagerGraphCombination`. + """ def modified_arguments(self, kwargs, requested_parameters): if self._parameter_name in requested_parameters: diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-optional-parameter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-optional-parameter.pbtxt new file mode 100644 index 00000000000..4dedda05bf6 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-optional-parameter.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.__internal__.test.combinations.OptionalParameter" +tf_class { + is_instance: "<class \'tensorflow.python.framework.test_combinations.OptionalParameter\'>" + is_instance: "<class \'tensorflow.python.framework.test_combinations.ParameterModifier\'>" + is_instance: "<type \'object\'>" + member { + name: "DO_NOT_PASS_TO_THE_TEST" + mtype: "<type \'object\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'parameter_name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "modified_arguments" + argspec: "args=[\'self\', \'kwargs\', \'requested_parameters\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-parameter-modifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-parameter-modifier.pbtxt new file mode 100644 index 00000000000..9b2438ccc8a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-parameter-modifier.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.__internal__.test.combinations.ParameterModifier" +tf_class { + is_instance: "<class \'tensorflow.python.framework.test_combinations.ParameterModifier\'>" + is_instance: "<type \'object\'>" + member { + name: "DO_NOT_PASS_TO_THE_TEST" + mtype: "<type \'object\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'parameter_name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "modified_arguments" + argspec: "args=[\'self\', \'kwargs\', \'requested_parameters\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.pbtxt index 08695f72bea..b5190c37802 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.pbtxt @@ -4,6 +4,14 @@ tf_module { name: "NamedObject" mtype: "<type \'type\'>" } + member { + name: "OptionalParameter" + mtype: "<type \'type\'>" + } + member { + name: "ParameterModifier" + mtype: "<type \'type\'>" + } member { name: "TestCombination" mtype: "<type \'type\'>"