from dataclasses import dataclass, field
from olympus.observers.observer import Observer
from olympus.utils import show_dict, TimeThrottler
from olympus.utils.stat import StatStream
from olympus.utils.options import option
from datetime import datetime, timedelta
import warnings
[docs]def get_time_delta(start):
return (datetime.utcnow() - start).total_seconds()
[docs]@dataclass
class Speed(Observer):
batch_size: int = 0
frequency_new_epoch: int = 1
frequency_end_epoch: int = 1
frequency_new_batch: int = 1
frequency_end_batch: int = 1
step_time: StatStream = field(default_factory=lambda: StatStream(drop_first_obs=5))
epoch_time: StatStream = field(default_factory=lambda: StatStream(drop_first_obs=1))
step_start: datetime = field(default_factory=datetime.utcnow)
epoch_start: datetime = field(default_factory=datetime.utcnow)
step: int = 0
epoch: int = 0
total_steps: int = 0
priority: int = 10
[docs] def guess_batch_size(self, input):
try:
if isinstance(input, list):
return input[0].shape[0]
if hasattr(input, "shape"):
return input.shape[0]
except Exception:
return 0
[docs] def on_new_epoch(self, task, epoch, context):
self.epoch = epoch
self.epoch_start = datetime.utcnow()
[docs] def on_end_epoch(self, task, epoch, context=None):
self.epoch_time += get_time_delta(self.epoch_start)
[docs] def on_new_batch(self, task, step, input=None, context=None):
self.step = step
self.total_steps += 1
self.step_start = datetime.utcnow()
[docs] def on_end_batch(self, task, step, input=None, context=None):
self.step_time += get_time_delta(self.step_start)
self.batch_size = self.guess_batch_size(input)
[docs] def value(self):
result = {}
if self.step_time.count > 0:
result["step_time"] = self.step_time.avg
if self.batch_size and self.batch_size > 0:
result["batch_speed"] = self.batch_size / self.step_time.avg
result["batch_size"] = self.batch_size
if self.step_time.count > 2:
result["step_time_sd"] = self.step_time.sd
if self.epoch_time.count > 0:
result["epoch_time"] = self.epoch_time.avg
if self.epoch_time.count > 2:
result["epoch_time_sd"] = self.epoch_time.sd
return result
[docs] def state_dict(self):
return {
"speed_step_time": self.step_time.state_dict(),
"speed_epoch_time": self.step_time.state_dict(),
"total_steps": self.total_steps,
"epoch": self.epoch,
}
[docs] def load_state_dict(self, state_dict):
self.step_time.from_dict(state_dict["speed_step_time"])
self.epoch_time.from_dict(state_dict["speed_epoch_time"])
self.step_start = datetime.utcnow()
self.epoch_start = datetime.utcnow()
self.total_steps = state_dict["total_steps"]
self.epoch = state_dict["epoch"]
[docs]class GuessMaxStep:
def __init__(self, max_steps=None):
if max_steps is None:
self.guessed = False
self.current_max = float("-inf")
else:
self.guessed = True
self.current_max = max_steps
[docs] def update(self, new_step):
if new_step is None:
return
if new_step > self.current_max - 1:
self.current_max = new_step + 1
else:
self.guessed = True
[docs] def max_step(self):
if self.guessed:
return self.current_max
return 0
[docs] def state_dict(self):
return dict(guessed=self.guessed, current_max=self.current_max)
[docs] def load_state_dict(self, state_dict):
self.guessed = state_dict["guessed"]
self.current_max = state_dict["current_max"]
[docs]def fill(msg, size=40):
fill_msg = " " * (min(0, size - len(msg)))
return f"{msg}{fill_msg}"
[docs]def show_progress(speed: Speed):
step_time = speed.step_time
return f"{speed.total_steps:4d} Elapsed time {step_time.total / 60:.2f} min ({step_time.avg:.2f} s/step)"
[docs]@dataclass
class DefaultProgress:
speed: Speed
[docs] def show_progress(self):
return show_progress(self.speed)
[docs] def state_dict(self):
return dict()
[docs] def load_state_dict(self, state_dict):
pass
[docs]@dataclass
class EpochProgress:
speed: Speed
epochs: int
steps: GuessMaxStep = field(default_factory=GuessMaxStep)
[docs] def show_progress(self):
if not self.steps.guessed:
self.steps.update(self.speed.step)
return show_progress(self.speed)
epoch = self.speed.epoch
step = self.speed.step
# Compute Total number of steps
total_steps = self.epochs * self.steps.max_step()
done_steps = self.speed.total_steps
remaining_steps = total_steps - done_steps
remaining_time = remaining_steps * self.speed.step_time.avg
if self.speed.step_time.count > 0:
remaining_time = timedelta(seconds=remaining_time)
else:
remaining_time = "N/A"
completion = done_steps * 100 / total_steps
return (
f"[{completion:6.2f} %] Epoch [{epoch:3d}/{self.epochs:3d}]"
f"[{step + 1:4d}/{self.steps.max_step():4d}] "
f"Remaining: {remaining_time}"
)
[docs] def state_dict(self):
return dict(steps=self.steps.state_dict())
[docs] def load_state_dict(self, state_dict):
self.steps.load_state_dict(state_dict["steps"])
[docs]@dataclass
class StepProgress:
speed: Speed
steps: int
[docs] def show_progress(self):
total = self.speed.total_steps
remaining_steps = self.steps - total
if self.speed.step_time.count > 0:
remaining_time = timedelta(
seconds=remaining_steps * self.speed.step_time.avg
)
else:
remaining_time = "N/A"
return fill(f"[{total:4d}/{self.steps:4d}] Remaining: {remaining_time}")
[docs] def state_dict(self):
return dict()
[docs] def load_state_dict(self, state_dict):
pass
[docs]class ProgressView(Observer):
"""Print progress regularly
Parameters
----------
speed: Speed
speed observer used to gather information about timings
It is used to compute an estimated end time
max_epochs: Optional[int]
The total number of epochs
max_steps: Optional[int]
The total number of steps in a single epochs
Notes
-----
If no max epochs nor max steps are specified it outputs
``12 Elapsed time 0.12 min (1.00 s/step)``
if both max epochs and max steps are specified, it outputs
``[ 25.00 %] Epoch [ 1/ 4][ 12/ 12] Remaining: 0:00:36.042655``
if only max epochs is specified we will try to guess the max steps during the first epoch.
"""
def __init__(self, speed: Speed, max_epochs=None, max_steps=None):
self.print_throttle = option("progress.print.throttle", 30, type=int)
self.print_fun = print
self.throttled_print = TimeThrottler(self.print_fun, every=self.print_throttle)
self.max_epochs = max_epochs
self.max_steps = max_steps
self.speed = speed
self.progress_printer = DefaultProgress(self.speed)
self.progress_printer = self.select_progress_printer(max_epochs, max_steps)
self.frequency_new_epoch: int = 1
self.frequency_end_epoch: int = option("progress.frequency.epoch", 1, type=int)
self.frequency_end_batch: int = option("progress.frequency.batch", 1, type=int)
self.show_metrics: str = option("progress.show.metrics", "epoch")
self.frequency_trial: int = 0
self.worker_id: int = option("worker.id", -1, type=int)
self.first_epoch = None
[docs] def set_max_epochs(self, epochs):
self.max_epochs = epochs
self.select_progress_printer()
[docs] def set_max_steps(self, steps):
self.max_steps = steps
self.select_progress_printer()
[docs] def select_progress_printer(self, max_epochs=None, max_steps=None):
if max_epochs is None:
max_epochs = self.max_epochs
if max_steps is None:
max_steps = self.max_steps
if max_epochs is not None and max_steps is None:
self.progress_printer = EpochProgress(self.speed, max_epochs)
if max_epochs is not None and max_steps is not None:
self.progress_printer = EpochProgress(
self.speed, max_epochs, GuessMaxStep(max_steps)
)
if max_epochs is None and max_steps is not None:
self.progress_printer = StepProgress(self.speed, max_steps)
return self.progress_printer
[docs] def reset_throttle(self):
self.throttled_print = TimeThrottler(self.print_fun, every=self.print_throttle)
[docs] def show_progress(self, start="\r", end="\n"):
worker = ""
if self.worker_id >= 0:
worker = f"[W: {self.worker_id:2d}] "
progress = self.progress_printer.show_progress()
message = f"{start}{worker}{progress}{end}"
self.throttled_print(fill(message), end="")
[docs] def on_start_train(self, task, step=None):
self.print_fun("Starting")
if task:
show_dict(task.metrics.value(), print_fun=self.print_fun)
[docs] def on_resume_train(self, task, epoch):
self.print_fun("Resuming at epoch", epoch)
if task:
show_dict(task.metrics.value(), print_fun=self.print_fun)
[docs] def on_end_train(self, task, step=None):
self.print_fun("Completed training")
if task:
show_dict(task.metrics.value())
[docs] def on_new_epoch(self, task, epoch, context):
if self.first_epoch is None:
self.first_epoch = epoch
if epoch == 0:
warnings.warn(
"First epoch should 1; epoch 0 is used for the untrained model"
)
[docs] def on_end_epoch(self, task, epoch, context):
self.reset_throttle()
self.show_progress("", "\n")
if task is not None and self.show_metrics == "epoch":
show_dict(task.metrics.value(), print_fun=self.print_fun)
[docs] def on_end_batch(self, task, step, input=None, context=None):
self.show_progress()
if task is not None and self.show_metrics == "batch":
show_dict(task.metrics.value(), print_fun=self.print_fun)
[docs] def value(self):
return {}
[docs] def state_dict(self):
return dict(
progress_printer=self.progress_printer.state_dict(),
max_steps=self.max_steps,
max_epochs=self.max_epochs,
)
[docs] def load_state_dict(self, state_dict):
self.progress_printer.load_state_dict(state_dict["progress_printer"])
self.max_steps = state_dict["max_steps"]
self.max_epochs = state_dict["max_epochs"]
self.select_progress_printer()
[docs]@dataclass
class SampleCount(Observer):
sample_count: int = 0
frequency_end_batch: int = 1
[docs] def state_dict(self):
return dict(sample_count=self.sample_count)
[docs] def load_state_dict(self, state_dict):
self.sample_count = state_dict["sample_count"]
[docs] def on_end_batch(self, task, step, input=None, context=None):
if hasattr(input, "shape"):
batch_size = input.shape[0]
elif hasattr(input, "__getitem__"):
batch_size = len(input[0])
elif input is None:
batch_size = 1
else:
batch_size = input.size(0)
self.sample_count += batch_size
[docs] def value(self):
return {
"sample_count": self.sample_count,
}
[docs]@dataclass
class ElapsedRealTime(Observer):
start_time: datetime = field(default_factory=datetime.utcnow)
end_time: datetime = field(default_factory=datetime.utcnow)
[docs] def state_dict(self):
return self.value()
[docs] def load_state_dict(self, state_dict):
self.start_time = self.end_time - timedelta(seconds=state_dict["elapsed_time"])
[docs] def on_end_batch(self, step, task, input=None, context=None):
self.end_time = datetime.utcnow()
[docs] def on_end_train(self, task, step=None):
self.end_time = datetime.utcnow()
@property
def elapsed_time(self):
return (self.end_time - self.start_time).total_seconds()
[docs] def value(self):
return {"elapsed_time": self.elapsed_time}