diff --git a/crates/post/scripts/stacked_cores.py b/crates/post/scripts/stacked_cores.py new file mode 100644 index 0000000..6320194 --- /dev/null +++ b/crates/post/scripts/stacked_cores.py @@ -0,0 +1,46 @@ +def load_csv(path: str): + """ + returns (header: list[str], rows: list[list[T]]) + """ + header = [] + rows = [] + for i, line in enumerate(open(path).read().strip().split('\n')): + if i == 0: + header = line.split(',') + else: + rows.append(eval(line)) + return header, rows + +def labeled_rows(header: list, rows: list): + """ + return a list of dicts, + transforming each row into a kv map + """ + new_rows = [] + for row in rows: + new_rows.append({ header[i]: elem for i, elem in enumerate(row) }) + return new_rows + +def last_row_before_t(rows: list, t: float): + """ + return the last row for which row[time] < t + """ + prev_row = None + for row in rows: + if row["time"] >= t: + break + prev_row = row + return prev_row + + +def extract_m(row: dict) -> list: + """ + return [M(state0), M(state1), ...] + """ + m = [] + for k, v in row.items(): + if k.startswith('M(state') and k.endswith(')'): + n = int(k[len('M(state'):-1]) + assert n == len(m) + m.append(v) + return m diff --git a/crates/post/scripts/stacked_cores_8xx.py b/crates/post/scripts/stacked_cores_8xx.py index f8ca6ef..2824a03 100755 --- a/crates/post/scripts/stacked_cores_8xx.py +++ b/crates/post/scripts/stacked_cores_8xx.py @@ -5,52 +5,7 @@ to extract higher-level info from them. """ import sys -def load_csv(path: str): - """ - returns (header: list[str], rows: list[list[T]]) - """ - header = [] - rows = [] - for i, line in enumerate(open(path).read().strip().split('\n')): - if i == 0: - header = line.split(',') - else: - rows.append(eval(line)) - return header, rows - -def labeled_rows(header: list, rows: list): - """ - return a list of dicts, - transforming each row into a kv map - """ - new_rows = [] - for row in rows: - new_rows.append({ header[i]: elem for i, elem in enumerate(row) }) - return new_rows - -def last_row_before_t(rows: list, t: float): - """ - return the last row for which row[time] < t - """ - prev_row = None - for row in rows: - if row["time"] >= t: - break - prev_row = row - return prev_row - - -def extract_m(row: dict) -> list: - """ - return [M(state0), M(state1), ...] - """ - m = [] - for k, v in row.items(): - if k.startswith('M(state') and k.endswith(')'): - n = int(k[len('M(state'):-1]) - assert n == len(m) - m.append(v) - return m +from stacked_cores import load_csv, labeled_rows, last_row_before_t, extract_m def extract_8xx(path: str): header, raw_rows = load_csv(path)