allow for providers dictionary keys to be either a ProviderConfiguration or a dict
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, cast, Dict, Optional, Tuple, Type, Union
|
from typing import Any, cast, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
@@ -66,16 +66,28 @@ class ProviderConfiguration:
|
|||||||
|
|
||||||
|
|
||||||
def encode_providers(
|
def encode_providers(
|
||||||
providers_dict: Dict[str, Dict[str, Any]]
|
providers_dict: Dict[str, Union[ProviderConfiguration, Dict[str, Any]]]
|
||||||
) -> Dict[str, Dict[str, Any]]:
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
id_: {
|
id_: {
|
||||||
**(config.__dict__),
|
**(config if isinstance(config, dict) else asdict(config)),
|
||||||
"ground_truth_adapter_type": config.ground_truth_adapter_type.__name__,
|
"ground_truth_adapter_type": (
|
||||||
|
config["ground_truth_adapter_type"]
|
||||||
|
if isinstance(config, dict)
|
||||||
|
else config.ground_truth_adapter_type
|
||||||
|
).__name__,
|
||||||
"caching_adapter_type": (
|
"caching_adapter_type": (
|
||||||
cast(type, config.get("caching_adapter_type")).__name__
|
(
|
||||||
if config.caching_adapter_type is None
|
cast(type, config.get("caching_adapter_type")).__name__
|
||||||
else None
|
if config.get("caching_adapter_type") is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if isinstance(config, dict)
|
||||||
|
else (
|
||||||
|
config.caching_adapter_type.__name__
|
||||||
|
if config.caching_adapter_type is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
for id_, config in providers_dict.items()
|
for id_, config in providers_dict.items()
|
||||||
|
Reference in New Issue
Block a user