virne.solver.learning.utils#

Functions

apply_mask_to_logit(logit[, mask])

Apply a mask to a given logits tensor.

get_available_device()

Get the available device (CPU or GPU).

get_observations_sample(obs_batch, indices)

Get a sample from an input observation batch given the indices.

get_pyg_batch(x_batch, edge_index_batch[, ...])

Convert a batch of node and edge information into Pytorch Geometric format.

get_pyg_data(x, edge_index[, edge_attr])

Convert node and edge information into Pytorch Geometric format.

load_pyg_batch_from_network_list(network_list)

Load batch data from a list of networks

load_pyg_data_from_network(network[, ...])

Load data from network

normailize_data(data[, method])

Normalize node or edge data.

update_mean_var_count_from_moments(mean, ...)

Update mean, variance, and count from the batch mean, variance, and count.

Classes

RunningMeanStd([epsilon, shape])

Calculate running mean and standard deviation for a given data.