torch_concepts.nn.functional.prune_linear_layer

prune_linear_layer(linear: Linear, mask: Tensor, dim: int = 0) Linear[source]

Return a new nn.Linear where inputs (dim=0) or outputs (dim=1) have been pruned according to mask.

Parameters:
  • linear (nn.Linear) – Layer to prune.

  • mask (1D Tensor[bool] or 0/1) – Mask over features. True/1 = keep, False/0 = drop. - If dim=0: length == in_features - If dim=1: length == out_features

  • dim (int) – 0 -> prune input features (columns of weight) 1 -> prune output units (rows of weight)