from dataclasses import dataclass, field
from olympus.observers.observer import Observer
from olympus.utils import show_dict
from olympus.utils.stat import StatStream
from olympus.utils.options import option
from datetime import datetime, timedelta
from typing import Optional
[docs]def get_time_delta(start):
return (datetime.utcnow() - start).total_seconds()
[docs]@dataclass
class Speed(Observer):
batch_size: int = 0
frequency_new_epoch: int = 1
frequency_end_epoch: int = 1
frequency_new_batch: int = 1
frequency_end_batch: int = 1
step_time: StatStream = field(default_factory=lambda: StatStream(drop_first_obs=5))
epoch_time: StatStream = field(default_factory=lambda: StatStream(drop_first_obs=1))
step_start: datetime = field(default_factory=datetime.utcnow)
epoch_start: datetime = field(default_factory=datetime.utcnow)
[docs] def guess_batch_size(self, input):
try:
return input[0].shape[0]
except Exception:
return 0
[docs] def on_new_epoch(self, epoch, task, context):
self.epoch_start = datetime.utcnow()
[docs] def on_end_epoch(self, task, epoch, context=None):
self.epoch_time += get_time_delta(self.epoch_start)
[docs] def on_new_batch(self, task, step, input=None, context=None):
self.step_start = datetime.utcnow()
[docs] def on_end_batch(self, task, step, input=None, context=None):
self.step_time += get_time_delta(self.step_start)
self.batch_size = self.guess_batch_size(input)
[docs] def value(self):
result = {}
if self.step_time.count > 0:
result['step_time'] = self.step_time.avg
if self.batch_size > 0:
result['batch_speed'] = self.batch_size / self.step_time.avg
if self.step_time.count > 2:
result['step_time_sd'] = self.step_time.sd
if self.epoch_time.count > 0:
result['epoch_time'] = self.epoch_time.avg
if self.epoch_time.count > 2:
result['epoch_time_sd'] = self.epoch_time.sd
return result
[docs] def state_dict(self):
return {
'speed_step_time': self.step_time.state_dict(),
'speed_epoch_time': self.step_time.state_dict(),
}
[docs] def load_state_dict(self, state_dict):
self.step_time.from_dict(state_dict['speed_step_time'])
self.epoch_time.from_dict(state_dict['speed_epoch_time'])
self.step_start = datetime.utcnow()
self.epoch_start = datetime.utcnow()
[docs]@dataclass
class ProgressView(Observer):
speed_observer: Optional[Speed] = None
print_fun = print
max_epoch: int = 0
max_step: int = 0
step_length: int = 0
epoch: int = 0
step: int = 0
multiplier: int = 0
frequency_end_epoch: int = field(
default_factory=lambda: option('progress.frequency.epoch', 1, type=int))
frequency_end_batch: int = field(
default_factory=lambda: option('progress.frequency.batch', 1, type=int))
show_metrics: str = field(
default_factory=lambda: option('progress.show.metrics', 'epoch'))
frequency_trial: int = 0
orion_handle = None
worker_id: int = option('worker.id', -1, type=int)
[docs] def show_progress(self, epoch, step=None):
if step is None:
step = ' ' * self.step_length
else:
step = f'Step [{step:3d}/{self.max_step:3d}]'
self.step_length = len(step)
hpo = ''
if self.orion_handle is not None:
hpo_completion = self.overall_progress()
hpo = f'HPO [{hpo_completion:6.2f}%] '
worker = ''
if self.worker_id >= 0:
worker = f'[W: {self.worker_id:2d}] '
eta = ''
if self.speed_observer:
eta = self.eta(self.speed_observer, epoch)
self.print_fun(
f'\r{worker}{hpo}Epoch [{epoch:3d}/{self.max_epoch:3d}] {step} {eta}', end='')
[docs] def overall_progress(self):
"""Return the overall HPO progress in % completion"""
return len(self.orion_handle.fetch_trials_by_status('completed')) * 100 / self.number_of_trials()
[docs] def number_of_trials(self):
# FIXME: Get max trials for the algo itself
return self.orion_handle.max_trials
[docs] def estimate_time_trial_finish(self, obs, epoch):
"""Estimate when a trial will finish"""
if obs.step_time.count == 0:
return None
total_steps = self.max_step * self.max_epoch
spent_steps = self.max_step * epoch + self.step
remaining_steps = total_steps - spent_steps
avg = obs.step_time.avg
# if we spent enough epochs estimate using both duration
if obs.epoch_time.count > 0:
avg = (avg + obs.epoch_time.avg / float(self.max_step)) / 2
step_estimate = avg * remaining_steps
return step_estimate
[docs] def eta(self, obs, epoch):
step_estimate = self.estimate_time_trial_finish(obs, epoch)
if step_estimate:
return f'ETA: {step_estimate / 60:9.4f} min'
return ''
[docs] def on_end_epoch(self, task, epoch, context):
self.epoch = epoch
self.max_epoch = max(self.epoch, self.max_epoch)
self.print_fun()
self.show_progress(epoch)
self.print_fun()
if self.show_metrics == 'epoch':
show_dict(task.metrics.value())
[docs] def on_end_batch(self, task, step, input=None, context=None):
self.step = step
self.max_step = max(step, self.max_step)
self.show_progress(self.epoch, step)
if self.show_metrics == 'batch':
show_dict(task.metrics.value())
[docs] def init_speed_observer(self, task):
if not self.speed_observer and task:
self.speed_observer = task.metrics.get('Speed', None)
[docs] def value(self):
return {}
[docs] def state_dict(self):
return dict(
max_epoch=self.max_epoch,
max_step=self.max_step
)
[docs] def load_state_dict(self, state_dict):
self.max_epoch = state_dict['max_epoch']
self.max_step = state_dict['max_step']
[docs]@dataclass
class SampleCount(Observer):
sample_count: int = 0
epoch: int = 0
frequency_end_batch: int = 1
frequency_end_epoch: int = 1
[docs] def state_dict(self):
return dict(epoch=self.epoch, sample_count=self.sample_count)
[docs] def load_state_dict(self, state_dict):
self.sample_count = state_dict['sample_count']
self.epoch = state_dict['epoch']
[docs] def on_end_epoch(self, task, epoch, context):
self.epoch = epoch
[docs] def on_end_batch(self, task, step, input=None, context=None):
if hasattr(input, '__getitem__'):
batch_size = len(input[0])
else:
batch_size = input.size(0)
self.sample_count += batch_size
[docs] def value(self):
return {
'sample_count': self.sample_count,
'epoch': self.epoch
}
[docs]@dataclass
class ElapsedRealTime(Observer):
start_time: datetime = field(default_factory=datetime.utcnow)
end_time: datetime = field(default_factory=datetime.utcnow)
frequency_end_batch: int = 1
frequency_end_train: int = 1
[docs] def state_dict(self):
return self.value()
[docs] def load_state_dict(self, state_dict):
self.start_time = self.end_time - timedelta(seconds=state_dict['elapsed_time'])
[docs] def on_end_batch(self, step, task, input=None, context=None):
self.end_time = datetime.utcnow()
[docs] def on_end_train(self, task, step=None):
self.end_time = datetime.utcnow()
@property
def elapsed_time(self):
return (self.end_time - self.start_time).total_seconds()
[docs] def value(self):
return {
'elapsed_time': self.elapsed_time
}