gpt-4chan-public/src/compute_metrics.py

78 lines
2.9 KiB
Python
Executable File

#!/usr/bin/env python3
import json
from pathlib import Path
from loguru import logger
from tabulate import tabulate
import lm_eval.tasks
m1 = 'GPT-J-6B'
m2 = 'GPT-4chan'
log_dir = Path('./eval_logs')
all_tasks = set()
model_data = {}
for fn in log_dir.rglob('log_*.stdout.txt'):
try:
file_text = fn.read_text()
data = json.loads('{' + file_text.split('{', 1)[1].rsplit('}', 1)[0] + '}')
model = data['config']['model_args'].split('=')[1]
model = m2 if 'fp16' in model else m1
if model not in model_data:
model_data[model] = {}
results = data['results']
tasks = list(results.keys())
assert len(tasks) == 1, 'Only one task supported'
task = tasks[0]
if task in model_data[model]:
raise ValueError(f'Duplicate task {task}')
task_version = data['versions'][task]
results = results[task]
results_data = {}
for result_key in results:
if result_key.endswith('_stderr'):
continue
result_value = results[result_key]
results_data[result_key] = {'value': result_value}
stderr_key = f'{result_key}_stderr'
if stderr_key in results:
results_data[result_key]['stderr'] = results[stderr_key]
else:
logger.warning(f'No stderr for {result_key} in {results}')
model_data[model][task] = {'version': task_version, 'results': results_data}
all_tasks.add(task)
except Exception:
logger.exception(f'Failed to parse {fn}')
continue
all_models = list(sorted(model_data.keys()))
table_data = []
for task in all_tasks:
try:
higher_is_better = lm_eval.tasks.get_task(task).higher_is_better(None)
except Exception:
logger.warning(f'Failed to get higher_is_better for {task}')
continue
if any(task not in model_data[model] for model in all_models):
logger.warning(f'No results for {task}')
continue
results = model_data[m1][task]['results']
results2 = model_data[m2][task]['results']
for metric in results:
result_value = results[metric]['value']
stderr_value = results[metric].get('stderr', 0.0)
result2_value = results2[metric]['value']
stderr2_value = results2[metric].get('stderr', 0.0)
significance = (result_value - result2_value) / ((stderr_value + stderr2_value + 1e-8) / 2)
if higher_is_better[metric]:
significance *= -1
if abs(significance) > 1:
significant = '+' if significance > 0 else '-'
else:
significant = ''
table_data.append([task, metric, result_value, stderr_value, result2_value, stderr2_value, significant])
table_str = tabulate(table_data, headers=['Task', 'Metric', m1, 'stderr', m2, 'stderr', 'Significant'], tablefmt='pipe')
print(table_str)
Path('./results.table.txt').write_text(table_str)