Source code for causalcompass.algorithms.cuts.wrapper

import torch
import numpy as np
import tempfile
import shutil
from types import SimpleNamespace
from ..base import BaseCausalAlgorithm

[docs] class CUTS(BaseCausalAlgorithm): """ CUTS is a deep learning-based causal discovery method tailored for irregular time series data. References ---------- https://github.com/jarrycyx/UNN Parameters ---------- input_step : int, default 10 Number of past time steps used as input batch_size : int, default 32 Training batch size weight_decay : float, default 0.001 Controls the strength of regularization device : str, default 'cuda' Computation device Examples -------- >>> from causalcompass.algorithms import CUTS >>> model = CUTS(input_step=10, batch_size=32, weight_decay=0.001, device='cuda') >>> predicted_adj = model.run(X, true_cm=true_adj, mask=mask, original_data=X) >>> all_metrics, no_diag_metrics = model.eval(true_adj, predicted_adj) """
[docs] def __init__(self, input_step=10, batch_size=32, weight_decay=0.001, device='cuda', seed=None, **kwargs): """ Initialize CUTS """ super().__init__(seed=seed) self._eval_output_type = "continuous" self.device = device self.config_params = { 'cuts_input_step': input_step, 'cuts_batch_size': batch_size, 'cuts_weight_decay': weight_decay, 'device': device }
def run(self, X, true_cm, mask=None, original_data=None): """ Run CUTS algorithm. :param X: Interpolated or complete time series data, shape (T, p). :param mask: Data mask, shape (T, p). Defaults to all 1s. :param original_data: Original complete data (if available). Defaults to X. :param true_cm: True causal matrix. :return: Predicted adjacency matrix, shape (p, p). """ # Default handling if mask is None: mask = np.ones_like(X) if original_data is None: original_data = X # Mock args run_params = self.config_params.copy() _, p = X.shape run_params['p'] = p args_mock = SimpleNamespace(**run_params) # Import internal modules from .cuts_main import CUTS as CUTSModel from .utils.logger import MyLogger from .cuts_config import create_cuts_config opt = create_cuts_config(args_mock) if self.device.startswith("cuda") and not torch.cuda.is_available(): device = torch.device("cpu") else: device = torch.device(self.device) temp_dir = tempfile.mkdtemp(prefix="cuts_temp_") log = MyLogger(log_dir=temp_dir, stdout=False, stderr=False, tensorboard=True) try: # Reshape data to (T, p, 1) as expected by CUTS # Check dimensions to avoid double expansion data_interp = X[:, :, np.newaxis] if X.ndim == 2 else X orig_data = original_data[:, :, np.newaxis] if original_data.ndim == 2 else original_data data_mask = mask[:, :, np.newaxis] if mask.ndim == 2 else mask cuts_model = CUTSModel(opt, log, device=device) # Train and get prediction predicted_adj = cuts_model.train(data_interp, data_mask, orig_data, true_cm) return predicted_adj finally: try: shutil.rmtree(temp_dir) except: pass