allow for providers dictionary keys to be either a ProviderConfiguration or a dict

This commit is contained in:
Sumner Evans
2021-10-18 21:46:12 -06:00
parent 8eb5f73289
commit b3679d25fa

View File

@@ -1,7 +1,7 @@
import logging import logging
import os import os
import pickle import pickle
from dataclasses import dataclass, field from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, cast, Dict, Optional, Tuple, Type, Union from typing import Any, cast, Dict, Optional, Tuple, Type, Union
@@ -66,16 +66,28 @@ class ProviderConfiguration:
def encode_providers( def encode_providers(
providers_dict: Dict[str, Dict[str, Any]] providers_dict: Dict[str, Union[ProviderConfiguration, Dict[str, Any]]]
) -> Dict[str, Dict[str, Any]]: ) -> Dict[str, Dict[str, Any]]:
return { return {
id_: { id_: {
**(config.__dict__), **(config if isinstance(config, dict) else asdict(config)),
"ground_truth_adapter_type": config.ground_truth_adapter_type.__name__, "ground_truth_adapter_type": (
config["ground_truth_adapter_type"]
if isinstance(config, dict)
else config.ground_truth_adapter_type
).__name__,
"caching_adapter_type": ( "caching_adapter_type": (
cast(type, config.get("caching_adapter_type")).__name__ (
if config.caching_adapter_type is None cast(type, config.get("caching_adapter_type")).__name__
else None if config.get("caching_adapter_type") is not None
else None
)
if isinstance(config, dict)
else (
config.caching_adapter_type.__name__
if config.caching_adapter_type is not None
else None
)
), ),
} }
for id_, config in providers_dict.items() for id_, config in providers_dict.items()