Multi-Task Learning: A Simple Explanation

The overall premise of machine learning is to learn from historical date and attempt to infer, estimate, or predict outcomes. Let us assume that machine learning models are fed by a set of parameters (collectively referred to as features) that can influence the outcome produced by the model (referred to as labels). In most cases the goal is to focus on a single well-defined task. The process of generating a prediction or estimation of a label is done through the minimization of a cost function during training. 

As an example, an object recognition models learn how to identify objects once they are trained by a large labeled datasets and become capable of identifying objects once they are presented with unlabeled pictures during inference. Despite complexities involved in building and training effective object recognition models, they essentially are trained to do a single task and that is object recognition.

Well-designed and trained models often yield good results but it is not unusual to run into models exhibiting poor performance that stubbornly resist even the most aggressive refinements. This can happen for a variety of reasons, one in particular (relevant to multi-task learning) is caused by models becoming fixated on a narrow set of features (overfitting) and lose sight of the bigger picture. 

Muti-Task learning can often help improve the performance of such models. it basically entails expanding the model and artificially adding auxiliary tasks that are not really targeted by the original model but happen to be relevant to the main task. The expanded model is then trained to optimize both tasks (primary and auxiliary) at the same time. As an example, imagine having difficulty extracting acceptable performance from a model built to predict the failure probability of a jet engine in a given period. This model probably can be fed by a large set of environmental and operational parameters such as flight hours, frequency of service, previous failures, number of take offs and landings, and many more. In Multi-Task learning we can artificially introduce auxiliary tasks such as predicting the probability of the engine temperature or vibration to exceed certain thresholds. While learning about engine temperature or vibration is not our primary interest but excess temperature and vibration certainly can’t do much good to extend the life of any type of machinery let alone a jet engine.  There are many types and implementations of Multi-task models but one popular method entails building a deep neural network that accept features from all tasks (primary and auxiliary). Its first few hidden layers are shared by all tasks but the upper hidden layers become task dependent. (see the Fig. 2 below). In other words, the training is initially done by having the “big picture in mind” but take a more refined path to produce the final outcome.

So why does this work? While covering this topic is a daunting task for this post but you can view this as not letting the model be overly biased by the noise of just one task. In other words, forcing the training to take place in a regime that produce several tasks with uncorrelated noise rather than letting a single runaway noise source to pollute the outcome. Another way of thinking about this is imagining a larger more diverse model with greater training dataset with very diverse feature set.