probabilistic_model.probabilistic_circuit.jax.utils#

Functions#

copy_bcoo(→ jax.experimental.sparse.BCOO)

copy_bcsr(→ jax.experimental.sparse.BCSR)

simple_interval_to_open_array(→ jax.numpy.array)

create_bcoo_indices_from_row_lengths(→ numpy.array)

Create the indices of a BCOO array with the given row lengths.

create_bcoo_indices_from_row_lengths_np(→ numpy.array)

Create the indices of a BCOO array with the given row lengths.

create_bcsr_indices_from_row_lengths(...)

Create the column indices and indent pointer of bcsr array with the given row lengths.

embed_sparse_array_in_nan_array(→ jax.Array)

sample_from_sparse_probabilities_csc(...)

Sample from a sparse array of probabilities.

remove_rows_and_cols_where_all(→ jax.Array)

Remove rows and columns from an array where all elements are equal to a given value.

shrink_index_array(→ jax.Array)

Shrink an index array to only contain successive indices.

sparse_remove_rows_and_cols_where_all(...)

Remove rows and columns from a sparse tensor where all elements are equal to a given value.

Module Contents#

probabilistic_model.probabilistic_circuit.jax.utils.copy_bcoo(x: jax.experimental.sparse.BCOO) jax.experimental.sparse.BCOO#
probabilistic_model.probabilistic_circuit.jax.utils.copy_bcsr(x: jax.experimental.sparse.BCSR) jax.experimental.sparse.BCSR#
probabilistic_model.probabilistic_circuit.jax.utils.simple_interval_to_open_array(interval: random_events.interval.SimpleInterval) jax.numpy.array#
probabilistic_model.probabilistic_circuit.jax.utils.create_bcoo_indices_from_row_lengths(row_lengths: numpy.array) numpy.array#

Create the indices of a BCOO array with the given row lengths.

The shape of the indices is (2, sum(row_lengths)). The shape of the sparse tensor that the indices describe should be (len(row_lengths), max(row_lengths)).

Example:

>>> row_lengths = jnp.array([2, 3])
>>> create_bcoo_indices_from_row_lengths(row_lengths)
    [[0 0]
     [0 1]
     [1 0]
     [1 1]
     [1 2]]
Parameters:

row_lengths – The row lengths.

Returns:

The indices of the sparse tensor

probabilistic_model.probabilistic_circuit.jax.utils.create_bcoo_indices_from_row_lengths_np(row_lengths: numpy.array) numpy.array#

Create the indices of a BCOO array with the given row lengths.

The shape of the indices is (2, sum(row_lengths)). The shape of the sparse tensor that the indices describe should be (len(row_lengths), max(row_lengths)).

Example:

>>> row_lengths = jnp.array([2, 3])
>>> create_bcoo_indices_from_row_lengths(row_lengths)
    [[0 0]
     [0 1]
     [1 0]
     [1 1]
     [1 2]]
Parameters:

row_lengths – The row lengths.

Returns:

The indices of the sparse tensor

probabilistic_model.probabilistic_circuit.jax.utils.create_bcsr_indices_from_row_lengths(row_lengths: jax.Array) typing_extensions.Tuple[jax.Array, jax.Array]#

Create the column indices and indent pointer of bcsr array with the given row lengths.

The shape of the sparse tensor that the indices describe should be (len(row_lengths), max(row_lengths)).

Example:

>>> row_lengths = jnp.array([2, 3])
>>> create_bcsr_indices_from_row_lengths(row_lengths)
(Array([0, 1, 0, 1, 2], dtype=int32), Array([0, 2, 5], dtype=int32))
Parameters:

row_lengths – The row lengths.

Returns:

The indices of the sparse tensor

probabilistic_model.probabilistic_circuit.jax.utils.embed_sparse_array_in_nan_array(sparse_array: jax.experimental.sparse.BCOO) jax.Array#
probabilistic_model.probabilistic_circuit.jax.utils.sample_from_sparse_probabilities_csc(probabilities: scipy.sparse.csr_array, amount: numpy.array) scipy.sparse.csc_array#

Sample from a sparse array of probabilities. Each row in the sparse array encodes a categorical probability distribution.

Parameters:
  • probabilities – The sparse array of probabilities.

  • amount – The amount of samples to draw from each row.

Returns:

The samples that are drawn for each state in the probabilities indicies.

probabilistic_model.probabilistic_circuit.jax.utils.remove_rows_and_cols_where_all(array: jax.Array, value: float) jax.Array#

Remove rows and columns from an array where all elements are equal to a given value.

Parameters:
  • array – The tensor to remove rows and columns from.

  • value – The value to remove.

Returns:

The tensor without the rows and columns.

Example:

>>> a = jnp.array([[1, 0, 3], [0, 0, 0], [7, 0, 9]])
>>> remove_rows_and_cols_where_all(a, 0)
array([[1, 3], [7, 9]])
probabilistic_model.probabilistic_circuit.jax.utils.shrink_index_array(index_array: jax.Array) jax.Array#

Shrink an index array to only contain successive indices.

Example:

>>> shrink_index_array(jnp.array([[0, 3], [1, 0], [4, 1]]))
    [[0 2]
     [1 0]
     [2 1]]
Parameters:

index_array – The index tensor to shrink.

Returns:

The shrunken index tensor.

probabilistic_model.probabilistic_circuit.jax.utils.sparse_remove_rows_and_cols_where_all(array: jax.experimental.sparse.BCOO, value: float) jax.experimental.sparse.BCOO#

Remove rows and columns from a sparse tensor where all elements are equal to a given value.

Example::
>>> array = BCOO.fromdense(jnp.array([[1, 0, 3], [0, 0, 0], [7, 0, 9]]))
>>> sparse_remove_rows_and_cols_where_all(array, 0).todense()
    [[1 3]
     [7 9]]
Parameters:
  • array – The sparse tensor to remove rows and columns from.

  • value – The value to remove.

Returns:

The tensor without the unnecessary rows and columns.