Federated Learning for Medical AI
Updated: Jul 8, 2020
Conventional machine learning works on the paradigm of bringing the data to the code. The data used to train the models is centralized in a single storage either on a local disk or on the cloud at a data center. Data samples are loaded from this centralized store while training the model.
Many valuable datasets however are highly decentralized in nature. Take, for instance, patient data lying with hospitals. The data in custody of every hospital forms a data island, and each such data island has characteristics that it does not share with data islands from other hospitals. Hence, a model trained on data from one hospital might show poor generalization on data from another hospital. The conventional approach of pooling data from all hospitals and building a single model on it presents several challenges. To begin with, hospitals are uncomfortable sharing patient data, and governments policies such as GDPR, HIPAA, CCPA etc. now make it virtually impossible for multiple organizations to pool their private patient data. Secondly, the data itself is in a state of constant change. This pooling exercise, if it has to be successful, must be repeated frequently, and this is even harder than doing it just once.
What is Federated Learning?
The term Federated Learning (FL) was coined by Google researchers in their paper titled ‘Communication-Efficient Learning of Deep Networks from Decentralized Data’. Initially proposed for mobile devices, FL advocates bringing code to the data instead of the other way round.
Federated Learning enables machine learning models to learn on private data without moving the data and compromising privacy.
In an FL scenario, there are multiple users — in our case hospitals — with each of them owning a private data store. The data never leaves the hospital’s computers. A central server maintains a global shared model which is sent to all hospitals (in practice, only a random subset of hospitals) where the clients individually update the model based on their private data. Each client then communicates only the model updates back to the server. The server aggregates these updates and makes one final update to the global model. This constitutes one round of federated learning. This process allows the users to reap benefits of the rich private data every user has without the need of centralizing the data in one place and compromising on privacy.
The single biggest obstacle to developing AI models for medical applications is the lack of sufficient training data. Sufficient data exists, but it is heavily fragmented and owned by different institutions. Federated Learning allows multiple healthcare institutions to share their data to build a global model while still guaranteeing privacy.
There are many challenges involved in practically implementing a federated learning system:
Unreliable connectivity: Some clients may not have reliable Internet connectivity and may fail to respond midway through a round.
Lockstep execution across devices with varying availability: Every round proceeds in a stepwise fashion. As the number of participating clients grows, it becomes increasingly important that the implementation is robust across all points of failure.
Limited storage and compute resources on the client: The clients are typically light in storage and compute power. Especially when the clients are mobile devices, it is important that the FL implementation kicks in only when the device is available (it is plugged in and on an unmetered Internet connection)
Different data distributions across devices: Clients typically house highly private data which may vary widely from one client to another. In our example, the data from one hospital may have nothing in common with the data from another hospital.
Security: Inference Attacks on a client or the server where the model weight updates are exploited to reveal properties of the original data or Model Poisoning where misbehaving clients intentionally contaminate the global model need to be detected and thwarted.
Many open source platforms have sprung up which support developers working on building Federated Learning systems. The most popular amongst them are PySyft and Tensor Flow Federated (TFF).
The rising popularity of edge computing will give a boost to research, development, and adoption of Federated Learning. In the years to come, Federated Learning will become an indispensable tool to carry out privacy-preserving, distributed learning on decentralized data.