probabilistic_model.probabilistic_circuit.jax.utils
===================================================

.. py:module:: probabilistic_model.probabilistic_circuit.jax.utils


Functions
---------

.. autoapisummary::

   probabilistic_model.probabilistic_circuit.jax.utils.copy_bcoo
   probabilistic_model.probabilistic_circuit.jax.utils.copy_bcsr
   probabilistic_model.probabilistic_circuit.jax.utils.simple_interval_to_open_array
   probabilistic_model.probabilistic_circuit.jax.utils.create_bcoo_indices_from_row_lengths
   probabilistic_model.probabilistic_circuit.jax.utils.create_bcoo_indices_from_row_lengths_np
   probabilistic_model.probabilistic_circuit.jax.utils.create_bcsr_indices_from_row_lengths
   probabilistic_model.probabilistic_circuit.jax.utils.embed_sparse_array_in_nan_array
   probabilistic_model.probabilistic_circuit.jax.utils.sample_from_sparse_probabilities_csc
   probabilistic_model.probabilistic_circuit.jax.utils.remove_rows_and_cols_where_all
   probabilistic_model.probabilistic_circuit.jax.utils.shrink_index_array
   probabilistic_model.probabilistic_circuit.jax.utils.sparse_remove_rows_and_cols_where_all


Module Contents
---------------

.. py:function:: copy_bcoo(x: jax.experimental.sparse.BCOO) -> jax.experimental.sparse.BCOO

.. py:function:: copy_bcsr(x: jax.experimental.sparse.BCSR) -> jax.experimental.sparse.BCSR

.. py:function:: simple_interval_to_open_array(interval: random_events.interval.SimpleInterval) -> jax.numpy.array

.. py:function:: create_bcoo_indices_from_row_lengths(row_lengths: numpy.array) -> numpy.array

   Create the indices of a BCOO array with the given row lengths.

   The shape of the indices is (2, sum(row_lengths)).
   The shape of the sparse tensor that the indices describe should be (len(row_lengths), max(row_lengths)).

   Example::

       >>> row_lengths = jnp.array([2, 3])
       >>> create_bcoo_indices_from_row_lengths(row_lengths)
           [[0 0]
            [0 1]
            [1 0]
            [1 1]
            [1 2]]

   :param row_lengths: The row lengths.
   :return: The indices of the sparse tensor


.. py:function:: create_bcoo_indices_from_row_lengths_np(row_lengths: numpy.array) -> numpy.array

   Create the indices of a BCOO array with the given row lengths.

   The shape of the indices is (2, sum(row_lengths)).
   The shape of the sparse tensor that the indices describe should be (len(row_lengths), max(row_lengths)).

   Example::

       >>> row_lengths = jnp.array([2, 3])
       >>> create_bcoo_indices_from_row_lengths(row_lengths)
           [[0 0]
            [0 1]
            [1 0]
            [1 1]
            [1 2]]

   :param row_lengths: The row lengths.
   :return: The indices of the sparse tensor


.. py:function:: create_bcsr_indices_from_row_lengths(row_lengths: jax.Array) -> typing_extensions.Tuple[jax.Array, jax.Array]

   Create the column indices and indent pointer of bcsr array with the given row lengths.

   The shape of the sparse tensor that the indices describe should be (len(row_lengths), max(row_lengths)).

   Example::

       >>> row_lengths = jnp.array([2, 3])
       >>> create_bcsr_indices_from_row_lengths(row_lengths)
       (Array([0, 1, 0, 1, 2], dtype=int32), Array([0, 2, 5], dtype=int32))

   :param row_lengths: The row lengths.
   :return: The indices of the sparse tensor


.. py:function:: embed_sparse_array_in_nan_array(sparse_array: jax.experimental.sparse.BCOO) -> jax.Array

.. py:function:: sample_from_sparse_probabilities_csc(probabilities: scipy.sparse.csr_array, amount: numpy.array) -> scipy.sparse.csc_array

   Sample from a sparse array of probabilities.
   Each row in the sparse array encodes a categorical probability distribution.

   :param probabilities: The sparse array of probabilities.
   :param amount: The amount of samples to draw from each row.
   :return: The samples that are drawn for each state in the probabilities indicies.


.. py:function:: remove_rows_and_cols_where_all(array: jax.Array, value: float) -> jax.Array

   Remove rows and columns from an array where all elements are equal to a given value.

   :param array: The tensor to remove rows and columns from.
   :param value: The value to remove.
   :return: The tensor without the rows and columns.

   Example::


       >>> a = jnp.array([[1, 0, 3], [0, 0, 0], [7, 0, 9]])
       >>> remove_rows_and_cols_where_all(a, 0)
       array([[1, 3], [7, 9]])


.. py:function:: shrink_index_array(index_array: jax.Array) -> jax.Array

   Shrink an index array to only contain successive indices.

   Example::

       >>> shrink_index_array(jnp.array([[0, 3], [1, 0], [4, 1]]))
           [[0 2]
            [1 0]
            [2 1]]
   :param index_array: The index tensor to shrink.
   :return: The shrunken index tensor.


.. py:function:: sparse_remove_rows_and_cols_where_all(array: jax.experimental.sparse.BCOO, value: float) -> jax.experimental.sparse.BCOO

   Remove rows and columns from a sparse tensor where all elements are equal to a given value.

   Example::
       >>> array = BCOO.fromdense(jnp.array([[1, 0, 3], [0, 0, 0], [7, 0, 9]]))
       >>> sparse_remove_rows_and_cols_where_all(array, 0).todense()
           [[1 3]
            [7 9]]

   :param array: The sparse tensor to remove rows and columns from.
   :param value: The value to remove.
   :return: The tensor without the unnecessary rows and columns.


