yamle.methods.ensemble module#

class yamle.methods.ensemble.EnsembleMethod(*args, **kwargs)[source]#

Bases: MemberMethod

This is a Method class for the Ensemble model.

It uses the Ensemble model to wrap around the original model and then uses the base method to train the members one by one in cooperation with the EnsembleTrainer class.

Parameters:

num_members (int) – The number of members in the ensemble.

increment_current_member()[source]#

This method is used to increment the current member index.

Return type:

None

class yamle.methods.ensemble.SnapsotEnsembleMethod(*args, **kwargs)[source]#

Bases: EnsembleMethod

This is a Method class for the Snapshot Ensemble method.

It uses the Ensemble model to wrap around the original model and then uses the base method to train the network via the cyclic learning rate scheduler. Each time the learning rate hits the minimum, the current model is saved as the next member, while the learning rate is reset to the maximum value and the main model is trained further.

Parameters:

num_members (int) – The number of members in the ensemble.

get_parameters(recurse=True)[source]#

A helper function to get the parameters of a single ensemble member.

In this case, get always the first one.

Return type:

List[Parameter]

on_train_epoch_end()[source]#

This method is called at the end of each training epoch.

In this case, if the learning rate cycle has been completed, the current model’s weights are copied into the next member of the ensemble.

Return type:

None

configure_optimizers()[source]#

This method is used to configure the optimizers and the learning rate schedulers.

Return type:

Tuple[List[Optimizer], List[_LRScheduler]]

class yamle.methods.ensemble.GradientBoostingEnsembleMethod(num_members, shrinkage=1.0, *args, **kwargs)[source]#

Bases: EnsembleMethod

This is a Method class for the Gradient Boosting Ensemble method.

It uses the Ensemble model to wrap around the original model and then uses the base method to train the network.

Parameters:
  • num_members (int) – The number of members in the ensemble.

  • shrinkage (float) – The shrinkage parameter for the gradient boosting. Defaults to 1.0.

static add_specific_args(parent_parser)[source]#

This method is used to add the specific arguments for the class.

Return type:

ArgumentParser