芜湖市网站建设_网站建设公司_论坛网站_seo优化
2025/12/17 0:12:27 网站建设 项目流程

user magic methods

  • user magic methods
    • magic_methods
    • reflectable_magic_methods
      • Pow
      • TrueDiv
      • FloorDiv
    • bool_magic_methods
    • wrap_node
    • to_node
  • user magic methods安裝流程
    • 主程式
    • _make_user_magic
      • \_make\_user\_magic - method_attr
      • \_make\_user\_magic - unary\_magic\_impl
      • \_make\_user\_magic - binary\_magic\_impl
      • \_make\_user\_magic - rbinary\_magic\_impl
      • 安裝user magic method
  • 調用流程

user magic methods

PyTorch模型支援動態形狀的輸入。在PyTorch的動態形狀系統中,除了之前看過的torch.SymNode之外,還會用到torch.SymInt,torch.SymFloattorch.SymBool這三個類別。其中torch.SymInttorch.SymFloat用於表示計算形狀期間產生的symbolic sizes;另外計算形狀期間有可能會需要做邏輯判斷,這時便會用到torch.SymBool來表示symbolic的邏輯值。

但如果我們去查看torch.SymBool,torch.SymInt,torch.SymFloat等類別的定義,卻會發現很多未實作的方法,如__eq__,__lt__等。

其實這些方法不是沒有實作,而是稍後會由torch.fx.experimental.symbolic_shapes模組安裝到torch.SymBool,torch.SymInt,torch.SymFloat上,這些方法被稱為user magic methods。

magic_methods

在PyTorch SymNode 的設計之謎:為何magic methods「看起來沒實作」?處我們已經看過magic methods的定義,以及它們是如何被安裝到torch.SymNode上的。

如果回去查看magic_methods的定義,可以知道magic methods包含了unary magic methods和binary magic methods兩個子集合。

在binary magic methods子集合中,有所謂的reflectable_magic_methods,來看看它的定義。

reflectable_magic_methods

reflectable_magic_methods的定義位於torch.fx.experimental.symbolic_shapes.py。它是一個 將方法的名稱對應到lambda函數 的字典,其中key代表方法的名字,value則為該方法的實現:

# Methods that have a `__foo__` as well as `__rfoo__`reflectable_magic_methods={'add':lambdaa,b:a+b,'sub':lambdaa,b:a-b,'mul':lambdaa,b:a*b,'mod':lambdaa,b:a%b,'pow':lambdaa,b:Pow(a,b),'and':lambdaa,b:a&b,'or':lambdaa,b:a|b,'truediv':lambdaa,b:TrueDiv(a,b),'floordiv':lambdaa,b:FloorDiv(a,b),}

注意以上方法的入參和回傳值皆為sympy.Expr

待會在user magic methods安裝流程章節會看到,如果一個binary method本來的名稱是foo,則它會被以__foo__的名稱安裝到SymInt,SymFloatSymBool上。

例如sub方法會被以__sub__的名稱安裝到SymIntSymFloat上,之後使用者便可以透過SymInt.__sub__(other)SymFloat.__sub__(other)來調用這個方法。

如果一個binary method屬於reflectable_magic_methods,那麼除了SymInt.__sub__SymFloat.__sub__之外,還會多安裝一個__rsub__方法。

那麼__sub____rsub__有何不同之處呢?
SymInt.__sub__(other)是由自己減去對方,即由self._expr減去other._exprSymInt.__rsub__(other)則反過來,是由對方減去自己,即由other._expr減去self._expr

reflectable_magic_methods中大部份方法在做什麼都一目瞭然,只有pow,truediv,floordir三個方法用到了PyTorch中自定義的類別Pow,TrueDivFloorDiv,讓我們來看看它們的定義。

Pow

reflectable_magic_methods中的pow方法對應到PyTorch中自定義的Pow類別,其定義如下:

# Overloaded to be compatible with regular Python.# https://github.com/pytorch/pytorch/issues/90900classPow(sympy.Function):@classmethoddefeval(cls,base,exp):ifexp.is_zero:returnsympy.Integer(1)elifbase.is_zeroandexp<0:raiseZeroDivisionError(f"{base}cannot be raised to a negative power")else:returnbase**exp

可以看到Pow類別繼承了sympy.Function,並且定義了class methodeval方法,這實際上是在按照sympy的規則來撰寫自定義函數,詳見Creating a Custom Function。

Pow.eval函數的入參clsPow類別本身,底數base和指數exp則皆為sympy.Expr

Pow.eval函數用於指數運算,分以下幾種情況:

  • 當指數exp是0時:直接回傳1
  • 當底數base是0且指數exp為負時:不合法,raiseZeroDivisionError錯誤。數學細節詳見Are exponents with base 0 even defined?
  • 在正常情況下則會回傳baseexp次方

可以看出,Pow.eval函數對的核心是sympy**運算子,PyTorch中為了處理指數為0的特殊情況和底數為0且指數為負的錯誤,才對sympy**運算子進行了包裝。

TrueDiv

reflectable_magic_methods中的true_div方法對應到TrueDiv類別,其定義如下:

# Overloaded to be compatible with regular Python.# https://github.com/pytorch/pytorch/issues/90900classTrueDiv(sympy.Function):@classmethoddefeval(cls,base,divisor):ifdivisor.is_zero:raiseZeroDivisionError("division by zero")else:returnbase/divisor

此處TrueDiv繼承自sympy.Function,並定義了eval方法,可知TrueDiv也是按照sympy規則來撰寫的自定義函數。

TrueDiv.eval函數的入參clsTrueDiv類別本身,分子base和分母divisor則皆為sympy.Expr

TrueDiv.eval函數用於除法運算,分以下幾種情況:

  • 在分母divisor為0的情況下,會raiseZeroDivisionError錯誤
  • 否則其行為跟/運算子一樣

這裡也可以看出,TrueDiv.eval函數的核心是sympy/運算子,PyTorch中為了處理分母為0的錯誤,才對sympy/運算子進行了包裝。

FloorDiv

reflectable_magic_methods中的floor_div方法對應到FloorDiv類別,其定義如下:

classFloorDiv(sympy.Function):""" We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) """nargs=(2,)precedence=50# precedence of mul # noqa: F811# Default return type for SymPy assumptions.# https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlersis_real=True@propertydefbase(self):returnself.args[0]@propertydefdivisor(self):returnself.args[1]def_sympystr(self,printer):base=printer.parenthesize(self.base,self.precedence)divisor=printer.parenthesize(self.divisor,self.precedence)returnf"{base}//{divisor}"# SymPy assumptions based on argument types.def_eval_is_real(self):returnfuzzy_or([self.base.is_real,self.divisor.is_real])def_eval_is_integer(self):returnfuzzy_and([self.base.is_integer,self.divisor.is_integer])# Automatic evaluation.# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval@classmethoddefeval(cls,base,divisor):defcheck_supported_type(x):if(x.is_integerisFalseandx.is_realisFalseandx.is_complex)orx.is_Boolean:raiseTypeError(f"unsupported operand type(s) for //: "f"'{type(base).__name__}' and '{type(divisor).__name__}'"f", expected integer or real")check_supported_type(base)check_supported_type(divisor)# We don't provide the same error message as in Python because SymPy# makes it difficult to check the types.ifdivisor.is_zero:raiseZeroDivisionError("division by zero")ifbase.is_zero:returnsympy.S.Zeroifbase.is_integeranddivisor==1:returnbaseifbase.is_realanddivisor==1:returnsympy.floor(base)ifisinstance(base,sympy.Integer)andisinstance(divisor,sympy.Integer):returnbase//divisorifisinstance(base,(sympy.Integer,sympy.Float))andisinstance(divisor,(sympy.Integer,sympy.Float)):returnsympy.floor(base/divisor)ifisinstance(base,FloorDiv):returnFloorDiv(base.args[0],base.args[1]*divisor)ifisinstance(base,sympy.Add):forainbase.args:gcd=sympy.gcd(a,divisor)ifgcd==divisor:returnFloorDiv(base-a,divisor)+a/gcd gcd=sympy.gcd(base,divisor)ifgcd!=1:returnFloorDiv(sympy.simplify(base/gcd),sympy.simplify(divisor/gcd))

FloorDiv.eval函數的入參clsFloorDiv類別本身,分子base和分母divisor則皆為sympy.Expr

FloorDiv.eval函數會做除法後取floor,分以下幾種情況:

  • 在分母divisor為0的情況下,會raiseZeroDivisionError錯誤
  • 在分子base為0的情況下,會回傳0
  • 在分子base是整數且分母divisor是1的情況下,不需做除法也不需取floor,會直接回傳分子base
  • 在分子base是實數且分母divisor是1的情況下,不需做除法,只需要取floor,回傳sympy.floor(base)
  • 在分子base是整數且分母divisor是整數的情況下,會直接做整數除法,回傳base // divisor
  • 在分子basesympy.Integersympy.Float且分母divisor亦是sympy.Integersympy.Float的情況下,會真的做除法,然後取floor,回傳sympy.floor(base / divisor)
  • 在分子baseFloorDiv的情況下,會將分母divisorbase.args[1]相乘,當作新的分母,然後回傳FloorDiv(base.args[0], base.args[1] * divisor)
  • 在分子basesympy.Add的情況下,會遍歷base.args(加數、b)中的每個元素,嘗試找出與分母divisor的最大公因數gcddivisor相等的元素a。如果找到了,則先單獨計算該元素與divisor的商,即a / gcd,因為agcd的倍數,所以a / gcd是一個整數。然後再加上其它元素之和與divisor的商的floor值,即FloorDiv(base - a, divisor),最後回傳
  • 最後,如果分子base與分母divisor的最大公因數gcd不等於1,則會先將分子分母各自除以gcd、各自簡化後再做FloorDiv運算。回傳值為:FloorDiv(sympy.simplify(base / gcd), sympy.simplify(divisor / gcd))
  • 如果最後的if條件不成立不會return值回去?

bool_magic_methods

magic_methods的子集合中,以下方法屬於bool magic methods,會被安裝到SymBool上:

bool_magic_methods={"and","or","sym_not"}

magic methods中不屬於bool magic methods者則將會同時被安裝到SymIntSymFloat兩個類別上。

wrap_node

在等一下會看到的_make_user_magic函數中會大量用到wrap_node函數,其定義位於torch/fx/experimental/symbolic_shapes.py

defwrap_node(x):# TODO: let C++ also take advantage of thisifisinstance(x,SymNode)andx.constantisnotNone:returnx.constantifx.is_int():returnSymInt(x)elifx.is_float():returnSymFloat(x)elifx.is_bool():returnSymBool(x)else:raiseAssertionError(f"unrecognized return type{x}")

wrap_node接受的參數x為一SymNode物件,而SymNode有個constant成員變數:

self.constant:Optional[Union[int,float,bool]]=constant

wrap_node函數會檢查SymNode xconstant成員變數,分以下幾種情況:

  • SymNode xconstant成員變數非空的情況下,wrap_node會取出constant回傳,而constant的型別是int,floatbool其中之一。
  • 如果constant成員變數為空,則wrap_node會依據型別檢查函數回傳的結果,決定將SymNode包裝成SymInt,SymFloat或者是SymBool,最後將包裝後的物件回傳。
  • 如果SymNode x不滿足上述情況,則會將x的型別視為不合法,會raiseAssertionError錯誤。

to_node

在等一下會看到的binary_magic_impl函數中會用到to_node函數,其定義位於torch/fx/experimental/symbolic_shapes.py

defto_node(self,num):ifisinstance(num,SymTypes):returnnum.nodeeliftype(num)isbool:returnself.wrap_bool(num)eliftype(num)isint:returnself.wrap_int(num)eliftype(num)isfloat:returnself.wrap_float(num)else:# NotImplemented is important so that Python tries the# other magic methodreturnNotImplemented

to_node函數接受以下參數:

  • selfSymNode物件
  • num:可能是SymTypes(包含SymInt,SymFloatSymBool)或Python中的bool, float或int

首先檢查num是否為SymTypes

SymTypes=(SymInt,SymFloat,SymBool)

如果numSymTypes,則to_node函數會取出其node成員變數(型別為SymNode)回傳;如果是Python中的bool, float或int,則會調用對應的wrap函數,將其包裝成SymNode後回傳。

在wrap_bool函數中會檢查入參是否為Python中的bool,如果是,便將sympy.truesympy.false當作expr參數傳入SymNode的建構子,重新建構一個SymNode後回傳。

wrap_intwrap_float也類似,分別會將Python中的int和float包裝成SymNode後回傳。

另外要注意的一點是,在wrap_bool,wrap_intwrap_float等函數回傳的SymNode物件中會將constant成員變數設定為num

user magic methods安裝流程

SymInt,SymFloatSymBool安裝user magic methods的程式碼位於torch/fx/experimental/symbolic_shapes.py

主程式

在安裝user magic methods的主程式中,會遍歷magic_methods,一一對 函數名稱method和lambda函數func的pair 呼叫_make_user_magic函數:

formethod,funcinmagic_methods.items():ifmethodinbool_magic_methods:_make_user_magic(method,SymBool)else:_make_user_magic(method,SymInt)_make_user_magic(method,SymFloat)

這段程式碼會為SymBool安裝magic_methods中屬於bool_magic_methods的方法,包括__and__,__or__,__sym__not__

SymIntSymFloat安裝magic_methods中不屬於bool_magic_methods的方法,包括__add__,__sub__,__mul__,__mod__,__pow__,__and__,__or__,__truediv__,__floordiv__,__sym__not__,__eq__,__ne__,__gt__,__lt__,__le__,__ge__,__floor__,__sym__float__,__ceil__,__neg__,__sym__min__,__sym__max__,__sym__sqrt

最後還會為SymIntSymFloat安裝reflectable magic methods,包括__radd__,__rsub__,__rmul__,__rmod__,__rpow__,__rand__,__ror__,__rtruediv__,__rfloordiv__

_make_user_magic

_make_user_magic函數的作用是為SymInt,SymFloatSymBool安裝名為__method__的方法,等一下會看到,其實__method__方法就是我們在PyTorch SymNode 的設計之謎:為何magic methods「看起來沒實作」?見過的SymNode._method_attr方法的包裝。

_make_user_magic接受methoduser_type兩個參數:

  • method代表方法的名稱,方法會以__method__的名稱被安裝
  • user_type:要為哪個類別安裝方法,可以是SymInt,SymFloatSymBool其中之一
def_make_user_magic(method,user_type):# User magic takes care of wrapping the other operand into a node,# so that our internal logic can assume everything is nodesifmethodinmagic_methods_on_operator_with_trailing_underscore:method_attr=f"{method}_"else:method_attr=methoddefunary_magic_impl(self):returnwrap_node(getattr(self.node,method_attr)())defbinary_magic_impl(self,other):other_node=to_node(self.node,other)ifother_nodeisNotImplemented:returnNotImplementedreturnwrap_node(getattr(self.node,method_attr)(other_node))defrbinary_magic_impl(self,other):other_node=to_node(self.node,other)ifother_nodeisNotImplemented:returnNotImplementedreturnwrap_node(getattr(other_node,method_attr)(self.node))ifmethodinunary_magic_methods:setattr(user_type,f"__{method}__",unary_magic_impl)else:setattr(user_type,f"__{method}__",binary_magic_impl)ifmethodinreflectable_magic_methods:setattr(user_type,f"__r{method}__",rbinary_magic_impl)

這個函數比較長,我們可以將它拆成五個部份來看:

  • 一開始定義了method_attr變數
  • 中間定義了unary_magic_impl子函數
  • 接著定義了binary_magic_impl子函數
  • 接著定義了rbinary_magic_impl子函數
  • 最後則是實際把unary_magic_impl,binary_magic_implrbinary_magic_impl安裝在user_type

_make_user_magic - method_attr

如果入參method在 magic_methods_on_operator_with_trailing_underscore (包括andor)中,則會將method_attr變數設為method + "_",否則將method_attr設為method

ifmethodinmagic_methods_on_operator_with_trailing_underscore:method_attr=f"{method}_"else:method_attr=method

待會會看到,unary_magic_impl,binary_magic_implrbinary_magic_impl都是對SymNode方法的包裝,而它們調用的SymNode方法名稱即為method_attr

_make_user_magic - unary_magic_impl

因為unary_magic_impl函數即將被安裝在的SymInt,SymFloatSymBool身上,可知unary_magic_impl的參數self就是SymInt,SymFloatSymBool其中之一:

defunary_magic_impl(self):returnwrap_node(getattr(self.node,method_attr)())

SymBool,SymIntSymFloat都有一個SymNode型別的成員變數node,此處的self.node將取出其node成員變數。

getattr(self.node, method_attr)會獲取SymNodemethod_attr方法。

參考四則運算函數,在SymNode.method_attr方法中會調用SymNode._method_attr方法。

注意到SymNode._method_attr方法就是PyTorch SymNode 的設計之謎:為何magic methods「看起來沒實作」? - magic method安裝流程處由_make_node_magic安裝的unary_magic_impl,它會接受此處傳入的self.node(型別為SymNode)作為參數。

_make_node_magic裡的unary_magic_impl(它與此處_make_user_magic裡的unary_magic_impl是兩個不同的函數)中,會透過SymNode.expr方法把SymNode的_expr成員變數取出,把它當作參數傳入magic methods字典裡的lambda函數,lambda函數用輸入的sympy.Expr做運算後會回傳另一個sympy.Expr物件。_make_node_magic裡的unary_magic_impl會把回傳的sympy.Expr做包裝,得到一個SymNode之後回傳。

在此處_make_user_magic裡的unary_magic_impl函數中,會將這個SymNode用wrap_node函數包起來。

wrap_node函數在入參SymNodeconstant成員變數非空的情況下會回傳底層的Pythonint,floatbool;如果constant成員變數為空,則會將入參SymNode包裝成SymInt,SymFloatSymBool後回傳。

因為在_make_node_magic - unary_magic_impl處創建的SymNode並未設定constant成員變數,所以此處wrap_node函數回傳的是SymInt,SymFloatSymBool其中之一。

總結一下,unary_magic_impl接受SymInt,SymFloatSymBool為參數,調用底層的SymNodemagic method對它做運算,最後同樣回傳SymInt,SymFloat或者是SymBool

_make_user_magic - binary_magic_impl

比起unary_magic_implbinary_magic_impl多了一個參數other

defbinary_magic_impl(self,other):

在無法保證other之型別的情況下,需要做以下前處理:

other_node=to_node(self.node,other)

to_node函數的作用依照other的型別有所不同:

  • other屬於SymTypes = (SymInt, SymFloat, SymBool)的情況下,to_node會直接取出它們的node成員變數(也就是SymNode)回傳
  • other屬於int, float, bool的情況下,to_node會用wrap_int, wrap_float或wrap_bool函數對它們做包裝,得到一個SymNode後回傳

接著透過()運算子調用self.node.method_attr方法,傳入other_node,得到一個SymNode物件。最後用wrap_node函數將SymNode包裝成SymInt,SymFloat或者是SymBool後回傳:

returnwrap_node(getattr(self.node,method_attr)(other_node))

_make_user_magic - rbinary_magic_impl

method屬於reflectable_magic_methods的情況下,會額外安裝所謂的rbinary函數。

rbinary_magic_impl函數與binary_magic_impl大體相同,不同之處在於以下這行程式碼:

getattr(other_node,method_attr)(self.node)

binary_magic_impl是在self.node上調用method_attr,並把other_node當作參數傳入;rbinary_magic_impl則反過來,在other_node上調用method_attr,把self.node當作參數傳入。

pow函數為例,binary版本是以self.expr為底數,other.expr為指數;rbinary版本則是相反,以other.expr為底數,self.expr為指數。

安裝user magic method

定義完必要的子函數後,在_make_user_magic的最後會檢查入參method是否屬於unary_magic_methods、將它們分為unary和binary兩類。

對於unary method,就用unary_magic_impl包裝,然後安裝到user_type上;對於binary method,則用binary_magic_impl做包裝,同樣安裝到user_type上。

在binary method底下,還有個子集合reflectable_magic_methods,如果method屬於reflectable_magic_methods,則會額外用rbinary_magic_impl包裝,然後安裝到user_type上。

ifmethodinunary_magic_methods:setattr(user_type,f"__{method}__",unary_magic_impl)else:setattr(user_type,f"__{method}__",binary_magic_impl)ifmethodinreflectable_magic_methods:setattr(user_type,f"__r{method}__",rbinary_magic_impl)

之後我們就可以透過對SymIntSymFloatSymBool物件呼叫__method_attr____rmethod_attr__方法來調用。讓我們來驗證一下,在Python命令行裡查看SymBool.__and__的方法,會出現以下輸出:

>>>torch.fx.experimental.symbolic_shapes.SymBool.__and__<function _make_user_magic.<locals>.binary_magic_impl at0x7fa81db14280>

在Python命令行裡查看SymInt.__add__SymInt.__radd__的方法,會出現以下輸出:

importtorch>>>torch.fx.experimental.symbolic_shapes.SymInt.__add__<function _make_user_magic.<locals>.binary_magic_impl at0x7fa81db13700>>>>torch.fx.experimental.symbolic_shapes.SymInt.__radd__<function _make_user_magic.<locals>.rbinary_magic_impl at0x7fa81db13790>

從輸出可以看出來,SymBool.__and__確實跟_make_user_magicbinary_magic_impl有關,SymInt.__add__SymInt.__radd__也是如此,代表它們確實是由torch/fx/experimental/symbolic_shapes.py安裝的user magic methods。

調用流程

使用者可以透過SymInt / 1.的語法來做除法運算,此處用到了SymInt/運算子,但從torch.SymInt中卻找不到/運算子的定義。

參考Python __truediv__() Magic Method:

to evaluate the expression x / y, Python attempts to call x.__truediv__(y)

在Python中使用/運算子時,底層會呼叫__truediv__
而我們知道,此處調用的SymInt.__truediv__方法是在SymInt被定義後才被安裝上去的user magic method。

注意到SymInt__truediv__方法接受的運算元是SymInt本身和一個數字1.,兩個不同型別的運算元該如何做運算呢?
參考binary_magic_impl,在binary_magic_impl函數中有個前處理,會調用to_node函數將1.包裝成SymNode,如此一來__truediv__的兩個運算元便都是SymNode型別了。

如user magic methods安裝流程章節所述,SymInt.__truediv__會呼叫SymNode.truediv,而根據四則運算函數,SymNode.truediv會進一步呼叫SymNode._truediv

參考_make_node_magic - binary_magic_impl,在SymNode._truediv中,會先取出兩個SymNodeexpr成員變數(型別為sympy.Expr),對它們做TrueDiv操作,得到新的sympy.Expr。在SymNode._truediv的最後,會將運算得到的sympy.Expr當作SymNode建構子的expr參數,創建一個SymNode後回傳。注意此處創建出來的SymNodepytypefloat

SymInt.__truediv__得到SymNode._truedivSymNode.truediv回傳的SymNode後,會再調用wrap_node函數。
wrap_node函數中,如果SymNodepytypefloat,則會將SymNode包裝成SymFloat後回傳。

所以SymInt / 1.會得到一個SymFloat

以下三行分別代表這整個過程中的函數調用鏈,各函數的參數和回傳值,整理如下:

SymInt.__truediv__SymNode.truedivSymNode._truediv→ lambda函數 →TrueDiv

SymIntfloat→ 兩個SymNode→ 兩個SymNode→ 兩個sympy.Expr→ 兩個sympy.Expr

SymFloatSymNodeSymNodesympy.Exprsympy.Expr

注:其中SymInt.__truediv__即透過_make_user_magic被安裝到SymInt上的binary_magic_implSymNode._truediv即透過_make_node_magic被安裝到SymNode上的binary_magic_impl

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询