Source code for causalcompass.algorithms.clstm

import torch
import numpy as np
from .base import BaseCausalAlgorithm
from .ngc.clstm import cLSTM as _cLSTM
from .ngc.clstm import train_model_ista

[docs] class CLSTM(BaseCausalAlgorithm): """ Deep learning-based causal discovery method that uses component-wise LSTMs to model nonlinear Granger causality. References ---------- https://github.com/iancovert/Neural-GC Parameters ---------- context : int, default 10 Context window length hidden_dim : int, default 100 Number of units in LSTM cell lam : float, default 0.005 Sparsity penalty term parameter lr : float, default 0.01 Learning rate max_iter : int, default 20000 Maximum training iterations lam_ridge : float, default 1e-2 Ridge regularization parameter device : str, default 'cuda' Computation device Examples -------- >>> from causalcompass.algorithms import CLSTM >>> model = CLSTM(context=10, hidden_dim=100, lam=0.005, lr=0.01, max_iter=20000, device='cuda') >>> predicted_adj = model.run(X) >>> all_metrics, no_diag_metrics = model.eval(true_adj, predicted_adj) """
[docs] def __init__(self, context=10, hidden_dim=100, lam=0.005, lr=0.01, max_iter=20000, lam_ridge=1e-2, device='cuda', seed=None): """ Initialize cLSTM """ super().__init__(seed=seed) self._eval_output_type = "continuous" self.context = context self.hidden_dim = hidden_dim self.lam = lam self.lr = lr self.max_iter = max_iter self.lam_ridge = lam_ridge self.device = device
def run(self, X): """ Run cLSTM algorithm. :param X: Time series data, shape (T, p). :return: Predicted adjacency matrix, shape (p, p). """ # Device setup if self.device.startswith("cuda") and not torch.cuda.is_available(): print("CUDA is not available, switching to CPU.") device = torch.device("cpu") else: device = torch.device(self.device) # Prepare data: Add batch dimension (1, T, p) X_tensor = torch.tensor(X[np.newaxis], dtype=torch.float32, device=device) # Initialize model model = _cLSTM( num_series=X.shape[-1], hidden=self.hidden_dim, ).to(device) # Train model train_model_ista( model, X_tensor, context=self.context, lam=self.lam, lam_ridge=self.lam_ridge, lr=self.lr, max_iter=self.max_iter, verbose=0, ) # Extract graph structure predicted_adj = model.GC(threshold=False).cpu().data.numpy() return predicted_adj