AWS Machine Learning Blog

Reinventing a cloud-native federated learning architecture on AWS

Machine learning (ML), especially deep learning, requires a large amount of data for improving model performance. Customers often need to train a model with data from different regions, organizations, or AWS accounts. It is challenging to centralize such data for ML due to privacy requirements, high cost of data transfer, or operational complexity.

Federated learning (FL) is a distributed ML approach that trains ML models on distributed datasets. The goal of FL is to improve the accuracy of ML models by using more data, while preserving the privacy and the locality of distributed datasets. FL increases the amount of data available for training ML models, especially data associated with rare and new events, resulting in a more general ML model. Existing partner open-source FL solutions on AWS include FedML and NVIDIA FLARE. These open-source packages are deployed in the cloud by running in virtual machines, without using the cloud-native services available on AWS.

In this blog, you will learn to build a cloud-native FL architecture on AWS. By using infrastructure as code (IaC) tools on AWS, you can deploy FL architectures with ease. Also, a cloud-native architecture takes full advantage of a variety of AWS services with proven security and operational excellence, thereby simplifying the development of FL.

We first discuss different approaches and challenges of FL. We then demonstrate how to build a cloud-native FL architecture on AWS. The sample code to build this architecture is available on GitHub. We use the AWS Cloud Development Kit (AWS CDK) to deploy the architecture with one-click deployment. The sample code demos a scenario where the server and all clients belong to the same organization (the same AWS account), but their datasets cannot be centralized due to data localization requirements. The sample code supports horizontal and synchronous FL for training neural network models. The ML framework used at FL clients is TensorFlow.

Overview of federated learning

FL typically involves a central FL server and a group of clients. Clients are compute nodes that perform local training. In an FL training round, the central server first sends a common global model to a group of clients. Clients train the global model with local data, then provide local models back to the server. The server aggregates the local models into a new global model, then starts a new training round. There may be tens of training rounds until the global model converges or until the number of training rounds reaches a threshold. Therefore, FL exchanges ML models between the central FL server and clients, without moving training data to a central location.

There are two major categories of FL depending on the client type: cross-device and cross-silo. Cross-device FL trains a common global models by keeping all the training data locally on a large number of devices, such as mobile phones or IoT devices, with limited and unstable network connections. Therefore, the design of cross-device FL needs to consider frequent joining and dropout of FL clients.

Cross-silo FL trains a global model on datasets distributed at different organizations and geo-distributed data centers. These datasets are prohibited from moving out of organizations and data center regions due to data protection regulations, operational challenges (such as data duplication and synchronization), or high costs. In contrast with cross-device FL, cross-silo FL assumes that organizations or data centers have reliable network connections, powerful computing resources, and addressable datasets.

FL has been applied to various industries, such as finance, healthcare, medicine, and telecommunications, where privacy preservation is critical or data localization is required. FL has been used to train a global model for financial crime detection among multiple financial institutions. The global model outperforms models trained with only local datasets by 20%. In healthcare, FL has been used to predict mortality of hospitalized patients based on electronic health records from multiple hospitals. The global model predicting mortality outperforms local models at all participating hospitals. FL has also been used for brain tumor segmentation. The global models for brain tumor segmentation perform similarly to the model trained by collecting distributed datasets at a central location. In telecommunications, FL can be applied to edge computing, wireless spectrum management, and 5G core networks.

There are many other ways to classify FL:

  • Horizontal or vertical – Depending on the partition of features in distributed datasets, FL can be classified as horizontal or vertical. In horizontal FL, all distributed datasets have the same set of features. In vertical FL, datasets have different groups of features, requiring additional communication patterns to align samples based on overlapped features.
  • Synchronous or asynchronous – Depending on the aggregation strategy at an FL server, FL can be classified as synchronous or asynchronous. A synchronous FL server aggregates local models from a selected set of clients into a global model. An asynchronous FL server immediately updates the global model after a local model is received from a client, thereby reducing the waiting time and improving training efficiency.
  • Hub-and-spoke or peer-to-peer – The typical FL topology is hub-and-spoke, where a central FL server coordinates a set of clients. Another FL topology is peer-to-peer without any centralized FL server, where FL clients aggregate information from neighboring clients to learn a model.

Challenges in FL

You can address the following challenges using algorithms running at FL servers and clients in a common FL architecture:

  • Data heterogeneity – FL clients’ local data can vary (i.e., data heterogeneity) due to particular geographic locations, organizations, or time windows. Data heterogeneity impacts the accuracy of global models, leading to more training iterations and longer training time. Many solutions have been proposed to mitigate the impact of data heterogeneity, such as optimization algorithms, partial data sharing among clients, and domain adaptation.
  • Privacy preservation – Local and global models may leak private information via an adversarial attack. Many privacy preservation approaches have been proposed for FL. A secure aggregation approach can be used to preserve the privacy of local models exchanged between FL servers and clients. Local and global differential privacy approaches bound the privacy loss by adding noise to local or global models, which provides a controlled trade-off between privacy and model accuracy. Depending on the privacy requirements, combinations of different privacy preservation approaches can be used.
  • Federated analytics – Federated analytics provides statistical measurements of distributed datasets without violating privacy requirements. Federated analytics is important not only for data analysis across distributed datasets before training, but also for model monitoring at inference.

Despite these challenges of FL algorithms, it is critical to build a secure architecture that provides end-to-end FL operations. One important challenge to building such an architecture is to enable the ease of deployment. The architecture must coordinate FL servers and clients for FL model building, training, and deployment, including continuous integration and continuous development (CI/CD) among clients, traceability, and authentication and access control for FL servers and clients. These features are similar to centralized ML operations (ML Ops), but are more challenging to implement because more parties are involved. The architecture also needs to be flexible to implement different FL topologies and synchronous or asynchronous aggregation.

Solution overview

We propose a cloud-native FL architecture on AWS, as shown in the following diagram. The architecture includes a central FL server and two FL clients. In reality, the number of FL clients can reach hundreds for cross-silo clients. The FL server must be on the AWS Cloud because it consists of a suite of microservices offered on the cloud. The FL clients can be on AWS or on the customer premises. The FL clients host their own local dataset and have their own IT and ML system for training ML models.

During FL model training, the FL server and a group of clients exchange ML models. That is, the clients download a global ML model from the server, perform local training, and upload local models to the server. The server downloads local models, aggregates local models into a new global model. This model exchange procedure is a single FL training round. The FL training round repeats until the global model reaches a given accuracy or the number of training rounds reach a threshold.

FL-architecture

Figure 1 – A cloud-native FL architecture for model training between a FL server and FL clients.

Prerequisites

To implement this solution, you need an AWS account to launch the services for a central FL server and the two clients. On-premises FL clients need to install the AWS Command Line Interface (AWS CLI), which allows access to the AWS services at the FL server, including Amazon Simple Queue Service (Amazon SQS), Amazon Simple Storage Service (Amazon S3), and Amazon DynamoDB.

Federated learning steps

In this section, we walk through the proposed architecture in Figure 1. At the FL server, the AWS Step Functions state machine runs a workflow as shown in Figure 2, which executes Steps 0, 1, and 5 from Figure 1. The state machine initiates the AWS services at the server (Step 0) and iterates FL training rounds. For each training round, the state machine sends out an Amazon Simple Notification Service (Amazon SNS) notification to the topic global_model_ready, along with a task token (Step 1). The state machine then pauses and waits for a callback with the task token. There are SQS queues subscribing to the global_model_ready topic. Each SQS queue corresponds to an FL client and queues the notifications sent from the server to the client.

Figure 2 – The workflow at the Step Functions state machine.

Each client keeps pulling messages from its assigned SQS queue. When a global_model_ready notification is received, the client downloads a global model from Amazon S3 (Step 2) and starts local training (Step 3). Local training generates a local model. The client then uploads the local model to Amazon S3 and writes the local model information, along with the received task token, to the DynamoDB table (Step 4).

We implement the FL model registry using Amazon S3 and DynamoDB. We use Amazon S3 to store the global and local models. We use DynamoDB table to store local model information because local model information can be different between FL algorithms, which requires a flexible schema supported by a DynamoDB table.

We also enable a DynamoDB stream to trigger a Lambda function, so that whenever a record is written into the DynamoDB table (when a new local model is received), a Lambda function is triggered to check if required local models are collected (Step 5). If so, the Lambda function runs the aggregation function to aggregate the local models into global models. The resulting global model is written to Amazon S3. The function also sends a callback, along with the task token retrieved from the DynamoDB table, to the Step Functions state machine. The state machine then determines if the FL training should be continued with a new training round or should be stopped based on a condition, for example, the number of training rounds reaching a threshold.

Each FL client uses the following sample code to interact with the FL server. If you want to customize the local training at your FL clients, the localTraining() function can be modified as long as the returned values are local_model_name and local_model_info for uploading to the FL server. You can select any ML framework for training local models at FL clients as long as all clients use the same ML framework.

# Step 2: receive notifications and model file name from its SQS queue
client.receiveNotificationsFromServer(sqs_region, client_queue_name)

# Step 3: download a global model and train locally
local_model_name, local_model_info = client.localTraining(global_model_name, s3_fl_model_registry)

# Step 4: upload the local model and local model info to the FL server
client.uploadToFLServer(s3_fl_model_registry, local_model_name, dynamodb_table_model_info, local_model_info)

The Lambda function for running the aggregation function at the server has the following sample code. If you want to customize the aggregation algorithm, you need to modify the fedAvg() function and the output.

# Step 5: aggregate local models in the Lambda function
def lambda_handler(event, context):
	# obtain task_name from the event triggered by the DynamoDB Stream
	task_name = event['Records'][0]['dynamodb']['Keys']['taskName']['S']

	# retrieve transactions from the DynamoDB table
	transactions = readFromFLServerTaskTable(os.environ['TASKS_TABLE_NAME'], task_name)

	# read local model info from required clients 
	# token is a call back token from the Step Functions state machine
	local_model_info, round_id, token = receiveUpdatedModelsFromClients(transactions, task_name)

	# fedAvg function aggregates local models into a global model and stores the global model in S3
	global_model_name, avg_train_acc, avg_test_acc, avg_train_loss, avg_test_loss = fedAvg(local_model_info, round_id)

	# output sent to the Step Function state machine
	output = {'taskName': task_name, 'roundId': str(round_id), 'trainAcc': str(avg_train_acc), 'testAcc': str(avg_test_acc), 'trainLoss': str(avg_train_loss), 'testLoss': str(avg_test_loss), 'weightsFile': str(global_model_name)}

	# send call back to the Step Functions state machine to report that the task identified by the token successfully completed
	step_client = boto3.client('stepfunctions')
	out_str = json.dumps(output)
	step_client.send_task_success(taskToken=token, output=out_str)

This architecture has two innovative designs. First, the FL server uses serverless services, such as Step Functions and Lambda. Therefore, no computing instance is kept running for the FL server, which minimizes the computing cost. Second, FL clients pull messages from their assigned SQS queues and upload or download models and info to or from services at the FL server. This design avoids the FL server directly accessing resources at the clients, which is critical to provide private and flexible IT and ML environments (on premises or on the AWS Cloud) to FL clients.

Advantages of being cloud-native

This architecture is cloud-native and provides end-to-end transparency by using AWS services with proven security and operational excellence. For example, you can have cross-account clients to assume roles to access the resource at the FL server. For on-premises clients, the AWS CLI and AWS SDK for Python (Boto3) at clients automatically provide secure network connections between the FL server and clients. For clients on the AWS Cloud, you can use AWS PrivateLink and AWS services with data encryption in transit and at rest for data protection. You can use Amazon Cognito and AWS Identity and Access Management (IAM) for the authentication and access control of FL servers and clients. For deploying the trained global model, you can use ML Ops capabilities in Amazon SageMaker.

The cloud-native architecture also enables integration with customized ML frameworks and federated learning algorithms and protocols. For example, you can select a ML framework for training local models at FL clients and customize different aggregation algorithms as scripts running in Lambda functions at the server. Also, you can modify the workflows in Step Functions to accommodate different communication protocols between the server and clients.

Another advantage of the cloud-native architecture is the ease of deployment by using IaC tools offered for the cloud. You can use the AWS Cloud Development Kit (AWS CDK) and AWS CloudFormation for one-click deployment.

Conclusion

New privacy laws continue to be implemented worldwide, and technology infrastructures are rapidly expanding across multiple regions and extending to network edges. Federated learning helps cloud customers use distributed datasets to train accurate ML models in a privacy-preserving manner. Federated learning also supports data localization and potentially saves costs, because it does not require large amounts of raw data to be moved or shared.

You can start experimenting and building cloud-native federated learning architectures for your use cases. You can customize the architecture to support various ML frameworks, such as TensorFlow or PyTorch. You can also customize it to support different FL algorithms, including asynchronous federated learning, aggregation algorithms, and differential privacy algorithms. You can enable this architecture with FL Ops functionalities using ML Ops capabilities in Amazon SageMaker.


About the Authors

Qiong (Jo) Zhang, PhD, is a Senior Partner SA at AWS, specializing in AI/ML. Her current areas of interest include federated learning, distributed training, and generative AI.  She holds 30+ patents and has co-authored 100+ journal/conference papers. She is also the recipient of the Best Paper Award at IEEE NetSoft 2016, IEEE ICC 2011, ONDM 2010, and IEEE GLOBECOM 2005.


Parker Newton
is an applied scientist in AWS Cryptography. He received his Ph.D. in cryptography from U.C. Riverside, specializing in lattice-based cryptography and the complexity of computational learning problems. He is currently working at AWS in secure computation and privacy, designing cryptographic protocols to enable customers to securely run workloads in the cloud while preserving the privacy of their data.

Olivia Choudhury, PhD, is a Senior Partner SA at AWS. She helps partners, in the Healthcare and Life Sciences domain, design, develop, and scale state-of-the-art solutions leveraging AWS. She has a background in genomics, healthcare analytics, federated learning, and privacy-preserving machine learning. Outside of work, she plays board games, paints landscapes, and collects manga.

Gang Fu  is a Healthcare Solution Architect at AWS. He holds a PhD in Pharmaceutical Science from the University of Mississippi and has over ten years of technology and biomedical research experience. He is passionate about technology and the impact it can make on healthcare.

Kris is a renowned leader in machine learning and generative AI, with a career spanning Goldman Sachs, consulting for major banks, and successful ventures like Foglight and SiteRock. He founded Indigo Capital Management and co-founded adaptiveARC, focusing on green energy tech. Kris also supports non-profits aiding assault victims and disadvantaged youth.

Bill Horne is a General Manager in AWS Cryptography. He leads the Cryptographic Computing Program, consisting of a team of applied scientists and engineers who are solving customer problems using emerging technologies like secure multiparty computation and homomorphic encryption. Prior to joining AWS in 2020 he was the VP and General Manager of Intertrust Secure Systems and was the Director of Security Research at Hewlett-Packard Enterprise. He is the author of 60 peer reviewed publications in the areas of security and machine learning, and holds 50 granted patents and 58 patents pending.