import torch
from typing import TypeVar
[docs]class Tensor(metaclass=TensorMeta):
def __init__(self, shape, device, type):
self.shape = shape
self.device = device
self.type = type
def __getitem__(self, *args):
shape = []
for arg in args:
if isinstance(arg, int):
shape.append(arg)
elif isinstance(arg, tuple):
shape = arg
elif isinstance(arg, TypeVar):
shape = arg
elif isinstance(arg, torch.device):
self.device = arg
elif isinstance(arg, torch.dtype):
self.dtype = arg
else:
print(f'{arg} is not recognized as a Tensor parameter')
return self
N = TypeVar('N')
C = TypeVar('C')
H = TypeVar('H')
W = TypeVar('W')
CHW = C, H, W
NCHW = N, C, H, W
Image = Tensor[CHW]
[docs]class Bound1D:
def __init__(self, min, max):
self.min = min
self.max = max
[docs]class VariableShape:
def __init__(self, **shapes):
self.shapes = shapes
[docs]class DictionaryShape:
def __init__(self, *keys):
self.keys = keys