47 lines
1.1 KiB
Python
47 lines
1.1 KiB
Python
def load_csv(path: str):
|
|
"""
|
|
returns (header: list[str], rows: list[list[T]])
|
|
"""
|
|
header = []
|
|
rows = []
|
|
for i, line in enumerate(open(path).read().strip().split('\n')):
|
|
if i == 0:
|
|
header = line.split(',')
|
|
else:
|
|
rows.append(eval(line))
|
|
return header, rows
|
|
|
|
def labeled_rows(header: list, rows: list):
|
|
"""
|
|
return a list of dicts,
|
|
transforming each row into a kv map
|
|
"""
|
|
new_rows = []
|
|
for row in rows:
|
|
new_rows.append({ header[i]: elem for i, elem in enumerate(row) })
|
|
return new_rows
|
|
|
|
def last_row_before_t(rows: list, t: float):
|
|
"""
|
|
return the last row for which row[time] < t
|
|
"""
|
|
prev_row = None
|
|
for row in rows:
|
|
if row["time"] >= t:
|
|
break
|
|
prev_row = row
|
|
return prev_row
|
|
|
|
|
|
def extract_m(row: dict) -> list:
|
|
"""
|
|
return [M(state0), M(state1), ...]
|
|
"""
|
|
m = []
|
|
for k, v in row.items():
|
|
if k.startswith('M(state') and k.endswith(')'):
|
|
n = int(k[len('M(state'):-1])
|
|
assert n == len(m)
|
|
m.append(v)
|
|
return m
|