|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Sample command to run the script: |
|
|
|
python multi_label_intent_slot_classification.py \ |
|
model.data_dir=/home/user/multiatis \ |
|
model.validation_ds.prefix=dev \ |
|
model.test_ds.prefix=dev \ |
|
trainer.gpus=[0] \ |
|
+trainer.fast_dev_run=true \ |
|
exp_manager.exp_dir=checkpoints |
|
|
|
fast_dev_run=false will save checkpoints for the model |
|
""" |
|
|
|
|
|
import pytorch_lightning as pl |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
from nemo.collections.nlp.models import MultiLabelIntentSlotClassificationModel |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
from nemo.utils.exp_manager import exp_manager |
|
|
|
|
|
@hydra_runner(config_path="conf", config_name="multi_label_intent_slot_classification_config") |
|
def main(cfg: DictConfig) -> None: |
|
logging.info(f'Config Params:\n {OmegaConf.to_yaml(cfg)}') |
|
trainer = pl.Trainer(**cfg.trainer) |
|
exp_manager(trainer, cfg.get("exp_manager", None)) |
|
|
|
|
|
model = MultiLabelIntentSlotClassificationModel(cfg.model, trainer=trainer) |
|
|
|
|
|
logging.info("================================================================================================") |
|
logging.info('Starting training...') |
|
trainer.fit(model) |
|
logging.info('Training finished!') |
|
|
|
|
|
if trainer.fast_dev_run: |
|
return |
|
|
|
|
|
|
|
logging.info("================================================================================================") |
|
logging.info("Starting the testing of the trained model on test set...") |
|
logging.info("We will load the latest model saved checkpoint from the training...") |
|
|
|
|
|
|
|
|
|
eval_model = model |
|
|
|
|
|
eval_model.update_data_dir_for_testing(data_dir=cfg.model.data_dir) |
|
eval_model.setup_test_data(test_data_config=cfg.model.test_ds) |
|
|
|
trainer.test(model=eval_model, ckpt_path=None, verbose=False) |
|
logging.info("Testing finished!") |
|
|
|
|
|
eval_model.optimize_threshold(cfg.model.test_ds, 'dev') |
|
|
|
|
|
logging.info("======================================================================================") |
|
logging.info("Evaluate the model on the given queries...") |
|
|
|
|
|
|
|
queries = [ |
|
'i would like to find a flight from charlotte to las vegas that makes a stop in st. louis', |
|
'on april first i need a ticket from tacoma to san jose departing before 7 am', |
|
'how much is the limousine service in boston', |
|
] |
|
|
|
|
|
pred_intents, pred_slots, pred_list = eval_model.predict_from_examples(queries, cfg.model.test_ds) |
|
logging.info('The prediction results of some sample queries with the trained model:') |
|
|
|
for query, intent, slots in zip(queries, pred_intents, pred_slots): |
|
logging.info(f'Query : {query}') |
|
logging.info(f'Predicted Intents: {intent}') |
|
logging.info(f'Predicted Slots: {slots}') |
|
|
|
logging.info("Inference finished!") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|