olympus.tasks.gan module

class olympus.tasks.gan.GAN(generator: torch.nn.modules.module.Module, discriminator: torch.nn.modules.module.Module, generator_optimizer: torch.optim.optimizer.Optimizer, discriminator_optimizer: torch.optim.optimizer.Optimizer, latent_vector_size: int = 10, criterion: torch.nn.modules.module.Module = CrossEntropyLoss())[source]

Bases: olympus.tasks.task.Task

Attributes:
device
metrics

Methods

criterion(\*input, \*\*kwargs)
eval_loss(self, batch) This is used to compute validation and test loss
fit(self, step, input, context) Execute a single batch
get_space(self, \*\*fidelities) Return missing hyper parameters that need to be set using init
init(self, \*\*kwargs) Used to initialize the hyperparameters is any
load_state_dict(self, state[, strict]) Try to load a previous unfinished state to resume
state_dict(self[, destination, prefix, …]) Save a state the task can go back to if an error occur
accuracy  
discriminate  
discriminate_probabilities  
finish  
generate  
report  
resumed  
set_device  
summary  
accuracy(self, batch, target)[source]
criterion(*input, **kwargs) = CrossEntropyLoss()
discriminate(self, batch)[source]
discriminate_probabilities(self, images)[source]
finish(self)[source]
fit(self, step, input, context)[source]

Execute a single batch

Parameters:
epoch: int

current step in the training process

context: dict

Optional Context

Notes

You should wrap whatever code you have here inside a BadResumeGuard to prevent users from resuming a failed task that can have a bad states

To resume a task, you need to create a clean one with the same hyper parameters. It will pickup automatically where at its last checkpoint

generate(self, latent_vector)[source]
latent_vector_size = 10