Say I already have a method with type annotations:
class Shape:def area(self) -> float:raise NotImplementedError
Which I will then subclass multiple times:
class Circle:def area(self) -> float:return math.pi * self.radius ** 2class Rectangle:def area(self) -> float:return self.height * self.width
As you can see, I'm duplicating the -> float
quite a lot. Say I have 10 different shapes, with multiple methods like this, some of which contain parameters too. Is there a way to just "copy" the annotation from the parent class, similar to what functools.wraps()
does with docstrings?
This might work, though I'm sure to miss the edge cases, like additional arguments:
from functools import partial, update_wrapperdef annotate_from(f):return partial(update_wrapper,wrapped=f,assigned=('__annotations__',),updated=())
which will assign "wrapper" function's __annotations__
attribute from f.__annotations__
(keep in mind that it is not a copy).
According to documents the update_wrapper
function's default for assigned includes __annotations__
already, but I can see why you'd not want to have all the other attributes assigned from wrapped.
With this you can then define your Circle
and Rectangle
as
class Circle:@annotate_from(Shape.area)def area(self):return math.pi * self.radius ** 2class Rectangle:@annotate_from(Shape.area)def area(self):return self.height * self.width
and the result
In [82]: Circle.area.__annotations__
Out[82]: {'return': builtins.float}In [86]: Rectangle.area.__annotations__
Out[86]: {'return': builtins.float}
As a side effect your methods will have an attribute __wrapped__
, which will point to Shape.area
in this case.
A less standard (if you can call the above use of update_wrapper standard) way to accomplish handling of overridden methods can be achieved using a class decorator:
from inspect import getmembers, isfunction, signaturedef override(f):"""Mark method overrides."""f.__override__ = Truereturn fdef _is_method_override(m):return isfunction(m) and getattr(m, '__override__', False)def annotate_overrides(cls):"""Copy annotations of overridden methods."""bases = cls.mro()[1:]for name, method in getmembers(cls, _is_method_override):for base in bases:if hasattr(base, name):breakelse:raise RuntimeError('method {!r} not found in bases of {!r}'.format(name, cls))base_method = getattr(base, name)method.__annotations__ = base_method.__annotations__.copy()return cls
and then:
@annotate_overrides
class Rectangle(Shape):@overridedef area(self):return self.height * self.width
Again, this will not handle overriding methods with additional arguments.