Implementing instrumental variable and backdoor criteria plus propensity and regression methods

Notes

  • Causal Bayesian Networkx: here

dowhy > causal_model.py CausalModel > causal_graph.py CausalGraph

%matplotlib inline
%load_ext autoreload
%autoreload 2
plt.style.use('bmh')
treatments = ['V0', 'V1']
outcome = 'Y'
common_causes = ['W0']
effect_modifiers = ['X0']
instruments = []
observed_nodes = treatments + [outcome] + instruments + effect_modifiers + common_causes
add_unobserved_confounder = True
missing_nodes_as_confounders = True
cg_ref = dw.causal_graph.CausalGraph(treatments,
                                 [outcome], graph=None,
                                 common_cause_names=common_causes,
                                 instrument_names=instruments,
                                 effect_modifier_names=effect_modifiers,
                                 observed_node_names=observed_nodes)
cg_ref._graph.nodes(data=True)
NodeDataView({'V0': {'observed': 'yes'}, 'V1': {'observed': 'yes'}, 'Y': {'observed': 'yes'}, 'W0': {'observed': 'yes'}, 'X0': {'observed': 'yes'}, 'U': {'label': 'Unobserved Confounders', 'observed': 'no'}})

class CausalGraph[source]

CausalGraph(treatments:List[str], outcome:str='Y', common_causes:List[str]=None, effect_modifiers:List[str]=None, instruments:List[str]=None, observed_nodes:List[str]=None, missing_nodes_as_confounders:bool=False, add_unobserved_confounder:bool=True)

cg = CausalGraph(treatments=treatments, outcome=outcome, common_causes=common_causes,
                 effect_modifiers=effect_modifiers, observed_nodes=observed_nodes,
                 missing_nodes_as_confounders=missing_nodes_as_confounders,
                 add_unobserved_confounder=add_unobserved_confounder)
cg.g.nodes['U']['observed']
False

show_graph[source]

show_graph(g:Graph, kind:str='spectral')

view_graph[source]

view_graph(kind:str='spectral')

cg.view_graph()

get_ancestors[source]

get_ancestors(node:str, g:DiGraph=None, parents_only:bool=False)

cg.get_ancestors('V0')
{'U', 'W0'}

cut_edges[source]

cut_edges(edges_to_cut:List[tuple]=None)

g_cut = cg.cut_edges([('U','Y'), ('W0', 'V1')])
show_graph(g_cut)

get_causes[source]

get_causes(nodes:List[str], edges_to_cut:List[tuple]=None)

cg.get_causes(['V0'])
{'U', 'W0'}

get_instruments[source]

get_instruments(treatments:List[str], outcome:str)

cg.get_instruments(treatments, outcome)
$\displaystyle \left\{\right\}$

get_effect_modifiers[source]

get_effect_modifiers(treatments:List[str], outcomes:List[str])

cg.get_effect_modifiers(treatments, [outcome])
['X0']

class CausalModel[source]

CausalModel(treatments:List[str], outcome:str='Y', common_causes:List[str]=None, effect_modifiers:List[str]=None, instruments:List[str]=None, causal_graph_kwargs=None)

treatments = ['V0',]  # 'V1']
outcome = 'Y'
common_causes = ['W0']
effect_modifiers = ['X0']
instruments = []
observed_nodes = treatments + [outcome] + instruments + effect_modifiers
add_unobserved_confounder = True
missing_nodes_as_confounders = True

cg_kwargs = dict(
    missing_nodes_as_confounders=missing_nodes_as_confounders,
    add_unobserved_confounder=add_unobserved_confounder,
    observed_nodes=observed_nodes
)
cm = CausalModel(treatments=treatments, outcome=outcome, common_causes=common_causes,
                 effect_modifiers=effect_modifiers,
                 causal_graph_kwargs=cg_kwargs)

identify_effect[source]

identify_effect(estimand_type:str='nonparametric-ate')

construct_backdoor[source]

construct_backdoor(treatments:List[str], outcome:str, common_causes:List[str], estimand_type:str='nonparametric-ate')

construct_instrumental_variable[source]

construct_instrumental_variable(treatments:List[str], outcome:str, instruments:List[str], estimand_type:str='nonparametric-ate')

estimands = cm.identify_effect(); estimands
causes: {'treatments': {'U', 'W0'}, 'effects': {'U', 'X0', 'W0'}}
common causes: {'U', 'W0'}
Backdoor: Derivative(Expectation(Y | U,W0), [V0])
Instrumental variable: None
{'observed_common_causes': {'U', 'W0'},
 'backdoor': Derivative(Expectation(Y | U,W0), [V0]),
 'instrumental_variable': None}

Regression estimators based on sklearn regression classes

isinstance(linear_model.LinearRegression(), sklearn.base.RegressorMixin)
True

class RegressionEstimator[source]

RegressionEstimator(model:RegressorMixin)

Sanity checking on a quadratic polynomial toy dataset

X = np.linspace(-1, 1, 200)
X = np.array([X**2, X, np.ones(len(X))*.5]).T
w = np.array([2, 0, .5])
y = X @ w

fig, ax = plt.subplots(figsize=(8,4), constrained_layout=True)
ax.scatter(X[:,-2], y)
ax.set(xlabel='x', ylabel='y', title='dataset')
plt.show()
regression_model = linear_model.LinearRegression()
estimator = RegressionEstimator(regression_model)
estimator.fit(X, y, ix=0, ix_confounders=[1])
ate = estimator.estimate_effect(X=X, treatment=1, control=0)
print(f'ate = {ate:.3f} coefs {estimator.m.coef_}')
ate = 2.000 coefs [ 2.00000000e+00 -1.27928927e-17]

Classification estimator

propensity score: common causes -> prediction of treatment (class) -> grouping by score to select pairs of most similar treatment and control group samples to compute the difference in outcome

grouping is done using some nearest neighbour search:

  • ATC if nearest neighbor is set up with the treated group and for each control group sample a match is looked up and then the difference of the outcome is computed
  • ATT if nearest neighbor is set up with the control group and for each treated group sample a match is looked up and then the difference of the outcome is computed

TODO: test PropensityScoreMatcher on data generated using bcg.basics classes

n = 200
x_treatment = np.random.choice([True, False], p=[.5, .5], size=n)
x_common_causes = np.array([
    [np.random.normal(loc=v, scale=.1) for v in x_treatment],
    [np.random.normal(loc=10-v, scale=.1) for v in x_treatment],
])

y_outcome = np.array([np.random.normal(loc=v, scale=.1) for v in x_treatment])

fig, axs = plt.subplots(figsize=(8,6), nrows=4, constrained_layout=True)
axs[0].hist(x_treatment.astype(float))
axs[0].set(xlabel='treatment')
axs[1].hist(x_common_causes[0])
axs[1].set(xlabel='cc0')
axs[2].hist(x_common_causes[1])
axs[2].set(xlabel='cc1')
axs[3].hist(y_outcome)
axs[3].set(xlabel='outcome')
plt.show()

fig, ax = plt.subplots(figsize=(8,4), constrained_layout=True)
ax.scatter(x_treatment, y_outcome)
ax.set(xlabel='treatment', ylabel='outcome', title='dataset')
plt.show()

X, y = np.concatenate((x_treatment[:,None], x_common_causes.T), axis=1), y_outcome
X.shape, y.shape
$\displaystyle \left( \left( 200, \ 3\right), \ \left( 200\right)\right)$
class PropensityScoreMatcher:
    
    def __init__(self, propensity_model:sklearn.base.ClassifierMixin):
        assert isinstance(propensity_model, sklearn.base.ClassifierMixin)
        self.pm = propensity_model
        
    def fit(self, X:np.ndarray, y:np.ndarray, ix:int, ix_confounders:List[int], reset:bool=True):
        '''building the classifier model & nearest neigbhor search thingy
        
        ix: needs to point to a binary variable
        '''
        if not isinstance(ix_confounders, list):
            ix_confounders = list(ix_confounders)
        self.ix = ix
        self.ix_confounders = ix_confounders

        _ix = [ix] + ix_confounders
        self._ix = _ix
        if reset:
            self.pm.fit(X[:, self.ix_confounders], X[:,self.ix])
                     
    def estimate_effect(self, X:np.ndarray, treatment:Union[int, bool], control:Union[int, bool],
                        y:np.ndarray=None, kind:str='ate'):
        assert y is not None, 'Cannot be None. That\'s just the default to have consistent method parameters.'
        assert kind in ['ate', 'att', 'atc']
        propensity_score = self.pm.predict(X[:, self.ix_confounders])
        ix_treat, ix_control = X[:,self.ix] == treatment, X[:,self.ix] == control
        
        X_treat, X_cont = X[ix_treat,:], X[ix_control,:]
        y_treat, y_cont = y[ix_treat], y[ix_control]
        
        searcher = neighbors.NearestNeighbors(n_neighbors=1)
        
        def get_att():
            searcher.fit(propensity_score[ix_control][:,None])
            distances, indices = searcher.kneighbors(propensity_score[ix_treat][:,None])
            
            att = 0
            n_treat = ix_treat.sum()
            for i in range(n_treat):
                out_treat = y_treat[i]
                out_cont = y_cont[indices[i][0]]
                att += out_treat - out_cont
            return att / n_treat
        
        def get_atc():
            searcher.fit(propensity_score[ix_treat][:,None])
            distances, indices = searcher.kneighbors(propensity_score[ix_control][:,None])
            
            atc = 0
            n_cont = ix_control.sum()
            for i in range(n_cont):
                out_treat = y_treat[indices[i][0]]
                out_cont = y_cont[i]
                atc += out_treat - out_cont
            return atc / n_cont
        
        def get_ate():
            n_treat = ix_treat.sum()
            n_cont = ix_control.sum()
            att = get_att()
            atc = get_atc()
            return (att*n_treat + atc*n_cont) / (n_treat + n_cont)
        
        if kind == 'ate':
            return get_ate()
        elif kind == 'att':
            return get_att()
        elif kind == 'atc':
            return get_atc()
        else:
            raise NotImplementedError
propensity_model = linear_model.LogisticRegression(solver='lbfgs')
estimator = PropensityScoreMatcher(propensity_model)
estimator.fit(X, y, ix=0, ix_confounders=[1, 2])
ate = estimator.estimate_effect(X=X, treatment=True, control=False, y=y)
print(f'ate = {ate:.3f}')
ate = 0.934

Generating data for the graphical model using bcg.basics functions

outcome_is_binary = True
treatment_is_binary = True

n = 333
n_common_causes = len(common_causes)
n_instruments = len(instruments)
n_eff_mods = len(effect_modifiers)
n_treatments = len(treatments)
beta = 1  # max random value

cc = CommonCauses.get_obs(n, n_common_causes)
ins = Instruments.get_obs(n, n_instruments)
em = EffectModifiers.get_obs(n, n_eff_mods)
treat = Treatments.get_obs(n, n_treatments, cc, ins, beta, treatment_is_binary=treatment_is_binary)
out = Outcomes.get_obs(treat, cc, em, outcome_is_binary=outcome_is_binary)
obs = pd.concat((treat.obs, cc.obs, em.obs, ins.obs, out.obs), axis=1)
X, y, not_target = get_Xy(obs, target=outcome)
obs.head(), obs.tail()
(   V0        W0        X0  Y
 0   1  1.176682  0.468362  0
 1   0 -0.581795 -0.010592  1
 2   0  0.413044  0.275547  1
 3   0 -0.741007 -2.044762  1
 4   1 -0.167759 -2.053056  0,
      V0        W0        X0  Y
 328   0  0.937082  0.408209  0
 329   0 -0.020501  0.657049  0
 330   1  1.063346 -0.548377  0
 331   1  1.241344 -0.537261  1
 332   0  0.667529  0.745335  0)
not_target.index('V0')
$\displaystyle 0$

Adding effect estimate functionality to CausalModel

Changing the implementation of get_Xy, incorporating products with effect modifiers, based on lns 59-71 in causal_estimators/linear_regression_estimator.py with the new argumentfeature_product_groups. The variable is supposed to consist of two lists, each containing features in obs, of which all products will be computed.

get_Xy_with_products[source]

get_Xy_with_products(obs:DataFrame, target:str='Y', feature_product_groups:List[list]=None)

feaure_product_groups (e.g. [["V0", "V1", "W0"], ["X0", "X1"]]) to compute products between each var in the first and second list (not within each list)

get_Xy_with_products(obs, target=outcome, feature_product_groups=[treatments, effect_modifiers])
(array([[ 1.        ,  1.17668216,  0.46836244],
        [ 0.        , -0.58179495, -0.        ],
        [ 0.        ,  0.41304439,  0.        ],
        [ 0.        , -0.7410068 , -0.        ],
        [ 1.        , -0.1677586 , -2.05305636],
        [ 0.        ,  0.01860169,  0.        ],
        [ 0.        ,  0.43848953,  0.        ],
        [ 1.        , -0.87806002, -0.03161624],
        [ 0.        ,  0.69264442,  0.        ],
        [ 0.        , -0.19594743, -0.        ],
        [ 1.        ,  1.00767243,  0.62229535],
        [ 0.        ,  2.50225413, -0.        ],
        [ 1.        , -1.5025282 , -1.04862138],
        [ 0.        ,  0.59197037, -0.        ],
        [ 0.        ,  0.81509076, -0.        ],
        [ 1.        ,  0.82414504,  0.60770411],
        [ 0.        ,  0.47207575, -0.        ],
        [ 1.        ,  0.22302966, -1.34418616],
        [ 0.        ,  0.84251715, -0.        ],
        [ 1.        , -0.88151327,  0.24835875],
        [ 1.        , -0.30735475, -0.04794243],
        [ 0.        , -0.58668586,  0.        ],
        [ 1.        ,  0.27842421, -1.26786009],
        [ 1.        , -1.38543832, -0.21981882],
        [ 0.        ,  0.04636712,  0.        ],
        [ 1.        , -0.27225662, -1.49817642],
        [ 0.        ,  0.69317737, -0.        ],
        [ 0.        ,  0.55967209,  0.        ],
        [ 0.        , -0.28012082,  0.        ],
        [ 0.        ,  0.83954801, -0.        ],
        [ 1.        ,  1.17486959, -0.19711913],
        [ 0.        , -0.42327682, -0.        ],
        [ 1.        ,  1.19543944, -0.63228272],
        [ 1.        , -0.5933412 , -1.49673812],
        [ 0.        , -0.33224539,  0.        ],
        [ 1.        , -0.48398116, -0.14450943],
        [ 1.        ,  0.35354921,  0.03223502],
        [ 0.        ,  0.78950319,  0.        ],
        [ 1.        ,  0.27726173, -0.90122988],
        [ 0.        , -1.02464195,  0.        ],
        [ 1.        , -0.11925939,  0.1559452 ],
        [ 1.        ,  2.6178189 , -0.05506116],
        [ 0.        ,  1.0751829 , -0.        ],
        [ 1.        ,  1.74349409, -1.47659248],
        [ 0.        , -0.11530499,  0.        ],
        [ 1.        ,  1.15353089, -2.97294368],
        [ 1.        , -0.20033576, -1.74212625],
        [ 0.        , -1.38846114,  0.        ],
        [ 1.        , -0.36144888,  0.67152466],
        [ 1.        ,  0.66968925,  0.61413237],
        [ 0.        ,  0.16281224, -0.        ],
        [ 0.        ,  0.21976562, -0.        ],
        [ 1.        ,  1.53845996,  0.54539873],
        [ 0.        , -0.91126766, -0.        ],
        [ 1.        , -0.68694697, -0.67405784],
        [ 1.        , -1.18724641,  0.96144145],
        [ 1.        ,  0.10311262,  0.75524058],
        [ 0.        ,  0.29532879, -0.        ],
        [ 0.        ,  0.540325  ,  0.        ],
        [ 0.        , -2.04097777, -0.        ],
        [ 0.        ,  1.82561622, -0.        ],
        [ 0.        ,  0.09735642,  0.        ],
        [ 0.        , -0.6646908 , -0.        ],
        [ 1.        ,  1.12239077,  0.09404086],
        [ 1.        , -0.12017347,  0.35420544],
        [ 1.        , -0.65994112, -0.62236337],
        [ 0.        ,  1.27624662, -0.        ],
        [ 1.        , -1.51983953, -0.37870379],
        [ 1.        , -0.40253108,  1.22736594],
        [ 0.        ,  1.97445598,  0.        ],
        [ 0.        ,  0.19682415, -0.        ],
        [ 0.        ,  0.2096079 , -0.        ],
        [ 1.        , -0.82997588, -0.30663071],
        [ 0.        , -1.28976469, -0.        ],
        [ 1.        ,  1.42382724, -0.40233017],
        [ 1.        , -1.14401892,  0.55393806],
        [ 0.        , -0.93563609, -0.        ],
        [ 1.        , -0.66266796,  0.01546187],
        [ 1.        , -2.18297248,  0.8571488 ],
        [ 1.        ,  0.61298124,  0.3419954 ],
        [ 0.        ,  0.69951473, -0.        ],
        [ 1.        , -0.49642078, -0.61186666],
        [ 1.        ,  0.58792823, -0.72918475],
        [ 1.        ,  0.33747147, -0.68207888],
        [ 0.        ,  0.7830838 , -0.        ],
        [ 1.        , -0.19651434,  0.41325202],
        [ 0.        , -0.82630505, -0.        ],
        [ 0.        , -1.08308539, -0.        ],
        [ 1.        ,  1.51949011, -0.5132064 ],
        [ 0.        , -0.55609254, -0.        ],
        [ 0.        , -0.55020446, -0.        ],
        [ 1.        , -0.23046882, -1.18978961],
        [ 1.        , -0.83335204, -0.51365536],
        [ 0.        , -0.75398382, -0.        ],
        [ 0.        ,  1.0331082 , -0.        ],
        [ 0.        , -1.90836353, -0.        ],
        [ 1.        , -0.24204649, -1.16788193],
        [ 0.        , -0.48366006, -0.        ],
        [ 1.        ,  1.9609003 , -0.10092163],
        [ 1.        , -0.10930367,  0.19715841],
        [ 1.        , -0.91681464, -1.6097447 ],
        [ 0.        , -1.96662707, -0.        ],
        [ 0.        ,  1.56986699,  0.        ],
        [ 1.        , -0.01387704,  0.65497536],
        [ 0.        , -0.68650505, -0.        ],
        [ 0.        ,  0.21987463, -0.        ],
        [ 0.        , -1.30508432, -0.        ],
        [ 1.        ,  1.61626745, -1.02732012],
        [ 1.        , -0.31633112, -1.61368864],
        [ 1.        , -0.83786904,  1.04923957],
        [ 0.        ,  2.02425014, -0.        ],
        [ 0.        , -1.18013087,  0.        ],
        [ 1.        , -0.53768099,  0.35227688],
        [ 0.        , -0.73424031,  0.        ],
        [ 1.        ,  0.87101132,  0.22893354],
        [ 0.        ,  0.37984918, -0.        ],
        [ 0.        ,  0.66534242, -0.        ],
        [ 1.        , -1.64420658, -1.95857656],
        [ 1.        ,  0.1383424 , -1.55802608],
        [ 0.        , -0.74683721, -0.        ],
        [ 1.        , -0.43560403,  1.94214118],
        [ 1.        , -1.06956346, -0.5789204 ],
        [ 1.        ,  1.99088147, -0.14986403],
        [ 0.        , -0.12728962, -0.        ],
        [ 0.        ,  1.99533371, -0.        ],
        [ 1.        ,  0.1353928 , -1.37197832],
        [ 0.        ,  1.37546801,  0.        ],
        [ 0.        , -1.08458817,  0.        ],
        [ 0.        ,  0.25768835,  0.        ],
        [ 1.        , -1.97007802, -0.17201895],
        [ 0.        ,  0.88053984,  0.        ],
        [ 1.        ,  0.96000679, -0.33383015],
        [ 0.        , -0.02423215,  0.        ],
        [ 0.        ,  1.62918072, -0.        ],
        [ 0.        , -0.35953656, -0.        ],
        [ 1.        , -3.33968888, -0.1151471 ],
        [ 1.        , -1.22286802, -1.08098558],
        [ 1.        , -0.27731606, -1.75435442],
        [ 1.        , -0.14299429, -1.4113362 ],
        [ 0.        , -0.79238418,  0.        ],
        [ 0.        ,  0.15492188, -0.        ],
        [ 1.        ,  0.24196779, -1.23409943],
        [ 1.        ,  0.50064661, -0.11293217],
        [ 1.        , -0.37903674, -1.85947872],
        [ 0.        , -2.12054598, -0.        ],
        [ 0.        , -2.69900803, -0.        ],
        [ 1.        , -1.04118005, -0.87282059],
        [ 1.        , -1.26530384,  0.03916189],
        [ 1.        , -0.08328762, -0.93456518],
        [ 1.        , -0.57943718,  0.64581655],
        [ 1.        ,  0.45377051, -1.2342177 ],
        [ 1.        ,  0.77285238,  0.21019265],
        [ 1.        ,  1.99209498, -1.75214351],
        [ 0.        , -1.5709821 , -0.        ],
        [ 1.        ,  0.57874377, -0.49886028],
        [ 1.        ,  0.52104172, -1.6131055 ],
        [ 0.        , -1.29674892, -0.        ],
        [ 1.        ,  0.25799875, -0.28173773],
        [ 1.        ,  0.09164494, -0.93814027],
        [ 0.        , -0.19604978,  0.        ],
        [ 1.        ,  2.40007948,  0.13682563],
        [ 0.        , -0.31576543, -0.        ],
        [ 1.        ,  1.27486599,  1.58551915],
        [ 0.        , -0.0988648 , -0.        ],
        [ 0.        ,  0.57158242, -0.        ],
        [ 0.        , -0.46094401,  0.        ],
        [ 0.        ,  0.2513693 ,  0.        ],
        [ 0.        , -1.19324532, -0.        ],
        [ 0.        ,  0.16327776, -0.        ],
        [ 0.        , -0.55345049, -0.        ],
        [ 0.        ,  0.88697158,  0.        ],
        [ 1.        ,  0.31997452, -1.81553901],
        [ 1.        ,  1.20021945, -1.44244412],
        [ 1.        ,  1.47430806, -0.29657584],
        [ 1.        ,  2.2577756 , -0.47501888],
        [ 0.        , -1.95906451, -0.        ],
        [ 1.        , -0.33739542,  1.47269935],
        [ 1.        ,  0.11656465,  0.54634721],
        [ 1.        ,  0.56925007,  2.19265007],
        [ 0.        , -0.84636977,  0.        ],
        [ 1.        , -0.20755873,  0.52894867],
        [ 1.        ,  0.86748963,  0.25236621],
        [ 1.        ,  0.77730636,  0.88968126],
        [ 0.        ,  0.83362934, -0.        ],
        [ 0.        ,  1.2010197 , -0.        ],
        [ 0.        , -1.78993312,  0.        ],
        [ 0.        , -0.88537341, -0.        ],
        [ 1.        ,  0.6281433 ,  0.52400025],
        [ 1.        ,  0.1306515 , -1.51775631],
        [ 0.        , -1.11019716, -0.        ],
        [ 1.        , -0.1040678 ,  0.81633243],
        [ 0.        , -2.2783055 ,  0.        ],
        [ 0.        ,  0.32917672, -0.        ],
        [ 0.        ,  0.30402605, -0.        ],
        [ 1.        ,  0.65239543, -0.03703682],
        [ 0.        , -0.70969102, -0.        ],
        [ 1.        ,  0.76836166,  1.09724708],
        [ 0.        , -0.95307736, -0.        ],
        [ 1.        , -0.60126687, -0.68932624],
        [ 1.        ,  0.2611847 , -0.51854676],
        [ 1.        ,  0.86086798, -0.99311398],
        [ 0.        , -0.60212345, -0.        ],
        [ 0.        , -0.21489864, -0.        ],
        [ 1.        ,  0.16784098, -0.96451859],
        [ 1.        ,  0.99933976, -0.92318495],
        [ 1.        , -0.77836613, -0.27311638],
        [ 1.        , -0.58136406,  1.129748  ],
        [ 0.        , -0.99482575, -0.        ],
        [ 0.        , -2.26833038,  0.        ],
        [ 0.        , -0.49140647,  0.        ],
        [ 0.        ,  0.3417349 ,  0.        ],
        [ 0.        ,  0.81352755,  0.        ],
        [ 1.        ,  1.04808466, -1.95391987],
        [ 0.        ,  0.03739659, -0.        ],
        [ 1.        ,  1.90839777, -0.97396579],
        [ 1.        ,  0.35613674, -1.19637825],
        [ 1.        , -1.54520658, -0.47753458],
        [ 1.        ,  0.38647342,  0.31642644],
        [ 1.        , -0.26497912,  1.07614905],
        [ 0.        , -1.18843513, -0.        ],
        [ 0.        , -0.18649098, -0.        ],
        [ 1.        , -0.95641689, -0.53847434],
        [ 0.        , -0.62785403, -0.        ],
        [ 0.        ,  0.54499987,  0.        ],
        [ 1.        , -0.3661312 , -1.00765035],
        [ 0.        , -1.84270122, -0.        ],
        [ 1.        , -1.33943218, -0.6999864 ],
        [ 1.        , -0.1482295 , -0.50383975],
        [ 1.        , -0.30679679, -1.6769    ],
        [ 0.        ,  0.6370456 , -0.        ],
        [ 1.        ,  1.31814802, -0.06917642],
        [ 1.        ,  1.94500489,  0.707354  ],
        [ 0.        , -1.82858479, -0.        ],
        [ 0.        ,  1.91970659, -0.        ],
        [ 0.        ,  0.19284009,  0.        ],
        [ 1.        ,  0.60358981, -1.69540105],
        [ 0.        , -0.65389405,  0.        ],
        [ 0.        , -0.99792959,  0.        ],
        [ 0.        , -0.61591353,  0.        ],
        [ 0.        ,  0.9736039 ,  0.        ],
        [ 0.        , -2.2061269 , -0.        ],
        [ 0.        , -1.18717205, -0.        ],
        [ 0.        , -2.96283253,  0.        ],
        [ 0.        , -0.16640058,  0.        ],
        [ 1.        , -0.80142357,  0.80451183],
        [ 0.        ,  0.7034448 ,  0.        ],
        [ 1.        , -0.65441058, -1.11367468],
        [ 0.        ,  1.9866488 ,  0.        ],
        [ 0.        , -0.12074271,  0.        ],
        [ 0.        ,  0.85874995,  0.        ],
        [ 0.        , -0.23471034, -0.        ],
        [ 1.        , -0.69421745, -0.37865228],
        [ 0.        , -1.29753069, -0.        ],
        [ 1.        , -0.12132465, -2.64655954],
        [ 0.        ,  0.5908582 , -0.        ],
        [ 0.        , -0.13259458, -0.        ],
        [ 1.        ,  0.12074633, -2.19508328],
        [ 0.        ,  0.98194261, -0.        ],
        [ 0.        ,  0.76501008,  0.        ],
        [ 0.        ,  0.47128088,  0.        ],
        [ 0.        , -1.47328253, -0.        ],
        [ 1.        , -1.91590797, -0.96676211],
        [ 0.        , -0.46992243, -0.        ],
        [ 1.        , -0.54882431,  0.21779183],
        [ 1.        , -0.58274857, -2.89552988],
        [ 1.        ,  0.50685665, -1.53525251],
        [ 0.        , -0.80810694, -0.        ],
        [ 1.        , -0.22109691, -0.85674545],
        [ 0.        ,  0.57368129,  0.        ],
        [ 0.        , -0.86655538,  0.        ],
        [ 0.        ,  1.28184488,  0.        ],
        [ 0.        , -0.40320565,  0.        ],
        [ 0.        ,  0.12229881,  0.        ],
        [ 0.        ,  0.53200215, -0.        ],
        [ 0.        , -1.39066443, -0.        ],
        [ 0.        , -0.08015956, -0.        ],
        [ 1.        , -0.18230888, -0.2610354 ],
        [ 0.        ,  0.13218753, -0.        ],
        [ 1.        , -0.48240024,  0.28261661],
        [ 0.        ,  0.58114232, -0.        ],
        [ 1.        ,  0.24950259,  0.34883739],
        [ 0.        ,  0.23214457,  0.        ],
        [ 1.        ,  0.4265521 , -0.8793497 ],
        [ 1.        , -0.09835993, -0.23177168],
        [ 1.        ,  0.31605596, -0.27188055],
        [ 0.        ,  0.49403607, -0.        ],
        [ 1.        ,  0.03023592, -0.56611124],
        [ 0.        , -2.16333854,  0.        ],
        [ 1.        , -0.56309567,  1.3852228 ],
        [ 1.        , -0.11057424, -0.7816743 ],
        [ 0.        ,  0.92326045,  0.        ],
        [ 1.        , -1.57325458, -2.17917003],
        [ 0.        ,  0.51388745, -0.        ],
        [ 1.        , -1.33023702, -0.98718309],
        [ 1.        , -0.00646919, -1.45698329],
        [ 1.        , -0.68117076,  0.11792957],
        [ 0.        , -0.50013649, -0.        ],
        [ 1.        , -2.29327951, -1.69190102],
        [ 0.        , -0.10177061, -0.        ],
        [ 1.        ,  0.33030482, -0.99170174],
        [ 1.        , -1.20297097, -0.40180826],
        [ 0.        , -0.42249472, -0.        ],
        [ 1.        ,  0.74796245, -0.40187473],
        [ 0.        ,  0.29643631,  0.        ],
        [ 1.        , -0.9387942 ,  0.0142777 ],
        [ 0.        , -1.39634321,  0.        ],
        [ 1.        ,  0.74506433, -1.77284876],
        [ 1.        , -0.22198416,  1.61809186],
        [ 0.        , -0.60407254, -0.        ],
        [ 1.        , -0.43769244, -0.98667038],
        [ 1.        , -0.65880241, -0.65922251],
        [ 1.        ,  1.19145053, -1.51247836],
        [ 1.        , -2.42059679, -0.83186204],
        [ 1.        , -0.89557387, -0.096262  ],
        [ 0.        , -0.17293205,  0.        ],
        [ 0.        , -0.32774982,  0.        ],
        [ 0.        , -0.28746924,  0.        ],
        [ 1.        , -0.67139085,  1.60276736],
        [ 0.        ,  0.0357559 , -0.        ],
        [ 0.        ,  0.07941142,  0.        ],
        [ 0.        ,  0.24710014, -0.        ],
        [ 1.        , -0.79006552, -0.27965665],
        [ 0.        ,  0.21546596, -0.        ],
        [ 0.        , -1.11276003, -0.        ],
        [ 1.        ,  0.20444365, -1.07183991],
        [ 1.        , -0.15059638, -2.32557143],
        [ 0.        ,  0.45714317, -0.        ],
        [ 0.        ,  1.27504927, -0.        ],
        [ 0.        ,  0.93708227,  0.        ],
        [ 0.        , -0.02050105,  0.        ],
        [ 1.        ,  1.06334631, -0.54837674],
        [ 1.        ,  1.24134447, -0.53726089],
        [ 0.        ,  0.66752857,  0.        ]]),
 array([0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
        1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1,
        1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,
        0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0,
        1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1,
        1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0,
        1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0,
        1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1,
        0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1,
        0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1,
        1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
        1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0,
        0, 1, 0]),
 ['V0', 'W0', 'V0_X0'])

estimate_effect[source]

estimate_effect(estimands:dict, control_value:float, treatment_name:str, treatment_value:float, obs:DataFrame, outcome:str='Y', causal_method:str='backdoor', model:Union[RegressorMixin, ClassifierMixin]=None, target_unit:str='ate', effect_modifiers:List[str]=None, supervised_type_is_regression:bool=True)

causal_method = 'backdoor'
control_value = 0
treatment_name = 'v0'
treatment_value = 2

effect_modifiers = effect_modifiers
target_unit = 'ate'

# model = linear_model.LinearRegression()
# model = linear_model.LogisticRegression()
model = None
supervised_type_is_regression = False

cm.estimate_effect(estimands, control_value, treatment_name, treatment_value,
                   obs, outcome=outcome, causal_method=causal_method, model=model,
                   target_unit=target_unit, effect_modifiers=effect_modifiers, 
                   supervised_type_is_regression=supervised_type_is_regression)
model None
confounders ['V0', 'U', 'W0', 'X0']
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-223-9adab0ab06b8> in <module>
     15                    obs, outcome=outcome, causal_method=causal_method, model=model,
     16                    target_unit=target_unit, effect_modifiers=effect_modifiers,
---> 17                    supervised_type_is_regression=supervised_type_is_regression)

<ipython-input-222-7ad792b78eaa> in estimate_effect(self, estimands, control_value, treatment_name, treatment_value, obs, outcome, causal_method, model, target_unit, effect_modifiers, supervised_type_is_regression)
     33     ix_confounders = [_i for _i,_v in enumerate(obs.columns.values) if _v in confounders]
     34     estimator.fit(X, y, ix, ix_confounders)
---> 35     effect = estimator.estimate_effect(X=X, treatment=treatment_value, control=control_value, y=y)
     36     return effect
     37 

<ipython-input-151-0802d1b99830> in estimate_effect(self, X, treatment, control, y, kind)
     64 
     65         if kind == 'ate':
---> 66             return get_ate()
     67         elif kind == 'att':
     68             return get_att()

<ipython-input-151-0802d1b99830> in get_ate()
     59             n_treat = ix_treat.sum()
     60             n_cont = ix_control.sum()
---> 61             att = get_att()
     62             atc = get_atc()
     63             return (att*n_treat + atc*n_cont) / (n_treat + n_cont)

<ipython-input-151-0802d1b99830> in get_att()
     34         def get_att():
     35             searcher.fit(propensity_score[ix_control][:,None])
---> 36             distances, indices = searcher.kneighbors(propensity_score[ix_treat][:,None])
     37 
     38             att = 0

/mnt/c/Programs_wsl/anaconda3/envs/py37_dowhy/lib/python3.7/site-packages/sklearn/neighbors/base.py in kneighbors(self, X, n_neighbors, return_distance)
    400         if X is not None:
    401             query_is_train = False
--> 402             X = check_array(X, accept_sparse='csr')
    403         else:
    404             query_is_train = True

/mnt/c/Programs_wsl/anaconda3/envs/py37_dowhy/lib/python3.7/site-packages/sklearn/utils/validation.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator)
    548                              " minimum of %d is required%s."
    549                              % (n_samples, array.shape, ensure_min_samples,
--> 550                                 context))
    551 
    552     if ensure_min_features > 0 and array.ndim == 2:

ValueError: Found array with 0 sample(s) (shape=(0, 1)) while a minimum of 1 is required.
a = np.linspace(1, 4, 5)
b = a[:, np.newaxis]; b