irtk.parameter.ParamGroup¶
- class irtk.parameter.ParamGroup¶
A class to manage a group of parameters.
This class provides functionality to add, access, and update parameters, as well as track which parameters require gradients or have been updated.
- __init__()¶
Initialize an empty ParamGroup.
- add_param(name: str, value: Any, is_tensor: bool = False, is_diff: bool = False, help_msg: str = '') None ¶
Add a new parameter to the group.
- Parameters:
name – The name of the parameter.
value – The value of the parameter.
is_tensor – Whether the parameter is a tensor.
is_diff – Whether the parameter is differentiable.
help_msg – A description of the parameter.
- get_requiring_grad() List[str] ¶
Get names of parameters that require gradients.
- Returns:
A list of parameter names that are differentiable and require gradients.
- get_updated() List[str] ¶
Get names of parameters that have been updated.
- Returns:
A list of parameter names that have been marked as updated.
- mark_updated(param_name: str, updated: bool = True) None ¶
Mark a parameter as updated or not.
- Parameters:
param_name – The name of the parameter to mark.
updated – Whether to mark the parameter as updated (True) or not (False).