yamle.methods.be module#

yamle.methods.be.replace_with_be(model, num_members)[source]#
This method is used to replace all the nn.Linear, nn.Conv2d layers

with a LinearBE, Conv2dBE respectively.

Parameters:
  • model (nn.Module) – The model to replace the layers in.

  • num_members (int) – The number of members in the ensemble.

Return type:

None

class yamle.methods.be.BEMethod(*args, **kwargs)[source]#

Bases: MemberMethod

This class is the extension of the base method for BatchEnsemble models.

The difference is in having to change the prediction to concatenate the num_members dimension. into the batch dimension during validation and testing.

Note that only Linear and Conv2d layers are supported, not the batch norm layers. In practice this is not a problem google/edward2