Training
Training Command
Quick Start
This example demonstrates how to train a model on the LibriSpeech dataset using the base
model configuration.
This guide assumes that the user has followed the installation guide
and has prepared LibriSpeech according to the data preparation guide.
Selecting the batch size arguments is based on the machine specifications. More information on choosing them can be found here.
Recommendations for LibriSpeech training are:
- a global batch size of 1024 for a 24GB GPU
- use all
train-*
subsets and validate ondev-clean
- 42000 steps is sufficient for 960hrs of train data
- adjust number of GPUs using the
--num_gpus=<NUM_GPU>
argument
To launch training inside the container, using a single GPU, run the following command:
./scripts/train.sh \
--data_dir=/datasets/LibriSpeech \
--train_manifests librispeech-train-clean-100-flac.json librispeech-train-clean-360-flac.json librispeech-train-other-500-flac.json \
--val_manifests librispeech-dev-clean-flac.json \
--model_config configs/base-8703sp_run.yaml \
--num_gpus 1 \
--global_batch_size 1024 \
--grad_accumulation_batches 16 \
--batch_split_factor 8 \
--val_batch_size 1 \
--training_steps 42000
The output of the training command is logged to /results/training_log_[timestamp].txt
.
The arguments are logged to /results/training_args_[timestamp].json
,
and the config file is saved to /results/[config file name]_[timestamp].yaml
.
Defaults to update for your own data
When training on your own data you will need to change the following args from their defaults to reflect your setup:
--data_dir
--train_manifests
/--train_tar_files
- To specify multiple training manifests, use
--train_manifests
followed by space-delimited file names, like this:--train_manifests first.json second.json third.json
.
- To specify multiple training manifests, use
--val_manifests
/--val_tar_files
/(--val_audio_dir
+--val_txt_dir
)--model_config=configs/base-8703sp_run.yaml
(or the_run.yaml
config file created by yourscripts/preprocess_<your dataset>.sh
script)
The audio paths stored in manifests are relative with respect to --data_dir
. For example,
if your audio file path is train/1.flac
and the data_dir is /datasets/LibriSpeech
, then the dataloader
will try to load audio from /datasets/LibriSpeech/train/1.flac
.
The learning-rate scheduler argument defaults are tested on 1k-50k hrs of data but when training on larger datasets than this you may need to tune the values. These arguments are:
--warmup_steps
: number of steps over which learning rate is linearly increased from--min_learning_rate
--hold_steps
: number of steps over which the learning rate is kept constant after warmup--half_life_steps
: the half life (in steps) for exponential learning rate decay
If you are using more than 50k hrs, it is recommended to start with half_life_steps=10880
and increase if necessary. Note that increasing
--half_life_steps
increases the probability of diverging later in training.
Arguments
To resume training or fine tune a checkpoint see the documentation here.
The default setup saves an overwriting checkpoint every time the Word Error Rate (WER) improves on the dev set.
Also, a non-overwriting checkpoint is saved at the end of training.
By default, checkpoints are saved every 5000 steps, and the frequency can be changed by setting --save_frequency=N
.
For a complete set of arguments and their respective docstrings see
args/train.py
and
args/shared.py
.
Data Augmentation for Difficult Target Data
If you are targeting a production setting where background noise is common or audio arrives at 8kHZ, see here for guidelines.
Monitor training
To view the progress of your training you can use TensorBoard. See the TensorBoard documentation for more information of how to set up and use TensorBoard.
Profiling
To profile training, see these instructions.
Controlling emission latency
See these instructions on how to control emission latency of a model.
Next Steps
Having trained a model:
- If you'd like to evaluate it on more test/validation data go to the validation docs.
- If you'd like to export a model checkpoint for inference go to the hardware export docs.