Skip to content

Inference Engine

InferenceEngine

Class representing an inference engine for a given model.

Parameters:

Name Type Description Default
model Callable

The model to be used for inference.

required
rng_key Optional[PRNGKey]

The random number generator key. If not provided, a default key with value 0 will be used.

None

Attributes:

Name Type Description
model Callable

The model used for inference.

rng_key PRNGKey

The random number generator key.

Source code in src/prophetverse/engine.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class InferenceEngine:
    """Class representing an inference engine for a given model.

    Args:
        model (Callable): The model to be used for inference.
        rng_key (Optional[jax.random.PRNGKey]): The random number generator key. 
            If not provided, a default key with value 0 will be used.

    Attributes:
        model (Callable): The model used for inference.
        rng_key (jax.random.PRNGKey): The random number generator key.

    """

    def __init__(self, model: Callable, rng_key=None):
        self.model = model
        if rng_key is None:
            rng_key = jax.random.PRNGKey(0)
        self.rng_key = rng_key

    def infer(self, **kwargs): 
        """Performs inference using the specified model.

        Args:
            **kwargs: Additional keyword arguments to be passed to the model.

        Returns:
            The result of the inference.

        """
        ...

    def predict(self, **kwargs): 
        """Generates predictions using the specified model.

        Args:
            **kwargs: Additional keyword arguments to be passed to the model.

        Returns:
            The predictions generated by the model.

        """
        ...

infer(**kwargs)

Performs inference using the specified model.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the model.

{}

Returns:

Type Description

The result of the inference.

Source code in src/prophetverse/engine.py
32
33
34
35
36
37
38
39
40
41
42
def infer(self, **kwargs): 
    """Performs inference using the specified model.

    Args:
        **kwargs: Additional keyword arguments to be passed to the model.

    Returns:
        The result of the inference.

    """
    ...

predict(**kwargs)

Generates predictions using the specified model.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the model.

{}

Returns:

Type Description

The predictions generated by the model.

Source code in src/prophetverse/engine.py
44
45
46
47
48
49
50
51
52
53
54
def predict(self, **kwargs): 
    """Generates predictions using the specified model.

    Args:
        **kwargs: Additional keyword arguments to be passed to the model.

    Returns:
        The predictions generated by the model.

    """
    ...

MAPInferenceEngine

Bases: InferenceEngine

Maximum a Posteriori (MAP) Inference Engine.

This class performs MAP inference using Stochastic Variational Inference (SVI) with AutoDelta guide. It provides methods for inference and prediction.

Parameters:

Name Type Description Default
model Callable

The probabilistic model to perform inference on.

required
optimizer _NumPyroOptim

The optimizer to use for SVI. Defaults to None.

required
num_steps int

The number of optimization steps to perform. Defaults to 10000.

10000
rng_key PRNGKey

The random number generator key. Defaults to None.

None
Source code in src/prophetverse/engine.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class MAPInferenceEngine(InferenceEngine):
    """
    Maximum a Posteriori (MAP) Inference Engine.

    This class performs MAP inference using Stochastic Variational Inference (SVI) with AutoDelta guide.
    It provides methods for inference and prediction.

    Args:
        model (Callable): The probabilistic model to perform inference on.
        optimizer (numpyro.optim._NumPyroOptim, optional): The optimizer to use for SVI. Defaults to None.
        num_steps (int, optional): The number of optimization steps to perform. Defaults to 10000.
        rng_key (jax.random.PRNGKey, optional): The random number generator key. Defaults to None.
    """

    def __init__(
        self,
        model: Callable,
        optimizer_factory: numpyro.optim._NumPyroOptim = None,
        num_steps=10000,
        num_samples=_DEFAULT_PREDICT_NUM_SAMPLES,
        rng_key=None,
    ):
        if optimizer_factory is None:
            optimizer_factory = self.default_optimizer_factory
        self.optimizer_factory = optimizer_factory
        self.num_steps = num_steps
        self.num_samples = num_samples
        super().__init__(model, rng_key)

    def default_optimizer_factory(self):
        return numpyro.optim.Adam(step_size=0.001)

    def infer(self, **kwargs):
        """
        Perform MAP inference.

        Args:
            **kwargs: Additional keyword arguments to be passed to the model.

        Returns:
            self: The updated MAPInferenceEngine object.
        """
        self.guide_ = AutoDelta(self.model, init_loc_fn=init_to_mean())
        svi_ = SVI(self.model, self.guide_, self.optimizer_factory(), loss=Trace_ELBO())
        self.run_results_ = svi_.run(
            rng_key=self.rng_key, num_steps=self.num_steps, **kwargs
        )
        self.posterior_samples_ = self.guide_.sample_posterior(self.rng_key, params=self.run_results_.params, **kwargs)
        return self

    def predict(self, **kwargs):
        """
        Generate predictions using the trained model.

        Args:
            **kwargs: Additional keyword arguments to be passed to the model.

        Returns:
            self.samples_: The predicted samples generated by the model.
        """
        predictive = numpyro.infer.Predictive(
            self.model,
            params=self.run_results_.params,
            guide=self.guide_,
            #posterior_samples=self.posterior_samples_,
            num_samples=self.num_samples,
        )
        self.samples_ = predictive(
            rng_key=self.rng_key,
            **kwargs
        )
        return self.samples_

infer(**kwargs)

Perform MAP inference.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the model.

{}

Returns:

Name Type Description
self

The updated MAPInferenceEngine object.

Source code in src/prophetverse/engine.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def infer(self, **kwargs):
    """
    Perform MAP inference.

    Args:
        **kwargs: Additional keyword arguments to be passed to the model.

    Returns:
        self: The updated MAPInferenceEngine object.
    """
    self.guide_ = AutoDelta(self.model, init_loc_fn=init_to_mean())
    svi_ = SVI(self.model, self.guide_, self.optimizer_factory(), loss=Trace_ELBO())
    self.run_results_ = svi_.run(
        rng_key=self.rng_key, num_steps=self.num_steps, **kwargs
    )
    self.posterior_samples_ = self.guide_.sample_posterior(self.rng_key, params=self.run_results_.params, **kwargs)
    return self

predict(**kwargs)

Generate predictions using the trained model.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the model.

{}

Returns:

Type Description

self.samples_: The predicted samples generated by the model.

Source code in src/prophetverse/engine.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def predict(self, **kwargs):
    """
    Generate predictions using the trained model.

    Args:
        **kwargs: Additional keyword arguments to be passed to the model.

    Returns:
        self.samples_: The predicted samples generated by the model.
    """
    predictive = numpyro.infer.Predictive(
        self.model,
        params=self.run_results_.params,
        guide=self.guide_,
        #posterior_samples=self.posterior_samples_,
        num_samples=self.num_samples,
    )
    self.samples_ = predictive(
        rng_key=self.rng_key,
        **kwargs
    )
    return self.samples_

MCMCInferenceEngine

Bases: InferenceEngine

MCMCInferenceEngine is a class that performs MCMC (Markov Chain Monte Carlo) inference for a given model.

Parameters:

Name Type Description Default
model Callable

The model function to perform inference on.

required
num_samples int

The number of MCMC samples to draw.

1000
num_warmup int

The number of warmup samples to discard.

200
num_chains int

The number of MCMC chains to run in parallel.

1
dense_mass bool

Whether to use dense mass matrix for NUTS sampler.

False
rng_key Optional

The random number generator key.

None

Attributes:

Name Type Description
num_samples int

The number of MCMC samples to draw.

num_warmup int

The number of warmup samples to discard.

num_chains int

The number of MCMC chains to run in parallel.

dense_mass bool

Whether to use dense mass matrix for NUTS sampler.

mcmc_ MCMC

The MCMC object used for inference.

posterior_samples_ Dict[str, ndarray]

The posterior samples obtained from MCMC.

samples_predictive_ Dict[str, ndarray]

The predictive samples obtained from MCMC.

samples_ Dict[str, ndarray]

The MCMC samples obtained from MCMC.

Source code in src/prophetverse/engine.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
class MCMCInferenceEngine(InferenceEngine):
    """
    MCMCInferenceEngine is a class that performs MCMC (Markov Chain Monte Carlo) inference
    for a given model.

    Args:
        model (Callable): The model function to perform inference on.
        num_samples (int): The number of MCMC samples to draw.
        num_warmup (int): The number of warmup samples to discard.
        num_chains (int): The number of MCMC chains to run in parallel.
        dense_mass (bool): Whether to use dense mass matrix for NUTS sampler.
        rng_key (Optional): The random number generator key.

    Attributes:
        num_samples (int): The number of MCMC samples to draw.
        num_warmup (int): The number of warmup samples to discard.
        num_chains (int): The number of MCMC chains to run in parallel.
        dense_mass (bool): Whether to use dense mass matrix for NUTS sampler.
        mcmc_ (MCMC): The MCMC object used for inference.
        posterior_samples_ (Dict[str, np.ndarray]): The posterior samples obtained from MCMC.
        samples_predictive_ (Dict[str, np.ndarray]): The predictive samples obtained from MCMC.
        samples_ (Dict[str, np.ndarray]): The MCMC samples obtained from MCMC.

    """

    def __init__(
        self,
        model: Callable,
        num_samples=1000,
        num_warmup=200,
        num_chains=1,
        dense_mass=False,
        rng_key=None,
    ):
        self.num_samples = num_samples
        self.num_warmup = num_warmup
        self.num_chains = num_chains
        self.dense_mass = dense_mass
        super().__init__(model, rng_key)

    def infer(self, **kwargs):
        """
        Run MCMC inference.

        Args:
            **kwargs: Additional keyword arguments to be passed to the MCMC run method.

        Returns:
            self: The MCMCInferenceEngine object.

        """
        self.mcmc_ = MCMC(
            NUTS(self.model, dense_mass=self.dense_mass, init_strategy=init_to_mean()),
            num_samples=self.num_samples,
            num_warmup=self.num_warmup,
        )
        self.mcmc_.run(self.rng_key, **kwargs)
        self.posterior_samples_ = self.mcmc_.get_samples()
        return self

    def predict(self, **kwargs):
        """
        Generate predictive samples.

        Args:
            **kwargs: Additional keyword arguments to be passed to the Predictive method.

        Returns:
            Dict[str, np.ndarray]: The predictive samples.

        """

        predictive = Predictive(self.model, self.posterior_samples_, num_samples=self.num_samples)

        self.samples_predictive_ = predictive(self.rng_key, **kwargs)
        self.samples_ = self.mcmc_.get_samples()
        return self.samples_predictive_

infer(**kwargs)

Run MCMC inference.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the MCMC run method.

{}

Returns:

Name Type Description
self

The MCMCInferenceEngine object.

Source code in src/prophetverse/engine.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def infer(self, **kwargs):
    """
    Run MCMC inference.

    Args:
        **kwargs: Additional keyword arguments to be passed to the MCMC run method.

    Returns:
        self: The MCMCInferenceEngine object.

    """
    self.mcmc_ = MCMC(
        NUTS(self.model, dense_mass=self.dense_mass, init_strategy=init_to_mean()),
        num_samples=self.num_samples,
        num_warmup=self.num_warmup,
    )
    self.mcmc_.run(self.rng_key, **kwargs)
    self.posterior_samples_ = self.mcmc_.get_samples()
    return self

predict(**kwargs)

Generate predictive samples.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the Predictive method.

{}

Returns:

Type Description

Dict[str, np.ndarray]: The predictive samples.

Source code in src/prophetverse/engine.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def predict(self, **kwargs):
    """
    Generate predictive samples.

    Args:
        **kwargs: Additional keyword arguments to be passed to the Predictive method.

    Returns:
        Dict[str, np.ndarray]: The predictive samples.

    """

    predictive = Predictive(self.model, self.posterior_samples_, num_samples=self.num_samples)

    self.samples_predictive_ = predictive(self.rng_key, **kwargs)
    self.samples_ = self.mcmc_.get_samples()
    return self.samples_predictive_