.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/advanced/broadcasting_your_own_methods.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_advanced_broadcasting_your_own_methods.py: Broadcast functions across multi-dimensional data ===================================================== Use the @make_broadcastable decorator to efficiently apply functions across any data dimension. .. GENERATED FROM PYTHON SOURCE LINES 9-20 Summary ------- :func:`@make_broadcastable()` is particularly useful when you need to apply the same operation to multiple individuals or time points while avoiding the need to write complex loops. The example walks through a practical case study of detecting when animals enter a specific region of interest, showing how to convert a simple point-in-rectangle check into a function that works on a data array with many time-varying point trajectories. .. GENERATED FROM PYTHON SOURCE LINES 22-29 Imports ------- We will need ``numpy`` and ``xarray`` to make our custom data for this example, and ``matplotlib`` to show what it contains. We will be using the :mod:`movement.utils.broadcasting` module to turn our one-dimensional functions into functions that work across entire ``DataArray`` objects. .. GENERATED FROM PYTHON SOURCE LINES 31-43 .. code-block:: Python # For interactive plots: install ipympl with `pip install ipympl` and uncomment # the following lines in your notebook # %matplotlib widget import matplotlib.pyplot as plt import numpy as np import xarray as xr from movement import sample_data from movement.plots import plot_centroid_trajectory from movement.utils.broadcasting import make_broadcastable .. GENERATED FROM PYTHON SOURCE LINES 44-49 Load Sample Dataset ------------------- First, we load the ``SLEAP_three-mice_Aeon_proofread`` example dataset. For the rest of this example we'll only need the ``position`` data array, so we store it in a separate variable. .. GENERATED FROM PYTHON SOURCE LINES 49-52 .. code-block:: Python ds = sample_data.fetch_dataset("SLEAP_three-mice_Aeon_proofread.analysis.h5") positions: xr.DataArray = ds.position .. GENERATED FROM PYTHON SOURCE LINES 53-57 The individuals in this dataset follow very similar, arc-like trajectories. To help emphasise what we are doing in this example, we will offset the paths of two of the individuals by a small amount so that the trajectories are more distinct. .. GENERATED FROM PYTHON SOURCE LINES 57-61 .. code-block:: Python positions.loc[:, "y", :, "AEON3B_TP1"] -= 100.0 positions.loc[:, "y", :, "AEON3B_TP2"] += 100.0 .. GENERATED FROM PYTHON SOURCE LINES 62-85 .. code-block:: Python fig, ax = plt.subplots(1, 1) for mouse_name, col in zip( positions.individuals.values, ["r", "g", "b"], strict=False ): plot_centroid_trajectory( positions, individual=mouse_name, keypoints="centroid", ax=ax, linestyle="-", marker=".", s=2, linewidth=0.5, c=col, label=mouse_name, ) ax.invert_yaxis() ax.set_title("Trajectories") ax.set_xlabel("x (pixels)") ax.set_ylabel("y (pixels)") ax.legend() .. image-sg:: /examples/advanced/images/sphx_glr_broadcasting_your_own_methods_001.png :alt: Trajectories :srcset: /examples/advanced/images/sphx_glr_broadcasting_your_own_methods_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 86-97 Motivation ---------- Suppose that, during our experiment, we have a region of the enclosure that has a slightly wet floor, making it slippery. The individuals must cross this region in order to reach some kind of reward on the other side of the enclosure. We know that the "slippery region" of our enclosure is approximately rectangular in shape, and has its opposite corners at (400, 0) and (600, 2000), where the coordinates are given in pixels. We could then write a function that determines if a given (x, y) position was inside this "slippery region". .. GENERATED FROM PYTHON SOURCE LINES 97-118 .. code-block:: Python def in_slippery_region(xy_position) -> bool: """Return True if xy_position is in the slippery region. Return False otherwise. xy_position has 2 elements, the (x, y) coordinates respectively. """ # The slippery region is a rectangle with the following bounds x_min, y_min = 400.0, 0.0 x_max, y_max = 600.0, 2000.0 is_within_bounds_x = x_min <= xy_position[0] <= x_max is_within_bounds_y = y_min < xy_position[1] <= y_max return is_within_bounds_x and is_within_bounds_y # We can just check our function with a few sample points for point in [(0, 100), (450, 700), (550, 1500), (601, 500)]: print(f"{point} is in slippery region: {in_slippery_region(point)}") .. rst-class:: sphx-glr-script-out .. code-block:: none (0, 100) is in slippery region: False (450, 700) is in slippery region: True (550, 1500) is in slippery region: True (601, 500) is in slippery region: False .. GENERATED FROM PYTHON SOURCE LINES 119-124 Determine if each position was slippery --------------------------------------- Given our data, we could extract whether each position (for each time-point, and each individual) was inside the slippery region by looping over the values. .. GENERATED FROM PYTHON SOURCE LINES 124-166 .. code-block:: Python data_shape = positions.shape in_slippery = np.zeros( shape=( len(positions["time"]), len(positions["keypoints"]), len(positions["individuals"]), ), dtype=bool, ) # We would save one result per time-point, per keypoint, per individual # Feel free to comment out the print statements # (line-by-line progress through the loop), # if you are running this code on your own machine. for time_index, time in enumerate(positions["time"].values): # print(f"At time {time}:") for keypoint_index, keypoint in enumerate(positions["keypoints"].values): # print(f"\tAt keypoint {keypoint}") for individual_index, individual in enumerate( positions["individuals"].values ): xy_point = positions.sel( time=time, keypoints=keypoint, individuals=individual, ) was_in_slippery = in_slippery_region(xy_point) was_in_slippery_text = ( "was in slippery region" if was_in_slippery else "was not in slippery region" ) # print( # "\t\tIndividual " # f"{positions['individuals'].values[individual_index]} " # f"{was_in_slippery_text}" # ) # Save our result to our large array in_slippery[time_index, keypoint_index, individual_index] = ( was_in_slippery ) .. GENERATED FROM PYTHON SOURCE LINES 167-169 We could then build a new ``DataArray`` to store our results, so that we can access the results in the same way that we did our original data. .. GENERATED FROM PYTHON SOURCE LINES 169-185 .. code-block:: Python was_in_slippery_region = xr.DataArray( in_slippery, dims=["time", "keypoints", "individuals"], coords={ "time": positions["time"], "keypoints": positions["keypoints"], "individuals": positions["individuals"], }, ) print( "Boolean DataArray indicating if at a given time, " "a given individual was inside the slippery region:" ) was_in_slippery_region .. rst-class:: sphx-glr-script-out .. code-block:: none Boolean DataArray indicating if at a given time, a given individual was inside the slippery region: .. raw:: html
<xarray.DataArray (time: 601, keypoints: 1, individuals: 3)> Size: 2kB
    False False False False False False ... False False False False False False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 186-188 We could get the first and last time that an individual was inside the slippery region now, by examining this ``DataArray``. .. GENERATED FROM PYTHON SOURCE LINES 188-199 .. code-block:: Python i_id = "AEON3B_NTP" individual_0_centroid = was_in_slippery_region.sel( individuals=i_id, keypoints="centroid" ) first_entry = individual_0_centroid["time"][individual_0_centroid].values[0] last_exit = individual_0_centroid["time"][individual_0_centroid].values[-1] print( f"{i_id} first entered the slippery region at " f"{first_entry} and last exited at {last_exit}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none AEON3B_NTP first entered the slippery region at 2.1 and last exited at 8.64 .. GENERATED FROM PYTHON SOURCE LINES 200-213 Data Generalisation Issues -------------------------- The shape of the resulting ``DataArray`` is the same as our original ``DataArray``, but without the ``"space"`` dimension. Indeed, we have essentially collapsed the ``"space"`` dimension, since our ``in_slippery_region`` function takes in a 1D data slice (the x, y positions of a single individual's centroid at a given point in time) and returns a scalar value (True/False). However, the fact that we have to construct a new ``DataArray`` after running our function over all space slices in our ``DataArray`` is not scalable - our ``for`` loop approach relied on knowing how many dimensions our data had (and the size of those dimensions). We don't have a guarantee that the next ``DataArray`` that comes in will have the same structure. .. GENERATED FROM PYTHON SOURCE LINES 215-225 Making our Function Broadcastable --------------------------------- To combat this problem, we can make the observation that given any ``DataArray``, we always want to broadcast our ``in_slippery_region`` function along the ``"space"`` dimension. By "broadcast", we mean that we always want to run our function for each 1D-slice in the ``"space"`` dimension, since these are the (x, y) coordinates. As such, we can decorate our function with :func:`@make_broadcastable()\ `: .. GENERATED FROM PYTHON SOURCE LINES 225-232 .. code-block:: Python @make_broadcastable() def in_slippery_region_broadcastable(xy_position) -> float: return in_slippery_region(xy_position=xy_position) .. GENERATED FROM PYTHON SOURCE LINES 233-238 Note that when writing your own methods, there is no need to have both ``in_slippery_region`` and ``in_slippery_region_broadcastable``, simply apply :func:`@make_broadcastable()` to ``in_slippery_region`` directly. We've made two separate functions here to illustrate what's going on. .. GENERATED FROM PYTHON SOURCE LINES 240-242 ``in_slippery_region_broadcastable`` is usable in exactly the same ways as ``in_slippery_region`` was: .. GENERATED FROM PYTHON SOURCE LINES 242-250 .. code-block:: Python for point in [(0, 100), (450, 700), (550, 1500), (601, 500)]: print( f"{point} is in slippery region: " f"{in_slippery_region_broadcastable(point)}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none (0, 100) is in slippery region: False (450, 700) is in slippery region: True (550, 1500) is in slippery region: True (601, 500) is in slippery region: False .. GENERATED FROM PYTHON SOURCE LINES 251-256 However, ``in_slippery_region_broadcastable`` also takes a ``DataArray`` as the first (``xy_position``) argument, and an extra keyword argument ``broadcast_dimension``. These arguments let us broadcast across the given dimension of the input ``DataArray``, treating each 1D-slice as a separate input to ``in_slippery_region``. .. GENERATED FROM PYTHON SOURCE LINES 256-265 .. code-block:: Python in_slippery_region_broadcasting = in_slippery_region_broadcastable( positions, # Now a DataArray input broadcast_dimension="space", ) print("DataArray output using broadcasting: ") in_slippery_region_broadcasting .. rst-class:: sphx-glr-script-out .. code-block:: none DataArray output using broadcasting: .. raw:: html
<xarray.DataArray 'position' (time: 601, keypoints: 1, individuals: 3)> Size: 2kB
    False False False False False False ... False False False False False False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 266-271 Calling ``in_slippery_region_broadcastable`` in this way gives us a ``DataArray`` output - and one that retains any information that was in our original ``DataArray`` to boot! The result is exactly the same as what we got from using our ``for`` loop, and then adding the extra information to the result. .. GENERATED FROM PYTHON SOURCE LINES 271-277 .. code-block:: Python # Throws an AssertionError if the two inputs are not the same xr.testing.assert_equal( was_in_slippery_region, in_slippery_region_broadcasting ) .. GENERATED FROM PYTHON SOURCE LINES 278-281 But importantly, ``in_slippery_region_broadcastable`` also works on ``DataArrays`` with different dimensions. For example, we could have pre-selected one of our individuals beforehand. .. GENERATED FROM PYTHON SOURCE LINES 281-295 .. code-block:: Python i_id = "AEON3B_NTP" individual_0 = positions.sel(individuals=i_id) individual_0_in_slippery_region = in_slippery_region_broadcastable( individual_0, broadcast_dimension="space", ) print( "We get a 3D DataArray output from our 4D input, " "again with the 'space' dimension that we broadcast along collapsed:" ) individual_0_in_slippery_region .. rst-class:: sphx-glr-script-out .. code-block:: none We get a 3D DataArray output from our 4D input, again with the 'space' dimension that we broadcast along collapsed: .. raw:: html
<xarray.DataArray 'position' (time: 601, keypoints: 1)> Size: 601B
    False False False False False False ... False False False False False False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
        individuals  <U10 40B 'AEON3B_NTP'


.. GENERATED FROM PYTHON SOURCE LINES 296-308 Additional Function Arguments ----------------------------- So far our ``in_slippery_region`` method only takes a single argument, the ``xy_position`` itself. However in follow-up experiments, we might move the slippery region in the enclosure, and so adapt our existing function to make it more general. It will now allow someone to input a custom rectangular region, by specifying the minimum and maximum ``(x, y)`` coordinates of the rectangle, rather than relying on fixed values inside the function. The default region will be the rectangle from our first experiment, and we still want to be able to broadcast this function. And so we write a more general function, as below. .. GENERATED FROM PYTHON SOURCE LINES 308-331 .. code-block:: Python @make_broadcastable() def in_slippery_region_general( xy_position, xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0) ) -> bool: """Return True if xy_position is in the slippery region. Return False otherwise. xy_position has 2 elements, the (x, y) coordinates respectively. """ x_min, y_min = xy_min x_max, y_max = xy_max is_within_bounds_x = x_min <= xy_position[0] <= x_max is_within_bounds_y = y_min <= xy_position[1] <= y_max return is_within_bounds_x and is_within_bounds_y # (0.5, 0.5) is in the unit square whose bottom left corner is at the origin print(in_slippery_region_general((0.5, 0.5), (0.0, 0.0), (1.0, 1.0))) # But (0.5,0.5) is not in a unit square whose bottom left corner is at (1,1) print(in_slippery_region_general((0.5, 0.5), (1.0, 1.0), (2.0, 2.0))) .. rst-class:: sphx-glr-script-out .. code-block:: none True False .. GENERATED FROM PYTHON SOURCE LINES 332-336 We will find that :func:`@make_broadcastable()` retains the additional arguments to the function we define, however the ``xy_position`` argument has to be the first argument to the function, that appears in the ``def`` statement. .. GENERATED FROM PYTHON SOURCE LINES 336-345 .. code-block:: Python # Default arguments should give us the same results as before xr.testing.assert_equal( was_in_slippery_region, in_slippery_region_general(positions) ) # But we can also provide the optional arguments in the same way as with the # un-decorated function. in_slippery_region_general(positions, xy_min=(100, 0), xy_max=(400, 1000)) .. raw:: html
<xarray.DataArray 'position' (time: 601, keypoints: 1, individuals: 3)> Size: 2kB
    False False True False False True False ... True True False True True False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 346-353 Only Broadcast Along Select Dimensions -------------------------------------- :func:`@make_broadcastable()` has some flexibility with its input arguments, to help you avoid unintentional behaviour. You may have noticed, for example, that there is nothing stopping someone who wants to use your analysis code from trying to broadcast along the wrong dimension. .. GENERATED FROM PYTHON SOURCE LINES 353-361 .. code-block:: Python silly_broadcast = in_slippery_region_broadcastable( positions, broadcast_dimension="time" ) print("The output has collapsed the time dimension:") silly_broadcast .. rst-class:: sphx-glr-script-out .. code-block:: none The output has collapsed the time dimension: .. raw:: html
<xarray.DataArray 'position' (space: 2, keypoints: 1, individuals: 3)> Size: 6B
    False False False False False False
    Coordinates:
      * space        (space) <U1 8B 'x' 'y'
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 362-373 There is no error thrown because functionally, this is a valid operation. The time slices of our data were 1D, so we can run ``in_slippery_region`` on them. But each slice isn't a position, it's an array of one spatial coordinate (EG x) for each keypoint, each individual, at every time! So from an analysis standpoint, doing this doesn't make sense and isn't how we intend our function to be used. We can pass the ``only_broadcastable_along`` keyword argument to :func:`@make_broadcastable()` to prevent these kinds of mistakes, and make our intentions clearer. .. GENERATED FROM PYTHON SOURCE LINES 373-380 .. code-block:: Python @make_broadcastable(only_broadcastable_along="space") def in_slippery_region_space_only(xy_position): return in_slippery_region(xy_position) .. GENERATED FROM PYTHON SOURCE LINES 381-383 Now, ``in_slippery_region_space_only`` no longer takes the ``broadcast_dimension`` argument. .. GENERATED FROM PYTHON SOURCE LINES 383-392 .. code-block:: Python try: in_slippery_region_space_only( positions, broadcast_dimension="time", ) except TypeError as e: print(f"Got a TypeError when trying to run, here's the message:\n{e}") .. rst-class:: sphx-glr-script-out .. code-block:: none Got a TypeError when trying to run, here's the message: __main__.in_slippery_region_space_only() got multiple values for keyword argument 'broadcast_dimension' .. GENERATED FROM PYTHON SOURCE LINES 393-400 The error we get seems to be telling us that we've tried to set the value of ``broadcast_dimension`` twice. Specifying ``only_broadcastable_along = "space"`` forces ``broadcast_dimension`` to be set to ``"space"``, so trying to set it again (even to to the same value) results in an error. However, ``in_slippery_region_space_only`` knows to only use the ``"space"`` dimension of the input by default. .. GENERATED FROM PYTHON SOURCE LINES 400-407 .. code-block:: Python was_in_view_space_only = in_slippery_region_space_only(positions) xr.testing.assert_equal( in_slippery_region_broadcasting, was_in_view_space_only ) .. GENERATED FROM PYTHON SOURCE LINES 408-415 It is worth noting that there is a "helper" decorator, :func:`@space_broadcastable()\ `, that essentially does the same thing as :func:`@make_broadcastable(only_broadcastable_along="space")\ `. You can use this decorator for your own convenience. .. GENERATED FROM PYTHON SOURCE LINES 417-422 Extending to Class Methods -------------------------- :func:`@make_broadcastable()` can also be applied to class methods, though it needs to be told that you are doing so via the ``is_classmethod`` parameter. .. GENERATED FROM PYTHON SOURCE LINES 422-452 .. code-block:: Python class Rectangle: """Represents an observing camera in the experiment.""" xy_min: tuple[float, float] xy_max: tuple[float, float] def __init__(self, xy_min=(0.0, 0.0), xy_max=(1.0, 1.0)): """Create a new instance.""" self.xy_min = tuple(xy_min) self.xy_max = tuple(xy_max) @make_broadcastable(is_classmethod=True, only_broadcastable_along="space") def is_inside(self, /, xy_position) -> bool: """Whether the position is inside the rectangle.""" # For the sake of brevity, we won't redefine the entire method here, # and will just call our existing function. return in_slippery_region_general( xy_position, self.xy_min, self.xy_max ) slippery_region = Rectangle(xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0)) was_in_region_clsmethod = slippery_region.is_inside(positions) xr.testing.assert_equal( was_in_region_clsmethod, in_slippery_region_broadcasting ) .. GENERATED FROM PYTHON SOURCE LINES 453-459 :func:`@broadcastable_method()\ ` is provided as a helpful alias for :func:`@make_broadcastable(is_classmethod=True)\ `, and otherwise works in the same way (and accepts the same parameters). .. GENERATED FROM PYTHON SOURCE LINES 459-493 .. code-block:: Python class RectangleAlternative: """Represents an observing camera in the experiment.""" xy_min: tuple[float, float] xy_max: tuple[float, float] def __init__(self, xy_min=(0.0, 0.0), xy_max=(1.0, 1.0)): """Create a new instance.""" self.xy_min = tuple(xy_min) self.xy_max = tuple(xy_max) @make_broadcastable(is_classmethod=True, only_broadcastable_along="space") def is_inside(self, /, xy_position) -> bool: """Whether the position is inside the rectangle.""" # For the sake of brevity, we won't redefine the entire method here, # and will just call our existing function. return in_slippery_region_general( xy_position, self.xy_min, self.xy_max ) slippery_region_alt = RectangleAlternative( xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0) ) was_in_region_clsmethod_alt = slippery_region.is_inside(positions) xr.testing.assert_equal( was_in_region_clsmethod_alt, in_slippery_region_broadcasting ) xr.testing.assert_equal(was_in_region_clsmethod_alt, was_in_region_clsmethod) .. GENERATED FROM PYTHON SOURCE LINES 494-499 In fact, if you look at the :mod:`movement.roi` module, and in particular the classes inside it, you'll notice that we use :func:`@broadcastable_method()\ ` ourselves in some of these methods! .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.113 seconds) .. _sphx_glr_download_examples_advanced_broadcasting_your_own_methods.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/neuroinformatics-unit/movement/gh-pages?filepath=notebooks/examples/advanced/broadcasting_your_own_methods.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: broadcasting_your_own_methods.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: broadcasting_your_own_methods.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: broadcasting_your_own_methods.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_