Skip to content

Commit a94c06d

Browse files
authored
Merge pull request #5 from YiwenAI/main
update ray train
2 parents c75e3d0 + c192d15 commit a94c06d

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

‎basic/test_ray_train.py‎

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from torchvision import datasets
55
from torchvision.transforms import ToTensor
66

7+
from ray import train
8+
from ray.train.torch import TorchTrainer
9+
from ray.train import ScalingConfig
10+
711

812
def get_dataset():
913
return datasets.FashionMNIST(
@@ -32,6 +36,7 @@ def forward(self, inputs):
3236
return logits
3337

3438

39+
# without distributed training, pure pytorch
3540
def train_func():
3641
num_epochs = 3
3742
batch_size = 64
@@ -54,4 +59,40 @@ def train_func():
5459
print(f"epoch: {epoch}, loss: {loss.item()}")
5560

5661

57-
train_func()
62+
# train_func()
63+
64+
65+
# distributed training
66+
def train_func_distributed():
67+
num_epochs = 3
68+
batch_size = 64
69+
70+
dataset = get_dataset()
71+
dataloader = DataLoader(dataset, batch_size=batch_size)
72+
dataloader = train.torch.prepare_data_loader(dataloader)
73+
74+
model = NeuralNetwork()
75+
model = train.torch.prepare_model(model)
76+
77+
criterion = nn.CrossEntropyLoss()
78+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
79+
80+
for epoch in range(num_epochs):
81+
for inputs, labels in dataloader:
82+
optimizer.zero_grad()
83+
pred = model(inputs)
84+
loss = criterion(pred, labels)
85+
loss.backward()
86+
optimizer.step()
87+
print(f"epoch: {epoch}, loss: {loss.item()}")
88+
89+
90+
# For GPU Training, set `use_gpu` to True.
91+
use_gpu = False
92+
93+
trainer = TorchTrainer(
94+
train_func_distributed,
95+
scaling_config=ScalingConfig(num_workers=4, use_gpu=use_gpu)
96+
)
97+
98+
results = trainer.fit()

0 commit comments

Comments
 (0)