Skip to content
Snippets Groups Projects
latex_transer.py 6.77 KiB
from sympy import *
Ee=E
'''
getUnit			'{} () [] \\'			独立表达模块
getIndepend		'abcd'					独立符号
getContinuous	'1234. a^b a_b {x'}'	连续符号
getElement		'* / '					强连续运算
getAll			'+ -'					弱连续运算(全部)
				'='						全表达式
'''
syms={
	"\\alpha","\\beta","\\gamma","\\varepsilon","\\varphi","\\lambda","\\mu","\\rho","\\theta","\\omega","\\Gamma","\\Delta","\\Omega","\\Lambda","\\Phi","\\Psi","\\delta","\\kappa"
}
basicNumber="1234567890."
def warning(warn):
	print("Unknow elem:",warn)
def getUnit(E,dep):
	dep=dep+1
	print(' '*dep*2,'>getUnit',E)
	res=''
	if E[0]=='{' or E[0]=='(' or E[0]=='[':
		res,E=getAll(E[1:],dep)
	elif E[:5]=='\\frac':
		elem1,E=getUnit(E[5:],dep)
		elem2,E=getUnit(E,dep)
		res=elem1/elem2
	elif E[0]=='^':
		elem1mE=getUnit(E[1:],dep)
		elem2,E=getUnit(E,dep)
		res=elem1/elem2
	elif E[:6]=='\\sqrt[':
		elem1,E=getUnit(E[5:],dep)
		elem2,E=getUnit(E,dep)
		print(elem1,elem2)
		res=pow(elem2,1/elem1)
	elif E[:5]=='\\sqrt':
		res,E=getUnit(E[5:],dep)
		res=sqrt(res)
	elif E[:5]=='\\sum^':#inf \\sum^{high}_{n=low}elem
		high,E=getIndepend(E[5:],dep)
		high=int(high)
		n,E=getContinuous(E[2:],dep)
		n=str(n)
		low,E=getAll(E[1:],dep)
		low=int(low)
		elem,E=getElement(E,dep)
		res=0
		for i in range(low,high+1):
			res+=elem.evalf(subs={n:i})
	elif E[:5]=='\\sum_':#inf \\sum_{n=low}^{high}elem
		n,E=getContinuous(E[6:],dep)
		n=str(n)
		low,E=getAll(E[1:],dep)
		low=int(low)
		high,E=getIndepend(E[1:],dep)
		high=int(high)
		elem,E=getElement(E,dep)
		res=0
		for i in range(low,high+1):
			res+=elem.evalf(subs={n:i})
	elif E[:5]=='\\log_':
		# print('here')
		log_,E=getIndepend(E[5:],dep)
		elem,E=getElement(E,dep)
		res=log(elem)/log(log_)
	elif E[:3]=='\\ln':
		elem,E=getContinuous(E[3:],dep)
		res=log(elem)
	elif E[:4]=='\\cos':
		elem,E=getContinuous(E[4:],dep)
		res=cos(elem)
	elif E[:4]=='\\sin':
		elem,E=getContinuous(E[4:],dep)
		res=sin(elem)
	elif E[:4]=='\\tan':
		elem,E=getContinuous(E[4:],dep)
		res=tan(elem)
	elif E[:4]=='\\cot':
		elem,E=getContinuous(E[4:],dep)
		res=cot(elem)
	elif E[:5]=='\\acos':
		elem,E=getContinuous(E[5:],dep)
		res=acos(elem)
	elif E[:5]=='\\asin':
		elem,E=getContinuous(E[5:],dep)
		res=asin(elem)
	elif E[:5]=='\\atan':
		elem,E=getContinuous(E[5:],dep)
		res=atan(elem)
	elif E[:5]=='\\acot':
		elem,E=getContinuous(E[5:],dep)
		res=acot(elem)
	elif E[:3]=='\\pi':
		res=pi
		E=E[3:]
	elif E[0]=='e' or E[0]=='E':
		res=E
		E=E[1:]
	elif E[:6]=='\\infty':
		res=oo
		E=E[6:]
	elif E[0]=='i':
		res=I
		E=E[1:]
	else:
		for s in syms:
			if E[:len(s)]==s:
				res=symbols(s,real=True)
				E=E[len(s):]
				break
			elif E[:len(s)-1]==s[1:]:
				res=symbols(s,real=True)
				E=E[len(s):]
				break
	print(' '*dep*2,'<getUnit',res,E)
	return res,E
def getIndepend(E,dep):
	dep=dep+1
	print(' '*dep*2,'>getIndepend',E)
	res=''
	if E[0]=='{' or E[0]=='(' or E[0]=='[' or E[0]=='\\' or E[0]=='^':
		res,E=getUnit(E,dep)
	elif E[0] in "0123456789":
		res=int(E[0])
		E=E[1:]
	elif E[:3]=='\\pi':
		res=pi
		E=E[3:]
	elif E[0]=='e' or E[0]=='E':
		res=Ee
		E=E[1:]
	elif E[:6]=='\\infty':
		res=oo
		E=E[6:]
	elif E[0]=='i':
		res=I
		E=E[1:]
	else:
		for s in syms:
			if E[:len(s)]==s:
				res=symbols(s,real=True)
				E=E[len(s):]
				break
			elif E[:len(s)-1]==s[1:]:
				res=symbols(s,real=True)
				E=E[len(s):]
				break
		if res=='':
			res=symbols(E[0],real=True)
			E=E[1:]
	print(' '*dep*2,'<getIndepend',res,E)
	return res,E
def getIndependStr(E,dep):
	dep=dep+1
	print(' '*dep*2,'<getIndependStr',E)
	if E[0]=='{' or E[0]=='('or E[0]=='[':
		cnt=0
		s=""
		for i in E:
			if i=='{' or i=='('or E[0]=='[':
				cnt+=1
			elif i=='}' or i==')' or E[0]==']':
				cnt-=1
			s+=i
			if cnt==0:
				return s,E[len(s):]
	elif E[0] in "0123456789":
		return E[0],E[1:]
	elif E[:3]=='\\pi':
		return E[:3],E[3:]
	elif E[0]=='e' or E[0]=='E':
		return E[0],E[1:]
	elif E[:6]=='\\infty':
		return E[:6],E[6:]
	elif E[0]=='i':
		return E[0],E[1:]
	else:
		for s in syms:
			if E[:len(s)]==s:
				return s,E[len(s):]
		return E[0],E[1:]
	print(' '*dep*2,'<getIndepend',res,E)

def getContinuous(E,dep):
	dep=dep+1
	print(' '*dep*2,'>getContinuous',E)
	s=""
	for i in E:
		if i in basicNumber:
			s=s+i
		else:
			break
	if s!='':
		elem=s
		try:
			res=int(elem)
		except:
			res=float(elem)
		E=E[len(elem):]
	else:
		res,E=getIndepend(E,dep)
		elem=str(res)
	if E!='':
		if E[0]=='^':
			elemSuper,E=getIndependStr(E[1:],dep)
			if E!='' and E[0]=='_':
				elemSub,E=getIndependStr(E[1:],dep)
				res=symbols(elem+'_'+elemSub,real=True)
			try:
				elemSuper,tmp=getIndepend(elemSuper,dep)
				res=pow(res,elemSuper)
			except:
				res=symbols(s+'^'+elemSuper,real=True)
		elif E[0]=='_':
			elemSub,E=getIndependStr(E[1:],dep)
			res=symbols(elem+'_'+elemSub,real=True)
			if E!='' and E[0]=='^':
				try:
					elemSuper,E=getIndepend(E[1:],dep)
					res=pow(res,elemSuper)
				except:
					elemSuper,E=getIndependStr(E[1:],dep)
					res=symbols(res+'^'+elemSuper,real=True)
	print(' '*dep*2,'<getContinuous',res,E)
	return res,E
def getElement(E,dep):
	dep=dep+1
	print(' '*dep*2,'>getElement',E)
	res,E=getContinuous(E,dep)
	while E!='':
		if E[0]=='*':
			elem,E=getContinuous(E[1:],dep)
			res=res*elem
		elif E[:5]=='\\cdot':
			elem,E=getContinuous(E[5:],dep)
			res=res*elem
		elif E[:6]=='\\times':
			elem,E=getContinuous(E[6:],dep)
			res=res*elem
		elif E[:4]=='\\ast':
			elem,E=getContinuous(E[4:],dep)
			res=res*elem
		elif E[0]=='/':
			elem,E=getContinuous(E[1:],dep)
			try:
				res=res/elem
			except ZeroDivisionError as e:
				print(e)
		elif E[:4]=='\\div':
			elem,E=getContinuous(E[4:],dep)
			try:
				res=res/elem
			except ZeroDivisionError as e:
				print(e)
		elif E[0]!='+' and E[0]!='-' and E[0]!='}' and E[0]!=')' and E[0]!=']':
			elem,E=getContinuous(E,dep)
			res=res*elem
		else:
			break
	print(' '*dep*2,'<getElement',res,E)
	return res,E
def getAll(E,dep):
	dep=dep+1
	print(' '*dep*2,'>getAll',E)
	if E[0]=='+':
		res,E=getElement(E[1:],dep)
	elif E[0]=='-':
		res,E=getElement(E[1:],dep)
		res=-res
	elif E[0]=='}' or E[0]==')' or E[0]==']':
		res=''
	else:
		res,E=getElement(E,dep)
	while E!='':
		if E[0]=='+':
			elem,E=getElement(E[1:],dep)
			res=res+elem
		elif E[0]=='-':
			elem,E=getElement(E[1:],dep)
			res=res-elem
		else:
			break
	E=E[1:]
	print(' '*dep*2,'<getAll',res,E)
	return res,E
def Trans(E):
	if E=='':
		return E
	print('Trans',E,type(E))
	if type(E)==str:
		E='{'+E+'}'
	else:
		E='{'+str(E)+'}'
	E=E.replace(' ','')
	E=E.replace('/_','/')
	try:
		res,E=getUnit(E,0)
		print('trans got',res)
		return cancel(ratsimp(expand_log(simplify(res))))
	except BaseException as e:
		print(e)
		return 'error'
def Trans_equal(E):
	for i in range(len(E)):
		if E[i]=='=':
			return Trans('('+E[:i]+')-('+E[i+1:]+')')
	print("Trans_equal",E)
	return Trans(E)
if __name__=='__main__':
	# E=input('')
	E='\\lnx^3'
	res=Trans(E)
	print(res)