diff --git a/sublime_music/adapters/adapter_base.py b/sublime_music/adapters/adapter_base.py index 087a33b..19fd198 100644 --- a/sublime_music/adapters/adapter_base.py +++ b/sublime_music/adapters/adapter_base.py @@ -170,7 +170,6 @@ class CacheMissError(Exception): KEYRING_APP_NAME = "app.sublimemusic.SublimeMusic" -@dataclass class ConfigurationStore(dict): """ This defines an abstract store for all configuration parameters for a given Adapter. diff --git a/sublime_music/config.py b/sublime_music/config.py index 9c1cd0b..9e981c5 100644 --- a/sublime_music/config.py +++ b/sublime_music/config.py @@ -1,7 +1,7 @@ import logging import os import pickle -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Optional, Tuple, Type, Union @@ -64,26 +64,33 @@ class ProviderConfiguration: if self.caching_adapter_config: self.caching_adapter_config.persist_secrets() + def asdict(self) -> Dict[str, Any]: + def get_typename(key: str) -> Optional[str]: + key += "_type" + if isinstance(self, dict): + return type_.__name__ if (type_ := self.get(key)) else None + else: + return type_.__name__ if (type_ := getattr(self, key)) else None + + return { + "id": self.id, + "name": self.name, + "ground_truth_adapter_type": get_typename("ground_truth_adapter"), + "ground_truth_adapter_config": dict(self.ground_truth_adapter_config), + "caching_adapter_type": get_typename("caching_adapter"), + "caching_adapter_config": dict(self.caching_adapter_config or {}), + } + def encode_providers( providers_dict: Dict[str, Union[ProviderConfiguration, Dict[str, Any]]] ) -> Dict[str, Dict[str, Any]]: - def get_typename( - config: Union[ProviderConfiguration, Dict[str, Any]], - key: str, - ) -> Optional[str]: - key += "_type" - if isinstance(config, dict): - return type_.__name__ if (type_ := config.get(key)) else None - else: - return type_.__name__ if (type_ := getattr(config, key)) else None - return { - id_: { - **(config if isinstance(config, dict) else asdict(config)), - "ground_truth_adapter_type": get_typename(config, "ground_truth_adapter"), - "caching_adapter_type": get_typename(config, "caching_adapter"), - } + id_: ( + config + if isinstance(config, ProviderConfiguration) + else ProviderConfiguration(**config) + ).asdict() for id_, config in providers_dict.items() } diff --git a/tests/config_test.py b/tests/config_test.py index d508637..e03bd7c 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -68,6 +68,11 @@ def test_json_load_unload(config_filename: Path, tmp_path: Path): assert original_config.version == loaded_config.version assert original_config.providers == loaded_config.providers assert original_config.provider == loaded_config.provider + assert original_config.provider and loaded_config.provider + assert ( + original_config.provider.ground_truth_adapter_config.items() + == loaded_config.provider.ground_truth_adapter_config.items() + ) def test_config_migrate_v5_to_v6(config_filename: Path, cwd: Path):