44from torchvision import datasets
55from torchvision .transforms import ToTensor
66
7+ from ray import train
8+ from ray .train .torch import TorchTrainer
9+ from ray .train import ScalingConfig
10+
711
812def get_dataset ():
913 return datasets .FashionMNIST (
@@ -32,6 +36,7 @@ def forward(self, inputs):
3236 return logits
3337
3438
39+ # without distributed training, pure pytorch
3540def train_func ():
3641 num_epochs = 3
3742 batch_size = 64
@@ -54,4 +59,40 @@ def train_func():
5459 print (f"epoch: { epoch } , loss: { loss .item ()} " )
5560
5661
57- train_func ()
62+ # train_func()
63+
64+
65+ # distributed training
66+ def train_func_distributed ():
67+ num_epochs = 3
68+ batch_size = 64
69+
70+ dataset = get_dataset ()
71+ dataloader = DataLoader (dataset , batch_size = batch_size )
72+ dataloader = train .torch .prepare_data_loader (dataloader )
73+
74+ model = NeuralNetwork ()
75+ model = train .torch .prepare_model (model )
76+
77+ criterion = nn .CrossEntropyLoss ()
78+ optimizer = torch .optim .SGD (model .parameters (), lr = 0.01 )
79+
80+ for epoch in range (num_epochs ):
81+ for inputs , labels in dataloader :
82+ optimizer .zero_grad ()
83+ pred = model (inputs )
84+ loss = criterion (pred , labels )
85+ loss .backward ()
86+ optimizer .step ()
87+ print (f"epoch: { epoch } , loss: { loss .item ()} " )
88+
89+
90+ # For GPU Training, set `use_gpu` to True.
91+ use_gpu = False
92+
93+ trainer = TorchTrainer (
94+ train_func_distributed ,
95+ scaling_config = ScalingConfig (num_workers = 4 , use_gpu = use_gpu )
96+ )
97+
98+ results = trainer .fit ()
0 commit comments