perf(predict): Improve LSTM training and tuning efficiency

- Reduced hyperparameter optimization trials - Lowered max epochs for tuning and final training - Decreased early stopping patience - Enabled parallel data loading with num_workers - Refined hyperparameter search space for models - Simplified batch size logic for CUDA
parent 5b68dd8a
......@@ -77,10 +77,11 @@ elif W == 5:
else:
print("Warning: N_TRAIN_CASES and PREDICT_CASE not set for W != 2, 3, or 5")
N_TRIALS = 25
MAX_EPOCHS_TUNE = 60
PATIENCE_TUNE = 8 # not used in CV version (we still early-stop inside fold training)
MAX_EPOCHS_FINAL = 300
N_TRIALS = 15 # Reduced from 25
MAX_EPOCHS_TUNE = 40 # Reduced from 60
PATIENCE_TUNE = 5 # Reduced from 8
MAX_EPOCHS_FINAL = 200 # Reduced from 300
N_WORKERS = 2 # for DataLoader parallelization
# -------------------------
# Utilities
......@@ -238,8 +239,8 @@ def train_one_fold( # type: ignore
train_ds = WindowDataset(x_train_scaled, y_train_scaled)
val_ds = WindowDataset(x_val_scaled, np.zeros(len(x_val_scaled), dtype=np.float32))
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=pin)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=pin)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=pin, num_workers=N_WORKERS, persistent_workers=True if N_WORKERS > 0 else False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=pin, num_workers=N_WORKERS, persistent_workers=True if N_WORKERS > 0 else False)
model = LSTMRegressor(
input_dim=n_features,
......@@ -407,7 +408,7 @@ def train_final_full_trainpool( # type: ignore
pin = (device == "cuda")
train_ds = WindowDataset(x_train_scaled, y_train_scaled)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=pin)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=pin, num_workers=N_WORKERS, persistent_workers=True if N_WORKERS > 0 else False)
model = LSTMRegressor(
input_dim=n_features,
......@@ -486,11 +487,11 @@ def train_final_full_trainpool( # type: ignore
# Hyperparameter sampler
# -------------------------
def sample_params(device: str = "cpu"):
window_size = int(np.random.choice([40, 60, 80]))
hidden_dim = int(np.random.choice([32, 64, 128, 256], p=[0.25, 0.35, 0.30, 0.10]))
dense_dim = int(np.random.choice([16, 32, 64, 128], p=[0.20, 0.40, 0.30, 0.10]))
num_layers = int(np.random.choice([1, 2, 3], p=[0.45, 0.40, 0.15]))
dropout = float(np.random.choice([0.0, 0.1, 0.2, 0.3]))
window_size = int(np.random.choice([40, 60], p=[0.6, 0.4])) # Removed 80
hidden_dim = int(np.random.choice([32, 64, 128], p=[0.30, 0.45, 0.25])) # Removed 256
dense_dim = int(np.random.choice([16, 32, 64], p=[0.25, 0.50, 0.25])) # Removed 128
num_layers = int(np.random.choice([1, 2], p=[0.55, 0.45])) # Removed 3
dropout = float(np.random.choice([0.0, 0.1, 0.2])) # Removed 0.3
lr = float(10 ** np.random.uniform(-4.0, -2.6))
weight_decay = float(10 ** np.random.uniform(-6.0, -3.5))
......@@ -499,15 +500,10 @@ def sample_params(device: str = "cpu"):
cost = window_size * hidden_dim * num_layers
if device == "cuda":
if cost >= 80 * 256 * 2:
batch_size = 16
elif cost >= 60 * 256 * 2 or cost >= 80 * 128 * 2:
if cost >= 60 * 128 * 2:
batch_size = 32
else:
batch_size = int(np.random.choice([32, 64], p=[0.55, 0.45]))
if hidden_dim >= 256 and num_layers >= 2 and dense_dim > 64:
dense_dim = 64
batch_size = int(np.random.choice([32, 64], p=[0.45, 0.55]))
else:
batch_size = int(np.random.choice([32, 64, 128]))
......@@ -650,7 +646,7 @@ def main():
pin = (device == "cuda")
test_ds = WindowDataset(x_test_scaled, np.zeros(len(x_test_scaled), dtype=np.float32))
test_loader = DataLoader(
test_ds, batch_size=best_params["batch_size"], shuffle=False, pin_memory=pin
test_ds, batch_size=best_params["batch_size"], shuffle=False, pin_memory=pin, num_workers=N_WORKERS, persistent_workers=True if N_WORKERS > 0 else False
)
use_amp = (device == "cuda")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment