probabilistic_model.probabilistic_circuit.jax.coupling_circuit#

Classes#

Conditioner

Interface for a conditioner that generates parameters for a circuit.

CouplingCircuit

A probabilistic circuit that uses a function to generate parameters for a circuit.

LinearConditioner

A simple linear conditioner that generates parameters for a circuit.

Module Contents#

class probabilistic_model.probabilistic_circuit.jax.coupling_circuit.Conditioner#

Interface for a conditioner that generates parameters for a circuit.

abstract generate_parameters(x: jax.Array) jax.Array#

Generate parameters for a circuit given an input.

Parameters:

x – The input to the conditioner.

Returns:

The parameters for the circuit.

abstract property output_length#
Returns:

The length number of parameters that the model outputs.

class probabilistic_model.probabilistic_circuit.jax.coupling_circuit.CouplingCircuit(conditioner: Conditioner, conditioner_columns: jax.Array, circuit: probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit.Layer, circuit_columns)#

Bases: equinox.Module

A probabilistic circuit that uses a function to generate parameters for a circuit. Simply speaking, this represents P(y, theta=f(x)).

conditioner: Conditioner#

The conditioner that generates the parameters for the circuit.

circuit: probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit.Layer#

The circuit to generate the parameters for.

conditioner_columns: jax.Array#

The columns in a matrix that the conditioner takes as input for producing circuit parameters.

circuit_columns: jax.Array#

The columns in a matrix that the circuit takes as input for calculating likelihoods.

partition_circuit()#

Partition the circuit into the parameters and the static structure. :return:

property slices_of_parameters_for_flat_model: typing_extensions.List[typing_extensions.Tuple[int, int]]#
Returns:

The slices that can be used to partition the parameters that are generated by the conditioner into the structure of the circuit.

create_circuit_from_parameters(params: jax.Array) probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit.Layer#

Generate a circuit with the structure from self.circuit and the parameters from params. :param params: The parameters to be used in the circuit. :return: The circuit

conditional_log_likelihood_single(x)#

Calculate the truncated log likelihood of a single data point.

Parameters:

x – The datapoint

Returns:

The truncated log likelihood of the data point

conditional_log_likelihood(x)#
validate()#

Check if the output of the conditioner matches the parametrization of the circuit.

class probabilistic_model.probabilistic_circuit.jax.coupling_circuit.LinearConditioner(in_features: int, out_features: int)#

Bases: equinox.Module, Conditioner

A simple linear conditioner that generates parameters for a circuit.

linear: equinox.nn.Linear#
generate_parameters(x: jax.Array) jax.Array#

Generate parameters for a circuit given an input.

Parameters:

x – The input to the conditioner.

Returns:

The parameters for the circuit.

property output_length#
Returns:

The length number of parameters that the model outputs.