Exploring the Effect of Data Heterogeneity in Federated Learning for Language Modeling


  • MARINA LIN Department of Computer Science, George Mason University, Fairfax, VA
  • Michael Crawshaw Department of Computer Science, George Mason University, Fairfax, VA
  • Mingrui Liu Department of Computer Science, George Mason University, Fairfax, VA




Federated Learning (FL) is a large-scale distributed paradigm for machine learning characterized by decentralized data and limited communication. Data heterogeneity, or differences in data across clients, is a central challenge in FL that can significantly slow down the learning process. Language modeling involves predicting the probability of a sequence of words in context and is the foundation for Chatbots such as ChatGPT. Currently, the effect of data heterogeneity is not well understood in practical settings such as in language modeling. For instance, (i) it is unclear whether the conventional assumption of bounded gradient dissimilarity holds during training of language models. Furthermore, (ii) algorithms introduced to handle data heterogeneity rely on assumptions which may not hold in practice. An example of this is the SCAFFOLD algorithm relying on smoothness of the objective function, a condition that is often violated by language models. In this work, the role of heterogeneity is investigated while training Recurrent Neural Networks (RNNs) for language modeling by experimentally addressing (i) and (ii) above. To address point (i), the gradient dissimilarity and its correlation to performance degradation under data heterogeneity is measured, with the goal of testing the theoretical assumption above. To address point (ii), SCAFFOLD is compared against baselines to determine whether techniques to mitigate heterogeneity are helpful for language modeling tasks. By learning about the impact of data heterogeneity in the context on language modeling, we can create more personalized and accurate chatbots to enhance user experience. 





College of Engineering and Computing: Department of Computer Science