probabilistic_model.probabilistic_circuit.jax.utils#
Functions#
|
|
|
|
|
|
|
Create the indices of a BCOO array with the given row lengths. |
|
Create the indices of a BCOO array with the given row lengths. |
Create the column indices and indent pointer of bcsr array with the given row lengths. |
|
|
|
Sample from a sparse array of probabilities. |
|
|
Remove rows and columns from an array where all elements are equal to a given value. |
|
Shrink an index array to only contain successive indices. |
Remove rows and columns from a sparse tensor where all elements are equal to a given value. |
Module Contents#
- probabilistic_model.probabilistic_circuit.jax.utils.copy_bcoo(x: jax.experimental.sparse.BCOO) jax.experimental.sparse.BCOO#
- probabilistic_model.probabilistic_circuit.jax.utils.copy_bcsr(x: jax.experimental.sparse.BCSR) jax.experimental.sparse.BCSR#
- probabilistic_model.probabilistic_circuit.jax.utils.simple_interval_to_open_array(interval: random_events.interval.SimpleInterval) jax.numpy.array#
- probabilistic_model.probabilistic_circuit.jax.utils.create_bcoo_indices_from_row_lengths(row_lengths: numpy.array) numpy.array#
Create the indices of a BCOO array with the given row lengths.
The shape of the indices is (2, sum(row_lengths)). The shape of the sparse tensor that the indices describe should be (len(row_lengths), max(row_lengths)).
Example:
>>> row_lengths = jnp.array([2, 3]) >>> create_bcoo_indices_from_row_lengths(row_lengths) [[0 0] [0 1] [1 0] [1 1] [1 2]]
- Parameters:
row_lengths – The row lengths.
- Returns:
The indices of the sparse tensor
- probabilistic_model.probabilistic_circuit.jax.utils.create_bcoo_indices_from_row_lengths_np(row_lengths: numpy.array) numpy.array#
Create the indices of a BCOO array with the given row lengths.
The shape of the indices is (2, sum(row_lengths)). The shape of the sparse tensor that the indices describe should be (len(row_lengths), max(row_lengths)).
Example:
>>> row_lengths = jnp.array([2, 3]) >>> create_bcoo_indices_from_row_lengths(row_lengths) [[0 0] [0 1] [1 0] [1 1] [1 2]]
- Parameters:
row_lengths – The row lengths.
- Returns:
The indices of the sparse tensor
- probabilistic_model.probabilistic_circuit.jax.utils.create_bcsr_indices_from_row_lengths(row_lengths: jax.Array) typing_extensions.Tuple[jax.Array, jax.Array]#
Create the column indices and indent pointer of bcsr array with the given row lengths.
The shape of the sparse tensor that the indices describe should be (len(row_lengths), max(row_lengths)).
Example:
>>> row_lengths = jnp.array([2, 3]) >>> create_bcsr_indices_from_row_lengths(row_lengths) (Array([0, 1, 0, 1, 2], dtype=int32), Array([0, 2, 5], dtype=int32))
- Parameters:
row_lengths – The row lengths.
- Returns:
The indices of the sparse tensor
- probabilistic_model.probabilistic_circuit.jax.utils.embed_sparse_array_in_nan_array(sparse_array: jax.experimental.sparse.BCOO) jax.Array#
- probabilistic_model.probabilistic_circuit.jax.utils.sample_from_sparse_probabilities_csc(probabilities: scipy.sparse.csr_array, amount: numpy.array) scipy.sparse.csc_array#
Sample from a sparse array of probabilities. Each row in the sparse array encodes a categorical probability distribution.
- Parameters:
probabilities – The sparse array of probabilities.
amount – The amount of samples to draw from each row.
- Returns:
The samples that are drawn for each state in the probabilities indicies.
- probabilistic_model.probabilistic_circuit.jax.utils.remove_rows_and_cols_where_all(array: jax.Array, value: float) jax.Array#
Remove rows and columns from an array where all elements are equal to a given value.
- Parameters:
array – The tensor to remove rows and columns from.
value – The value to remove.
- Returns:
The tensor without the rows and columns.
Example:
>>> a = jnp.array([[1, 0, 3], [0, 0, 0], [7, 0, 9]]) >>> remove_rows_and_cols_where_all(a, 0) array([[1, 3], [7, 9]])
- probabilistic_model.probabilistic_circuit.jax.utils.shrink_index_array(index_array: jax.Array) jax.Array#
Shrink an index array to only contain successive indices.
Example:
>>> shrink_index_array(jnp.array([[0, 3], [1, 0], [4, 1]])) [[0 2] [1 0] [2 1]]
- Parameters:
index_array – The index tensor to shrink.
- Returns:
The shrunken index tensor.
- probabilistic_model.probabilistic_circuit.jax.utils.sparse_remove_rows_and_cols_where_all(array: jax.experimental.sparse.BCOO, value: float) jax.experimental.sparse.BCOO#
Remove rows and columns from a sparse tensor where all elements are equal to a given value.
- Example::
>>> array = BCOO.fromdense(jnp.array([[1, 0, 3], [0, 0, 0], [7, 0, 9]])) >>> sparse_remove_rows_and_cols_where_all(array, 0).todense() [[1 3] [7 9]]
- Parameters:
array – The sparse tensor to remove rows and columns from.
value – The value to remove.
- Returns:
The tensor without the unnecessary rows and columns.