olympus.tasks.gan module

class olympus.tasks.gan.GAN(generator: torch.nn.modules.module.Module, discriminator: torch.nn.modules.module.Module, generator_optimizer: torch.optim.optimizer.Optimizer, discriminator_optimizer: torch.optim.optimizer.Optimizer, latent_vector_size: int = 10, criterion: torch.nn.modules.module.Module = CrossEntropyLoss())[source]

Bases: olympus.tasks.task.Task

Attributes:
device
events
metrics

Methods

criterion
eval_loss(batch) This is used to compute validation and test loss
fit(step, input, context) Execute a single batch
get_space() Return missing hyper parameters that need to be set using init
init(**kwargs) Used to initialize the hyperparameters is any
load_state_dict(state[, strict]) Try to load a previous unfinished state to resume
state_dict([destination, prefix, keep_vars]) Save a state the task can go back to if an error occur
accuracy  
discriminate  
discriminate_probabilities  
finish  
generate  
report  
resumed  
set_device  
summary  
accuracy(batch, target)[source]
criterion = CrossEntropyLoss()
discriminate(batch)[source]
discriminate_probabilities(images)[source]
finish()[source]
fit(step, input, context)[source]

Execute a single batch

Parameters:
epoch: int

current step in the training process

context: dict

Optional Context

Notes

You should wrap whatever code you have here inside a BadResumeGuard to prevent users from resuming a failed task that can have a bad states

To resume a task, you need to create a clean one with the same hyper parameters. It will pickup automatically where at its last checkpoint

generate(latent_vector)[source]
latent_vector_size = 10