EIANN.utils.data_utils#
Functions#
Find the root directory of the project containing the 'EIANN' folder. |
|
|
Crawl a nested dictionary and recursively convert all NumPy scalar types to native Python types in the nested structure. |
|
Write a dictionary to a YAML file. |
|
Read a dictionary from a YAML file. |
|
Export metrics data to an HDF5 file under a specific model group. |
|
Import metrics data from an HDF5 file into a nested dictionary. |
|
Convert the contents of an HDF5 file into a nested dictionary. |
|
Helper function to recursively convert an HDF5 group to a nested dictionary. |
|
Save a nested dictionary to an HDF5 file. |
|
Recursively write a nested dictionary to an HDF5 group. |
|
Save plot data for a specific network and seed into an HDF5 file. |
|
Load plot data for a specific network and seed from an HDF5 file. |
|
Delete a specific variable from an HDF5 file. |
|
Load MNIST dataset and return custom dataloaders that include sample indices. |
|
Load FashionMNIST dataset and return custom dataloaders that include sample indices. |
|
|
|
Generate a synthetic spiral dataset. |
|
Generate spiral dataset and return dataloaders. |
|
Generate spike times from an inhomogeneous Poisson process using the thinning method. |
|
Calculates number of ways to choose k things out of n, using binomial coefficients |
|
Generates all possible binary n-hot patterns of given length |
Module Contents#
- get_project_root()#
Find the root directory of the project containing the ‘EIANN’ folder.
- Returns:
Absolute path to the project root directory.
- Return type:
str
- Raises:
FileNotFoundError – If the ‘EIANN’ directory cannot be found.
- nested_convert_scalars(data)#
Crawl a nested dictionary and recursively convert all NumPy scalar types to native Python types in the nested structure.
- Parameters:
data (any) – A potentially nested structure (dict, list, tuple) containing scalars.
- Returns:
The same structure with NumPy scalars converted to native Python types.
- Return type:
any
- write_to_yaml(file_path, data, convert_scalars=True)#
Write a dictionary to a YAML file.
- Parameters:
file_path (str) – Path to the output YAML file. Should end with ‘.yaml’.
data (dict) – Dictionary to write to the file.
convert_scalars (bool, optional) – Whether to convert NumPy scalar types to native Python types before writing. Default is True.
- read_from_yaml(file_path, Loader=None)#
Read a dictionary from a YAML file.
- Parameters:
file_path (str) – Path to the YAML file to read.
Loader (yaml.Loader or None, optional) – YAML loader to use. Defaults to yaml.FullLoader.
- Returns:
Dictionary parsed from the YAML file.
- Return type:
dict
- Raises:
Exception – If the specified file does not exist.
- export_metrics_data(metrics_dict, model_name, path)#
Export metrics data to an HDF5 file under a specific model group.
- Parameters:
metrics_dict (dict) – Dictionary of metrics to export.
model_name (str) – Name of the model used as the top-level group in the HDF5 file.
path (str) – Path to the HDF5 file. If missing ‘.hdf5’, it will be appended.
- import_metrics_data(filename)#
Import metrics data from an HDF5 file into a nested dictionary.
- Parameters:
filename (str) – Path to the HDF5 file.
- Returns:
Nested dictionary of metrics by model and metric name.
- Return type:
dict
- hdf5_to_dict(file_path)#
Convert the contents of an HDF5 file into a nested dictionary.
- Parameters:
file_path (str) – Path to the HDF5 file.
- Returns:
Nested Python dictionary representing the HDF5 file structure.
- Return type:
dict
- convert_hdf5_group_to_dict(group)#
Helper function to recursively convert an HDF5 group to a nested dictionary.
- Parameters:
group (h5py.Group) – The HDF5 group to convert.
- Returns:
Dictionary representing the structure and datasets within the group.
- Return type:
dict
- dict_to_hdf5(data_dict, file_path)#
Save a nested dictionary to an HDF5 file.
- Parameters:
data_dict (dict) – Dictionary to save to the file.
file_path (str) – Destination path for the HDF5 file.
- convert_dict_to_hdf5_group(data_dict, group)#
Recursively write a nested dictionary to an HDF5 group.
- Parameters:
data_dict (dict) – Dictionary to write to the HDF5 group.
group (h5py.Group) – Target HDF5 group for storing the dictionary data.
- save_plot_data(network_name, seed, data_key, data, file_path=None, overwrite=False)#
Save plot data for a specific network and seed into an HDF5 file.
- Parameters:
network_name (str) – Name of the network.
seed (int) – Seed identifier for the data.
data_key (str) – Key under which to store the data.
data (array-like or dict) – Data to be saved.
file_path (str, optional) – Path to the HDF5 file. If None, a default path is used.
overwrite (bool, optional) – Whether to overwrite existing data at the specified key.
- load_plot_data(network_name, seed, data_key, file_path=None)#
Load plot data for a specific network and seed from an HDF5 file.
- Parameters:
network_name (str) – Name of the network.
seed (int) – Seed identifier for the data.
data_key (str) – Key under which the data is stored.
file_path (str, optional) – Path to the HDF5 file. If None, a default path is used.
- Returns:
The loaded data if present, otherwise None.
- Return type:
any or None
- delete_plot_data(variable_name, file_name, file_path_prefix='../data/')#
Delete a specific variable from an HDF5 file.
- Parameters:
variable_name (str) – Name of the variable to delete.
file_name (str) – Name of the HDF5 file.
file_path_prefix (str, optional) – Path prefix for the file location. Default is ‘../data/’.
- get_MNIST_dataloaders(sub_dataloader_size=None, batch_size=1, data_dir=None)#
Load MNIST dataset and return custom dataloaders that include sample indices.
- Parameters:
sub_dataloader_size (int, optional) – If set, creates a separate dataloader with this many samples.
batch_size (int, optional) – Batch size for the sub-dataloader.
data_dir (str, optional) – Path to the dataset directory. If None, a default path is used.
- Returns:
Tuple of DataLoaders: (train, [train_sub], val, test, generator).
- Return type:
tuple
- get_FashionMNIST_dataloaders(sub_dataloader_size=None, batch_size=1, data_dir=None)#
Load FashionMNIST dataset and return custom dataloaders that include sample indices.
- Parameters:
sub_dataloader_size (int, optional) – If set, creates a separate dataloader with this many samples.
batch_size (int, optional) – Batch size for the sub-dataloader.
data_dir (str, optional) – Path to the dataset directory. If None, a default path is used.
- Returns:
Tuple of DataLoaders: (train, [train_sub], val, test, generator).
- Return type:
tuple
- get_cifar10_dataloaders(sub_dataloader_size=None, batch_size=1, data_dir=None)#
- generate_spiral_data(arm_size=500, K=4, sigma=0.16, seed=0, offset=0.0)#
Generate a synthetic spiral dataset.
- Parameters:
arm_size (int, optional) – Number of points per spiral arm. Default is 500.
K (int, optional) – Number of spiral arms. Default is 4.
sigma (float, optional) – Noise level. Default is 0.16.
seed (int, optional) – Random seed. Default is 0.
offset (float, optional) – Offset to apply to spiral coordinates. Default is 0.
- Returns:
Each element is (index, data_tensor, one_hot_target).
- Return type:
list of tuples
- get_spiral_dataloaders(batch_size=1, points_per_spiral_arm=2000, seed=0)#
Generate spiral dataset and return dataloaders.
- Parameters:
batch_size (int or str, optional) – Batch size for training. Use ‘all’ or ‘full_dataset’ to load all data in one batch.
points_per_spiral_arm (int, optional) – Number of points per spiral arm. Default is 2000.
seed (int, optional) – Random seed. Default is 0.
- Returns:
Tuple of DataLoaders: (train, val, test, generator).
- Return type:
tuple
- generate_inhomogeneous_poisson_spikes(rate, refractory_period=3)#
Generate spike times from an inhomogeneous Poisson process using the thinning method.
Example usage:
`python rate = 300 * np.ones(1000) # Example rate in Hz spike_times = generate_inhomogeneous_poisson_spikes(rate, refractory_period=2) `
Parameters:#
- ratenumpy.ndarray
Time series of firing rates in Hz, sampled at 1ms interval. Each element represents the instantaneous firing rate at that time point.
- refractory_periodfloat, optional
Minimum interval between spikes in milliseconds
Returns:#
- list
List of spike times in seconds
- n_choose_k(n, k)#
Calculates number of ways to choose k things out of n, using binomial coefficients
- Parameters:
n (int) – number of things to choose from
k (int) – number of things chosen
- Returns:
int
- n_hot_patterns(n, length)#
Generates all possible binary n-hot patterns of given length
- Parameters:
n (int) – number of bits set to 1
length (int) – size of pattern (number of bits)
- Returns:
torch.tensor