Swapping Out Transformer Subcomponents¶
Author: Spencer Poff
Sometimes you find yourself wanting to experiment with an architecture that looks a lot like another, but with one component modified. If that component is buried deep within the model, this is not easily accomplished with subclassing without copying and pasting much of the original implementation.
To make this easier and avoid copypasta, we provide the @swappable decorator.
Making a Module Swappable¶
Let’s say you have an existing class, TransformerLayer, that uses a module that you’d like to modify, TransformerFFN. You can make that FFN swappable in two steps:
Decorate
TransformerLayerwith@swappable, passing in a name for the component you’d like to swap and its default class/constructor:
@swappable(ffn=TransformerFFN)
class TransformerLayer(nn.Module):
...
At runtime, the class for ffn will be added to a property
swappablesofTransformerLayer. Replace your instantiation ofTransformerFFNwith a call to that constructor:
self.feedforward = self.swappables.ffn(opt, ...)
That’s it!
Making the Swap¶
You can now replace TransformerFFN with whatever class or constructor you want before instantiating TransformerLayer:
layer = TransformerLayer.with_components(ffn=NewCustomFFN)(opt, ...)
As long as NewCustomFFN has the same __init__ and forward method signatures as TransformerFFN, everything should just work.
For examples, see:
parlai/agents/examples/transformer_variant.pyprojects/params_vs_compute/hash_ladder/hash_ladder.py
Composability¶
Since the swapping happens before instantiation, decorated components can be transparently composed. For example:
model = TransformerGeneratorModel.with_components(
encoder=TransformerEncoder.with_components(
layer=TransformerEncoderLayer.with_components(
self_attention=MultiHeadAttention,
feedforward=TransformerFFN,
)
),
decoder=TransformerDecoder.with_components(
layer=TransformerDecoderLayer.with_components(
encoder_attention=MultiHeadAttention,
self_attention=MultiHeadAttention,
feedforward=TransformerFFN,
)
),
)(opt=self.opt, dictionary=self.dict)
Implementation¶
See parlai/agents/transformer/modules/modular.py