fix encode/decode of provider configurations again

This commit is contained in:
Sumner Evans
2021-11-11 00:16:31 -07:00
parent 120eeb04a9
commit 91d753bd46
3 changed files with 28 additions and 17 deletions

View File

@@ -170,7 +170,6 @@ class CacheMissError(Exception):
KEYRING_APP_NAME = "app.sublimemusic.SublimeMusic" KEYRING_APP_NAME = "app.sublimemusic.SublimeMusic"
@dataclass
class ConfigurationStore(dict): class ConfigurationStore(dict):
""" """
This defines an abstract store for all configuration parameters for a given Adapter. This defines an abstract store for all configuration parameters for a given Adapter.

View File

@@ -1,7 +1,7 @@
import logging import logging
import os import os
import pickle import pickle
from dataclasses import asdict, dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Type, Union from typing import Any, Dict, Optional, Tuple, Type, Union
@@ -64,26 +64,33 @@ class ProviderConfiguration:
if self.caching_adapter_config: if self.caching_adapter_config:
self.caching_adapter_config.persist_secrets() self.caching_adapter_config.persist_secrets()
def asdict(self) -> Dict[str, Any]:
def get_typename(key: str) -> Optional[str]:
key += "_type"
if isinstance(self, dict):
return type_.__name__ if (type_ := self.get(key)) else None
else:
return type_.__name__ if (type_ := getattr(self, key)) else None
return {
"id": self.id,
"name": self.name,
"ground_truth_adapter_type": get_typename("ground_truth_adapter"),
"ground_truth_adapter_config": dict(self.ground_truth_adapter_config),
"caching_adapter_type": get_typename("caching_adapter"),
"caching_adapter_config": dict(self.caching_adapter_config or {}),
}
def encode_providers( def encode_providers(
providers_dict: Dict[str, Union[ProviderConfiguration, Dict[str, Any]]] providers_dict: Dict[str, Union[ProviderConfiguration, Dict[str, Any]]]
) -> Dict[str, Dict[str, Any]]: ) -> Dict[str, Dict[str, Any]]:
def get_typename(
config: Union[ProviderConfiguration, Dict[str, Any]],
key: str,
) -> Optional[str]:
key += "_type"
if isinstance(config, dict):
return type_.__name__ if (type_ := config.get(key)) else None
else:
return type_.__name__ if (type_ := getattr(config, key)) else None
return { return {
id_: { id_: (
**(config if isinstance(config, dict) else asdict(config)), config
"ground_truth_adapter_type": get_typename(config, "ground_truth_adapter"), if isinstance(config, ProviderConfiguration)
"caching_adapter_type": get_typename(config, "caching_adapter"), else ProviderConfiguration(**config)
} ).asdict()
for id_, config in providers_dict.items() for id_, config in providers_dict.items()
} }

View File

@@ -68,6 +68,11 @@ def test_json_load_unload(config_filename: Path, tmp_path: Path):
assert original_config.version == loaded_config.version assert original_config.version == loaded_config.version
assert original_config.providers == loaded_config.providers assert original_config.providers == loaded_config.providers
assert original_config.provider == loaded_config.provider assert original_config.provider == loaded_config.provider
assert original_config.provider and loaded_config.provider
assert (
original_config.provider.ground_truth_adapter_config.items()
== loaded_config.provider.ground_truth_adapter_config.items()
)
def test_config_migrate_v5_to_v6(config_filename: Path, cwd: Path): def test_config_migrate_v5_to_v6(config_filename: Path, cwd: Path):