Convert a lightning checkpoint to our own format.
The final checkpoint will contain the model configuration and the state dict.
It will be saved to:
out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"
Parameters:
-
lightning_checkpoint
(pathlib.Path
)
–
Path to the lightning checkpoint.
-
out_directory
(pathlib.Path
)
–
Output directory for the converted checkpoint.
-
checkpoint_name
(str
)
–
A unique name of the new checkpoint.
-
framework
(str
, default:
'smp'
)
–
The framework used for the model. Defaults to "smp".
Source code in darts/src/darts/legacy_training/util.py
| def convert_lightning_checkpoint(
*,
lightning_checkpoint: Path,
out_directory: Path,
checkpoint_name: str,
framework: str = "smp",
):
"""Convert a lightning checkpoint to our own format.
The final checkpoint will contain the model configuration and the state dict.
It will be saved to:
```python
out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"
```
Args:
lightning_checkpoint (Path): Path to the lightning checkpoint.
out_directory (Path): Output directory for the converted checkpoint.
checkpoint_name (str): A unique name of the new checkpoint.
framework (str, optional): The framework used for the model. Defaults to "smp".
"""
import torch
logger.debug(f"Loading checkpoint from {lightning_checkpoint.resolve()}")
lckpt = torch.load(lightning_checkpoint, weights_only=False, map_location=torch.device("cpu"))
now = datetime.now()
formatted_date = now.strftime("%Y-%m-%d")
config = lckpt["hyper_parameters"]["config"]
del config["model"]["encoder_weights"]
config["time"] = formatted_date
config["name"] = checkpoint_name
config["model_framework"] = framework
statedict = lckpt["state_dict"]
# Statedict has model. prefix before every weight. We need to remove them. This is an in-place function
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(statedict, "model.")
own_ckpt = {
"config": config,
"statedict": lckpt["state_dict"],
}
out_directory.mkdir(exist_ok=True, parents=True)
out_checkpoint = out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"
torch.save(own_ckpt, out_checkpoint)
logger.info(f"Saved converted checkpoint to {out_checkpoint.resolve()}")
|