camenduru's picture
thanks to NVIDIA ❤
7934b29
# Fine-tuning BERT model for information retrieval
name: &name BertIR
trainer:
devices: 1 # the number of gpus, 0 for CPU, or list with gpu indices
num_nodes: 1
max_epochs: 2 # the number of training epochs
max_steps: -1 # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
precision: 16 # 16 to use AMP
accelerator: gpu
strategy: ddp
log_every_n_steps: 1 # Interval of logging.
val_check_interval: 0.05 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
enable_checkpointing: False # provided by exp_manager
logger: false # provided by exp_manager
model:
nemo_path: null # exported .nemo path
language_model:
pretrained_model_name: bert-base-uncased
sim_score_dropout: 0.1
lm_checkpoint: null
config:
attention_probs_dropout_prob: 0.1
hidden_act: gelu
hidden_dropout_prob: 0.1
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_hidden_layers: 12
type_vocab_size: 2
vocab_size: 30522
config_file: null # json file, precedence over config
tokenizer:
tokenizer_name: ${model.language_model.pretrained_model_name} # tokenizer that inherits from TokenizerSpec
vocab_file: null # path to vocab file
tokenizer_model: null # tokenizer model for sentencepiece
special_tokens: null
train_ds:
passages: null # path to file with passages and their indices
queries: null # path to file with training queries and their indices
query_to_passages: null
# path to file with training examples which have the form of
# (query_id, relevant_passage_id, irrelevant_passage_1_id, ..., irrelevant_passage_n_id)
num_negatives: 10
batch_size: 6
psg_cache_format: npz
shuffle: true
num_samples: -1 # number of samples to be considered, -1 means all the dataset
num_workers: 1
drop_last: false
pin_memory: false
validation_ds:
passages: null # path to file with passages and their indices
queries: null # path to file with validation queries and their indices
query_to_passages: null # path to file with passages to re-rank for each validation query
num_negatives: 10
batch_size: 6
psg_cache_format: pkl
shuffle: false
num_samples: -1 # number of samples to be considered, -1 means all the dataset
num_workers: 1
drop_last: false
pin_memory: false
optim:
name: adam
lr: 1e-5
betas: [0.9, 0.999]
weight_decay: 0
sched:
name: WarmupAnnealing
warmup_steps: null
warmup_ratio: 0.05
last_epoch: -1
# pytorch lightning args
monitor: val_loss
reduce_on_plateau: false
exp_manager:
exp_dir: null # where to store logs and checkpoints
name: *name # name of experiment
create_tensorboard_logger: True
create_checkpoint_callback: True
hydra:
run:
dir: .
job_logging:
root:
handlers: null