| def replace_layer_recursive(model, old_layer, new_layer): | |
| for name, layer in model._modules.items(): | |
| if layer == old_layer: | |
| model._modules[name] = new_layer | |
| return True | |
| elif replace_layer_recursive(layer, old_layer, new_layer): | |
| return True | |
| return False | |
| def replace_all_layer_type_recursive(model, old_layer_type, new_layer): | |
| for name, layer in model._modules.items(): | |
| if isinstance(layer, old_layer_type): | |
| model._modules[name] = new_layer | |
| replace_all_layer_type_recursive(layer, old_layer_type, new_layer) | |
| def find_layer_types_recursive(model, layer_types): | |
| def predicate(layer): | |
| return type(layer) in layer_types | |
| return find_layer_predicate_recursive(model, predicate) | |
| def find_layer_predicate_recursive(model, predicate): | |
| result = [] | |
| for name, layer in model._modules.items(): | |
| if predicate(layer): | |
| result.append(layer) | |
| result.extend(find_layer_predicate_recursive(layer, predicate)) | |
| return result | |