Fix problem saving hysteretic curves models and update those models and results

parent 1afe0825
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -462,6 +462,13 @@ def cv_score_for_params_cached(
return mean_rmse, std_rmse, arr
def strip_compile_prefix(state_dict: dict) -> dict:
# torch.compile suele guardar keys como "_orig_mod.xxx"
if not any(k.startswith("_orig_mod.") for k in state_dict.keys()):
return state_dict
return {k.replace("_orig_mod.", "", 1): v for k, v in state_dict.items()}
# -------------------------
# Final training on 100% train_pool (cached) - no val
# Keep best by TRAIN loss
......@@ -561,7 +568,8 @@ def train_final_full_trainpool_cached( # type: ignore
epoch_loss = running / max(1, n_batches)
if epoch_loss < best_train_loss:
best_train_loss = epoch_loss
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
raw_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
best_state = strip_compile_prefix(raw_state)
del model, optimizer, criterion, train_loader, train_ds
gc.collect()
......@@ -655,7 +663,7 @@ def main():
if b_val == 29:
predict_case = 12
elif b_val == 34:
predict_case = 11
predict_case = 12
else:
print(f"Warning: No predict_case set for W={w_val}, B={b_val}")
elif w_val == 3:
......@@ -668,7 +676,7 @@ def main():
print(f"Warning: No predict_case set for W={w_val}, B={b_val}")
elif w_val == 5:
n_train_case = 64
predict_case = 13
predict_case = 70
else:
print("Warning: n_train_case and predict_case not set for W != 2, 3, or 5")
......@@ -875,6 +883,8 @@ def main():
dropout=best_params_model["dropout"]
).to(device)
final_state = strip_compile_prefix(final_state)
# compile not necessary for predict, but harmless if final was compiled; keep OFF to avoid
# overhead
final_model.load_state_dict(final_state)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment