Source code for causalcompass.algorithms.cutsplus.wrapper

import torch
import numpy as np
import tempfile
import shutil
from types import SimpleNamespace
from ..base import BaseCausalAlgorithm

[docs] class CUTSPlus(BaseCausalAlgorithm): """ CUTS+ extends the original CUTS framework. References ---------- https://github.com/jarrycyx/UNN Parameters ---------- input_step : int, default 10 Number of past time steps used as input batch_size : int, default 32 Training batch size weight_decay : float, default 0.001 Controls the strength of regularization device : str, default 'cuda' Computation device Examples -------- >>> from causalcompass.algorithms import CUTSPlus >>> model = CUTSPlus(input_step=10, batch_size=32, weight_decay=0.001, device='cuda') >>> predicted_adj = model.run(X, true_cm=true_adj, mask=mask) >>> all_metrics, no_diag_metrics = model.eval(true_adj, predicted_adj) """
[docs] def __init__(self, input_step=10, batch_size=32, weight_decay=0.001, device='cuda', seed=None, **kwargs): """ Initialize CUTS+ """ super().__init__(seed=seed) self._eval_output_type = "continuous" self.device = device self.config_params = { 'cutsplus_input_step': input_step, 'cutsplus_batch_size': batch_size, 'cutsplus_weight_decay': weight_decay, 'device': device }
def run(self, X, true_cm, mask=None): """ Run CUTS+ algorithm. :param X: Time series data, shape (T, p). :param mask: Data mask, shape (T, p). 1 for observed, 0 for missing. Defaults to all 1s. :param true_cm: True causal matrix (p, p). :return: Predicted adjacency matrix, shape (p, p). """ # Handle defaults if mask is None: mask = np.ones_like(X) # Construct a mock args object run_params = self.config_params.copy() _, p = X.shape run_params['p'] = p args_mock = SimpleNamespace(**run_params) from .cuts_plus import main as cutsplus_main from .utils.logger import MyLogger from .cutsplus_config import create_cutsplus_config # Create configuration opt = create_cutsplus_config(args_mock) # Device handling if self.device.startswith("cuda") and not torch.cuda.is_available(): print("CUDA not available, switching to CPU.") device = torch.device("cpu") else: device = torch.device(self.device) temp_dir = tempfile.mkdtemp(prefix="cutsplus_temp_") log = MyLogger(log_dir=temp_dir, stdout=False, stderr=False, tensorboard=False) try: # Run the algorithm predicted_adj = cutsplus_main( data=X, mask=mask, true_cm=true_cm, opt=opt, log=log, device=device ) return predicted_adj finally: # Clean up temporary directory try: shutil.rmtree(temp_dir) except: pass