user magic methods
- user magic methods
- magic_methods
- reflectable_magic_methods
- Pow
- TrueDiv
- FloorDiv
- bool_magic_methods
- wrap_node
- to_node
- user magic methods安裝流程
- 主程式
- _make_user_magic
- \_make\_user\_magic - method_attr
- \_make\_user\_magic - unary\_magic\_impl
- \_make\_user\_magic - binary\_magic\_impl
- \_make\_user\_magic - rbinary\_magic\_impl
- 安裝user magic method
- 調用流程
user magic methods
PyTorch模型支援動態形狀的輸入。在PyTorch的動態形狀系統中,除了之前看過的torch.SymNode之外,還會用到torch.SymInt,torch.SymFloat和torch.SymBool這三個類別。其中torch.SymInt和torch.SymFloat用於表示計算形狀期間產生的symbolic sizes;另外計算形狀期間有可能會需要做邏輯判斷,這時便會用到torch.SymBool來表示symbolic的邏輯值。
但如果我們去查看torch.SymBool,torch.SymInt,torch.SymFloat等類別的定義,卻會發現很多未實作的方法,如__eq__,__lt__等。
其實這些方法不是沒有實作,而是稍後會由torch.fx.experimental.symbolic_shapes模組安裝到torch.SymBool,torch.SymInt,torch.SymFloat上,這些方法被稱為user magic methods。
magic_methods
在PyTorch SymNode 的設計之謎:為何magic methods「看起來沒實作」?處我們已經看過magic methods的定義,以及它們是如何被安裝到torch.SymNode上的。
如果回去查看magic_methods的定義,可以知道magic methods包含了unary magic methods和binary magic methods兩個子集合。
在binary magic methods子集合中,有所謂的reflectable_magic_methods,來看看它的定義。
reflectable_magic_methods
reflectable_magic_methods的定義位於torch.fx.experimental.symbolic_shapes.py。它是一個 將方法的名稱對應到lambda函數 的字典,其中key代表方法的名字,value則為該方法的實現:
# Methods that have a `__foo__` as well as `__rfoo__`reflectable_magic_methods={'add':lambdaa,b:a+b,'sub':lambdaa,b:a-b,'mul':lambdaa,b:a*b,'mod':lambdaa,b:a%b,'pow':lambdaa,b:Pow(a,b),'and':lambdaa,b:a&b,'or':lambdaa,b:a|b,'truediv':lambdaa,b:TrueDiv(a,b),'floordiv':lambdaa,b:FloorDiv(a,b),}注意以上方法的入參和回傳值皆為sympy.Expr。
待會在user magic methods安裝流程章節會看到,如果一個binary method本來的名稱是foo,則它會被以__foo__的名稱安裝到SymInt,SymFloat或SymBool上。
例如sub方法會被以__sub__的名稱安裝到SymInt和SymFloat上,之後使用者便可以透過SymInt.__sub__(other)或SymFloat.__sub__(other)來調用這個方法。
如果一個binary method屬於reflectable_magic_methods,那麼除了SymInt.__sub__和SymFloat.__sub__之外,還會多安裝一個__rsub__方法。
那麼__sub__和__rsub__有何不同之處呢?SymInt.__sub__(other)是由自己減去對方,即由self._expr減去other._expr;SymInt.__rsub__(other)則反過來,是由對方減去自己,即由other._expr減去self._expr。
reflectable_magic_methods中大部份方法在做什麼都一目瞭然,只有pow,truediv,floordir三個方法用到了PyTorch中自定義的類別Pow,TrueDiv和FloorDiv,讓我們來看看它們的定義。
Pow
reflectable_magic_methods中的pow方法對應到PyTorch中自定義的Pow類別,其定義如下:
# Overloaded to be compatible with regular Python.# https://github.com/pytorch/pytorch/issues/90900classPow(sympy.Function):@classmethoddefeval(cls,base,exp):ifexp.is_zero:returnsympy.Integer(1)elifbase.is_zeroandexp<0:raiseZeroDivisionError(f"{base}cannot be raised to a negative power")else:returnbase**exp可以看到Pow類別繼承了sympy.Function,並且定義了class methodeval方法,這實際上是在按照sympy的規則來撰寫自定義函數,詳見Creating a Custom Function。
Pow.eval函數的入參cls是Pow類別本身,底數base和指數exp則皆為sympy.Expr。
Pow.eval函數用於指數運算,分以下幾種情況:
- 當指數
exp是0時:直接回傳1 - 當底數
base是0且指數exp為負時:不合法,raiseZeroDivisionError錯誤。數學細節詳見Are exponents with base 0 even defined? - 在正常情況下則會回傳
base的exp次方
可以看出,Pow.eval函數對的核心是sympy的**運算子,PyTorch中為了處理指數為0的特殊情況和底數為0且指數為負的錯誤,才對sympy的**運算子進行了包裝。
TrueDiv
reflectable_magic_methods中的true_div方法對應到TrueDiv類別,其定義如下:
# Overloaded to be compatible with regular Python.# https://github.com/pytorch/pytorch/issues/90900classTrueDiv(sympy.Function):@classmethoddefeval(cls,base,divisor):ifdivisor.is_zero:raiseZeroDivisionError("division by zero")else:returnbase/divisor此處TrueDiv繼承自sympy.Function,並定義了eval方法,可知TrueDiv也是按照sympy規則來撰寫的自定義函數。
TrueDiv.eval函數的入參cls是TrueDiv類別本身,分子base和分母divisor則皆為sympy.Expr。
TrueDiv.eval函數用於除法運算,分以下幾種情況:
- 在分母
divisor為0的情況下,會raiseZeroDivisionError錯誤 - 否則其行為跟
/運算子一樣
這裡也可以看出,TrueDiv.eval函數的核心是sympy的/運算子,PyTorch中為了處理分母為0的錯誤,才對sympy的/運算子進行了包裝。
FloorDiv
reflectable_magic_methods中的floor_div方法對應到FloorDiv類別,其定義如下:
classFloorDiv(sympy.Function):""" We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) """nargs=(2,)precedence=50# precedence of mul # noqa: F811# Default return type for SymPy assumptions.# https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlersis_real=True@propertydefbase(self):returnself.args[0]@propertydefdivisor(self):returnself.args[1]def_sympystr(self,printer):base=printer.parenthesize(self.base,self.precedence)divisor=printer.parenthesize(self.divisor,self.precedence)returnf"{base}//{divisor}"# SymPy assumptions based on argument types.def_eval_is_real(self):returnfuzzy_or([self.base.is_real,self.divisor.is_real])def_eval_is_integer(self):returnfuzzy_and([self.base.is_integer,self.divisor.is_integer])# Automatic evaluation.# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval@classmethoddefeval(cls,base,divisor):defcheck_supported_type(x):if(x.is_integerisFalseandx.is_realisFalseandx.is_complex)orx.is_Boolean:raiseTypeError(f"unsupported operand type(s) for //: "f"'{type(base).__name__}' and '{type(divisor).__name__}'"f", expected integer or real")check_supported_type(base)check_supported_type(divisor)# We don't provide the same error message as in Python because SymPy# makes it difficult to check the types.ifdivisor.is_zero:raiseZeroDivisionError("division by zero")ifbase.is_zero:returnsympy.S.Zeroifbase.is_integeranddivisor==1:returnbaseifbase.is_realanddivisor==1:returnsympy.floor(base)ifisinstance(base,sympy.Integer)andisinstance(divisor,sympy.Integer):returnbase//divisorifisinstance(base,(sympy.Integer,sympy.Float))andisinstance(divisor,(sympy.Integer,sympy.Float)):returnsympy.floor(base/divisor)ifisinstance(base,FloorDiv):returnFloorDiv(base.args[0],base.args[1]*divisor)ifisinstance(base,sympy.Add):forainbase.args:gcd=sympy.gcd(a,divisor)ifgcd==divisor:returnFloorDiv(base-a,divisor)+a/gcd gcd=sympy.gcd(base,divisor)ifgcd!=1:returnFloorDiv(sympy.simplify(base/gcd),sympy.simplify(divisor/gcd))FloorDiv.eval函數的入參cls是FloorDiv類別本身,分子base和分母divisor則皆為sympy.Expr。
FloorDiv.eval函數會做除法後取floor,分以下幾種情況:
- 在分母
divisor為0的情況下,會raiseZeroDivisionError錯誤 - 在分子
base為0的情況下,會回傳0 - 在分子
base是整數且分母divisor是1的情況下,不需做除法也不需取floor,會直接回傳分子base - 在分子
base是實數且分母divisor是1的情況下,不需做除法,只需要取floor,回傳sympy.floor(base) - 在分子
base是整數且分母divisor是整數的情況下,會直接做整數除法,回傳base // divisor - 在分子
base是sympy.Integer或sympy.Float且分母divisor亦是sympy.Integer或sympy.Float的情況下,會真的做除法,然後取floor,回傳sympy.floor(base / divisor) - 在分子
base是FloorDiv的情況下,會將分母divisor與base.args[1]相乘,當作新的分母,然後回傳FloorDiv(base.args[0], base.args[1] * divisor) - 在分子
base是sympy.Add的情況下,會遍歷base.args(加數、b)中的每個元素,嘗試找出與分母divisor的最大公因數gcd與divisor相等的元素a。如果找到了,則先單獨計算該元素與divisor的商,即a / gcd,因為a是gcd的倍數,所以a / gcd是一個整數。然後再加上其它元素之和與divisor的商的floor值,即FloorDiv(base - a, divisor),最後回傳 - 最後,如果分子
base與分母divisor的最大公因數gcd不等於1,則會先將分子分母各自除以gcd、各自簡化後再做FloorDiv運算。回傳值為:FloorDiv(sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)) - 如果最後的if條件不成立不會return值回去?
bool_magic_methods
在magic_methods的子集合中,以下方法屬於bool magic methods,會被安裝到SymBool上:
bool_magic_methods={"and","or","sym_not"}magic methods中不屬於bool magic methods者則將會同時被安裝到SymInt和SymFloat兩個類別上。
wrap_node
在等一下會看到的_make_user_magic函數中會大量用到wrap_node函數,其定義位於torch/fx/experimental/symbolic_shapes.py:
defwrap_node(x):# TODO: let C++ also take advantage of thisifisinstance(x,SymNode)andx.constantisnotNone:returnx.constantifx.is_int():returnSymInt(x)elifx.is_float():returnSymFloat(x)elifx.is_bool():returnSymBool(x)else:raiseAssertionError(f"unrecognized return type{x}")wrap_node接受的參數x為一SymNode物件,而SymNode有個constant成員變數:
self.constant:Optional[Union[int,float,bool]]=constantwrap_node函數會檢查SymNode x的constant成員變數,分以下幾種情況:
- 在
SymNode x的constant成員變數非空的情況下,wrap_node會取出constant回傳,而constant的型別是int,float或bool其中之一。 - 如果
constant成員變數為空,則wrap_node會依據型別檢查函數回傳的結果,決定將SymNode包裝成SymInt,SymFloat或者是SymBool,最後將包裝後的物件回傳。 - 如果
SymNode x不滿足上述情況,則會將x的型別視為不合法,會raiseAssertionError錯誤。
to_node
在等一下會看到的binary_magic_impl函數中會用到to_node函數,其定義位於torch/fx/experimental/symbolic_shapes.py:
defto_node(self,num):ifisinstance(num,SymTypes):returnnum.nodeeliftype(num)isbool:returnself.wrap_bool(num)eliftype(num)isint:returnself.wrap_int(num)eliftype(num)isfloat:returnself.wrap_float(num)else:# NotImplemented is important so that Python tries the# other magic methodreturnNotImplementedto_node函數接受以下參數:
self:SymNode物件num:可能是SymTypes(包含SymInt,SymFloat和SymBool)或Python中的bool, float或int
首先檢查num是否為SymTypes:
SymTypes=(SymInt,SymFloat,SymBool)如果num是SymTypes,則to_node函數會取出其node成員變數(型別為SymNode)回傳;如果是Python中的bool, float或int,則會調用對應的wrap函數,將其包裝成SymNode後回傳。
在wrap_bool函數中會檢查入參是否為Python中的bool,如果是,便將sympy.true或sympy.false當作expr參數傳入SymNode的建構子,重新建構一個SymNode後回傳。
wrap_int、wrap_float也類似,分別會將Python中的int和float包裝成SymNode後回傳。
另外要注意的一點是,在wrap_bool,wrap_int和wrap_float等函數回傳的SymNode物件中會將constant成員變數設定為num。
user magic methods安裝流程
為SymInt,SymFloat或SymBool安裝user magic methods的程式碼位於torch/fx/experimental/symbolic_shapes.py。
主程式
在安裝user magic methods的主程式中,會遍歷magic_methods,一一對 函數名稱method和lambda函數func的pair 呼叫_make_user_magic函數:
formethod,funcinmagic_methods.items():ifmethodinbool_magic_methods:_make_user_magic(method,SymBool)else:_make_user_magic(method,SymInt)_make_user_magic(method,SymFloat)這段程式碼會為SymBool安裝magic_methods中屬於bool_magic_methods的方法,包括__and__,__or__,__sym__not__。
為SymInt和SymFloat安裝magic_methods中不屬於bool_magic_methods的方法,包括__add__,__sub__,__mul__,__mod__,__pow__,__and__,__or__,__truediv__,__floordiv__,__sym__not__,__eq__,__ne__,__gt__,__lt__,__le__,__ge__,__floor__,__sym__float__,__ceil__,__neg__,__sym__min__,__sym__max__,__sym__sqrt。
最後還會為SymInt和SymFloat安裝reflectable magic methods,包括__radd__,__rsub__,__rmul__,__rmod__,__rpow__,__rand__,__ror__,__rtruediv__,__rfloordiv__。
_make_user_magic
_make_user_magic函數的作用是為SymInt,SymFloat或SymBool安裝名為__method__的方法,等一下會看到,其實__method__方法就是我們在PyTorch SymNode 的設計之謎:為何magic methods「看起來沒實作」?見過的SymNode._method_attr方法的包裝。
_make_user_magic接受method和user_type兩個參數:
method代表方法的名稱,方法會以__method__的名稱被安裝user_type:要為哪個類別安裝方法,可以是SymInt,SymFloat或SymBool其中之一
def_make_user_magic(method,user_type):# User magic takes care of wrapping the other operand into a node,# so that our internal logic can assume everything is nodesifmethodinmagic_methods_on_operator_with_trailing_underscore:method_attr=f"{method}_"else:method_attr=methoddefunary_magic_impl(self):returnwrap_node(getattr(self.node,method_attr)())defbinary_magic_impl(self,other):other_node=to_node(self.node,other)ifother_nodeisNotImplemented:returnNotImplementedreturnwrap_node(getattr(self.node,method_attr)(other_node))defrbinary_magic_impl(self,other):other_node=to_node(self.node,other)ifother_nodeisNotImplemented:returnNotImplementedreturnwrap_node(getattr(other_node,method_attr)(self.node))ifmethodinunary_magic_methods:setattr(user_type,f"__{method}__",unary_magic_impl)else:setattr(user_type,f"__{method}__",binary_magic_impl)ifmethodinreflectable_magic_methods:setattr(user_type,f"__r{method}__",rbinary_magic_impl)這個函數比較長,我們可以將它拆成五個部份來看:
- 一開始定義了
method_attr變數 - 中間定義了
unary_magic_impl子函數 - 接著定義了
binary_magic_impl子函數 - 接著定義了
rbinary_magic_impl子函數 - 最後則是實際把
unary_magic_impl,binary_magic_impl或rbinary_magic_impl安裝在user_type上
_make_user_magic - method_attr
如果入參method在 magic_methods_on_operator_with_trailing_underscore (包括and和or)中,則會將method_attr變數設為method + "_",否則將method_attr設為method。
ifmethodinmagic_methods_on_operator_with_trailing_underscore:method_attr=f"{method}_"else:method_attr=method待會會看到,unary_magic_impl,binary_magic_impl和rbinary_magic_impl都是對SymNode方法的包裝,而它們調用的SymNode方法名稱即為method_attr。
_make_user_magic - unary_magic_impl
因為unary_magic_impl函數即將被安裝在的SymInt,SymFloat或SymBool身上,可知unary_magic_impl的參數self就是SymInt,SymFloat或SymBool其中之一:
defunary_magic_impl(self):returnwrap_node(getattr(self.node,method_attr)())SymBool,SymInt或SymFloat都有一個SymNode型別的成員變數node,此處的self.node將取出其node成員變數。
getattr(self.node, method_attr)會獲取SymNode的method_attr方法。
參考四則運算函數,在SymNode.method_attr方法中會調用SymNode._method_attr方法。
注意到SymNode._method_attr方法就是PyTorch SymNode 的設計之謎:為何magic methods「看起來沒實作」? - magic method安裝流程處由_make_node_magic安裝的unary_magic_impl,它會接受此處傳入的self.node(型別為SymNode)作為參數。
在_make_node_magic裡的unary_magic_impl(它與此處_make_user_magic裡的unary_magic_impl是兩個不同的函數)中,會透過SymNode.expr方法把SymNode的_expr成員變數取出,把它當作參數傳入magic methods字典裡的lambda函數,lambda函數用輸入的sympy.Expr做運算後會回傳另一個sympy.Expr物件。_make_node_magic裡的unary_magic_impl會把回傳的sympy.Expr做包裝,得到一個SymNode之後回傳。
在此處_make_user_magic裡的unary_magic_impl函數中,會將這個SymNode用wrap_node函數包起來。
wrap_node函數在入參SymNode的constant成員變數非空的情況下會回傳底層的Pythonint,float或bool;如果constant成員變數為空,則會將入參SymNode包裝成SymInt,SymFloat或SymBool後回傳。
因為在_make_node_magic - unary_magic_impl處創建的SymNode並未設定constant成員變數,所以此處wrap_node函數回傳的是SymInt,SymFloat或SymBool其中之一。
總結一下,unary_magic_impl接受SymInt,SymFloat或SymBool為參數,調用底層的SymNodemagic method對它做運算,最後同樣回傳SymInt,SymFloat或者是SymBool。
_make_user_magic - binary_magic_impl
比起unary_magic_impl,binary_magic_impl多了一個參數other:
defbinary_magic_impl(self,other):在無法保證other之型別的情況下,需要做以下前處理:
other_node=to_node(self.node,other)to_node函數的作用依照other的型別有所不同:
- 在
other屬於SymTypes = (SymInt, SymFloat, SymBool)的情況下,to_node會直接取出它們的node成員變數(也就是SymNode)回傳 - 在
other屬於int, float, bool的情況下,to_node會用wrap_int, wrap_float或wrap_bool函數對它們做包裝,得到一個SymNode後回傳
接著透過()運算子調用self.node.method_attr方法,傳入other_node,得到一個SymNode物件。最後用wrap_node函數將SymNode包裝成SymInt,SymFloat或者是SymBool後回傳:
returnwrap_node(getattr(self.node,method_attr)(other_node))_make_user_magic - rbinary_magic_impl
在method屬於reflectable_magic_methods的情況下,會額外安裝所謂的rbinary函數。
rbinary_magic_impl函數與binary_magic_impl大體相同,不同之處在於以下這行程式碼:
getattr(other_node,method_attr)(self.node)binary_magic_impl是在self.node上調用method_attr,並把other_node當作參數傳入;rbinary_magic_impl則反過來,在other_node上調用method_attr,把self.node當作參數傳入。
以pow函數為例,binary版本是以self.expr為底數,other.expr為指數;rbinary版本則是相反,以other.expr為底數,self.expr為指數。
安裝user magic method
定義完必要的子函數後,在_make_user_magic的最後會檢查入參method是否屬於unary_magic_methods、將它們分為unary和binary兩類。
對於unary method,就用unary_magic_impl包裝,然後安裝到user_type上;對於binary method,則用binary_magic_impl做包裝,同樣安裝到user_type上。
在binary method底下,還有個子集合reflectable_magic_methods,如果method屬於reflectable_magic_methods,則會額外用rbinary_magic_impl包裝,然後安裝到user_type上。
ifmethodinunary_magic_methods:setattr(user_type,f"__{method}__",unary_magic_impl)else:setattr(user_type,f"__{method}__",binary_magic_impl)ifmethodinreflectable_magic_methods:setattr(user_type,f"__r{method}__",rbinary_magic_impl)之後我們就可以透過對SymInt、SymFloat或SymBool物件呼叫__method_attr__或__rmethod_attr__方法來調用。讓我們來驗證一下,在Python命令行裡查看SymBool.__and__的方法,會出現以下輸出:
>>>torch.fx.experimental.symbolic_shapes.SymBool.__and__<function _make_user_magic.<locals>.binary_magic_impl at0x7fa81db14280>在Python命令行裡查看SymInt.__add__和SymInt.__radd__的方法,會出現以下輸出:
importtorch>>>torch.fx.experimental.symbolic_shapes.SymInt.__add__<function _make_user_magic.<locals>.binary_magic_impl at0x7fa81db13700>>>>torch.fx.experimental.symbolic_shapes.SymInt.__radd__<function _make_user_magic.<locals>.rbinary_magic_impl at0x7fa81db13790>從輸出可以看出來,SymBool.__and__確實跟_make_user_magic和binary_magic_impl有關,SymInt.__add__和SymInt.__radd__也是如此,代表它們確實是由torch/fx/experimental/symbolic_shapes.py安裝的user magic methods。
調用流程
使用者可以透過SymInt / 1.的語法來做除法運算,此處用到了SymInt的/運算子,但從torch.SymInt中卻找不到/運算子的定義。
參考Python __truediv__() Magic Method:
to evaluate the expression x / y, Python attempts to call x.__truediv__(y)在Python中使用/運算子時,底層會呼叫__truediv__。
而我們知道,此處調用的SymInt.__truediv__方法是在SymInt被定義後才被安裝上去的user magic method。
注意到SymInt的__truediv__方法接受的運算元是SymInt本身和一個數字1.,兩個不同型別的運算元該如何做運算呢?
參考binary_magic_impl,在binary_magic_impl函數中有個前處理,會調用to_node函數將1.包裝成SymNode,如此一來__truediv__的兩個運算元便都是SymNode型別了。
如user magic methods安裝流程章節所述,SymInt.__truediv__會呼叫SymNode.truediv,而根據四則運算函數,SymNode.truediv會進一步呼叫SymNode._truediv。
參考_make_node_magic - binary_magic_impl,在SymNode._truediv中,會先取出兩個SymNode的expr成員變數(型別為sympy.Expr),對它們做TrueDiv操作,得到新的sympy.Expr。在SymNode._truediv的最後,會將運算得到的sympy.Expr當作SymNode建構子的expr參數,創建一個SymNode後回傳。注意此處創建出來的SymNode的pytype為float。
SymInt.__truediv__得到SymNode._truediv和SymNode.truediv回傳的SymNode後,會再調用wrap_node函數。
在wrap_node函數中,如果SymNode的pytype為float,則會將SymNode包裝成SymFloat後回傳。
所以SymInt / 1.會得到一個SymFloat。
以下三行分別代表這整個過程中的函數調用鏈,各函數的參數和回傳值,整理如下:
SymInt.__truediv__→SymNode.truediv→SymNode._truediv→ lambda函數 →TrueDiv
SymInt和float→ 兩個SymNode→ 兩個SymNode→ 兩個sympy.Expr→ 兩個sympy.Expr
↓
SymFloat←SymNode←SymNode←sympy.Expr←sympy.Expr
注:其中SymInt.__truediv__即透過_make_user_magic被安裝到SymInt上的binary_magic_impl,SymNode._truediv即透過_make_node_magic被安裝到SymNode上的binary_magic_impl。