fastest way to find the smallest positive real root of quartic polynomial 4 degree in python

2024/9/25 21:21:59

[What I want] is to find the only one smallest positive real root of quartic function ax^4 + bx^3 + cx^2 + dx + e

[Existing Method] My equation is for collision prediction, the maximum degree is quartic function as f(x) = ax^4 + bx^3 + cx^2 + dx + e and a,b,c,d,e coef can be positive/negative/zero (real float value). So my function f(x) can be quartic, cubic, or quadratic depending on a, b, c ,d ,e input coefficient.

Currently, I use NumPy to find roots as below.

import numpyroot_output = numpy.roots([a, b, c ,d ,e])

The "root_output" from the NumPy module can be all possible real/complex roots depending on the input coefficient. So I have to look at "root_output" one by one, and check which root is the smallest real positive value (root>0?)

[The Problem] My program needs to execute numpy.roots([a, b, c, d, e]) many times, so many times of executing numpy.roots is too slow for my project. and (a, b, c ,d ,e) value is always changed every time when executing numpy.roots

My attempt is to run the code on Raspberry Pi2. Below is an example of processing time.

  • Running many many times of numpy.roots on PC: 1.2 seconds
  • Running many many times of numpy.roots on Raspberry Pi2: 17 seconds

Could you please guide me on how to find the smallest positive real root in the fastest solution? Using scipy.optimize or implement some algorithm to speed up finding root or any advice from you will be great.

Thank you.

[Solution]

  • Quadratic function only need real positive roots (please be aware of division by zero)
def SolvQuadratic(a, b ,c):d = (b**2) - (4*a*c)if d < 0:return []if d > 0:square_root_d = math.sqrt(d)t1 = (-b + square_root_d) / (2 * a)t2 = (-b - square_root_d) / (2 * a)if t1 > 0:if t2 > 0:if t1 < t2:return [t1, t2]return [t2, t1]return [t1]elif t2 > 0:return [t2]else:return []else:t = -b / (2*a)if t > 0:return [t]return []
  • Quartic Function for quartic function, you can use pure python/numba version as the below answer from @B.M.. I also add another cython version from @B.M's code. You can use the below code as .pyx file and then compile it to get about 2x faster than pure python (please be aware of rounding issues).
import cmathcdef extern from "complex.h":double complex cexp(double complex)cdef double complex  J=cexp(2j*cmath.pi/3)
cdef double complex  Jc=1/Jcdef Cardano(double a, double b, double c, double d):cdef double z0cdef double a2, b2cdef double p ,q, Dcdef double complex rcdef double complex u, v, wcdef double w0, w1, w2cdef double complex r1, r2, r3z0=b/3/aa2,b2 = a*a,b*bp=-b2/3/a2 +c/aq=(b/27*(2*b2/a2-9*c/a)+d)/aD=-4*p*p*p-27*q*qr=cmath.sqrt(-D/27+0j)u=((-q-r)/2)**0.33333333333333333333333v=((-q+r)/2)**0.33333333333333333333333w=u*vw0=abs(w+p/3)w1=abs(w*J+p/3)w2=abs(w*Jc+p/3)if w0<w1:if w2<w0 : v = v*Jcelif w2<w1 : v = v*Jcelse: v = v*Jr1 = u+v-z0r2 = u*J+v*Jc-z0r3 = u*Jc+v*J-z0return r1, r2, r3cdef Roots_2(double a, double complex b, double complex c):cdef double complex bpcdef double complex deltacdef double complex r1, r2bp=b/2delta=bp*bp-a*cr1=(-bp-delta**.5)/ar2=-r1-b/areturn r1, r2def SolveQuartic(double a, double b, double c, double d, double e):"Ferrarai's Method""resolution of P=ax^4+bx^3+cx^2+dx+e=0, coeffs reals""First shift : x= z-b/4/a  =>  P=z^4+pz^2+qz+r"cdef double z0cdef double a2, b2, c2, d2cdef double p, q, rcdef double A, B, C, Dcdef double complex y0, y1, y2cdef double complex a0, b0cdef double complex r0, r1, r2, r3z0=b/4.0/aa2,b2,c2,d2 = a*a,b*b,c*c,d*dp = -3.0*b2/(8*a2)+c/aq = b*b2/8.0/a/a2 - 1.0/2*b*c/a2 + d/ar = -3.0/256*b2*b2/a2/a2 + c*b2/a2/a/16 - b*d/a2/4+e/a"Second find y so P2=Ay^3+By^2+Cy+D=0"A=8.0B=-4*pC=-8*rD=4*r*p-q*qy0,y1,y2=Cardano(A,B,C,D)if abs(y1.imag)<abs(y0.imag): y0=y1if abs(y2.imag)<abs(y0.imag): y0=y2a0=(-p+2*y0)**.5if a0==0 : b0=y0**2-relse : b0=-q/2/a0r0,r1=Roots_2(1,a0,y0+b0)r2,r3=Roots_2(1,-a0,y0-b0)return (r0-z0,r1-z0,r2-z0,r3-z0)

[Problem of Ferrari's method] We're facing the problem when the coefficients of quartic equation is [0.00614656, -0.0933333333333, 0.527664995846, -1.31617928376, 1.21906444869] the output from numpy.roots and ferrari methods is entirely different (numpy.roots is correct output).

import numpy as np
import cmathJ=cmath.exp(2j*cmath.pi/3)
Jc=1/Jdef ferrari(a,b,c,d,e):"Ferrarai's Method""resolution of P=ax^4+bx^3+cx^2+dx+e=0, coeffs reals""First shift : x= z-b/4/a  =>  P=z^4+pz^2+qz+r"z0=b/4/aa2,b2,c2,d2 = a*a,b*b,c*c,d*dp = -3*b2/(8*a2)+c/aq = b*b2/8/a/a2 - 1/2*b*c/a2 + d/ar = -3/256*b2*b2/a2/a2 +c*b2/a2/a/16-b*d/a2/4+e/a"Second find y so P2=Ay^3+By^2+Cy+D=0"A=8B=-4*pC=-8*rD=4*r*p-q*qy0,y1,y2=Cardano(A,B,C,D)if abs(y1.imag)<abs(y0.imag): y0=y1if abs(y2.imag)<abs(y0.imag): y0=y2a0=(-p+2*y0)**.5if a0==0 : b0=y0**2-relse : b0=-q/2/a0r0,r1=Roots_2(1,a0,y0+b0)r2,r3=Roots_2(1,-a0,y0-b0)return (r0-z0,r1-z0,r2-z0,r3-z0)#~ @jit(nopython=True)
def Cardano(a,b,c,d):z0=b/3/aa2,b2 = a*a,b*bp=-b2/3/a2 +c/aq=(b/27*(2*b2/a2-9*c/a)+d)/aD=-4*p*p*p-27*q*qr=cmath.sqrt(-D/27+0j)u=((-q-r)/2)**0.33333333333333333333333v=((-q+r)/2)**0.33333333333333333333333w=u*vw0=abs(w+p/3)w1=abs(w*J+p/3)w2=abs(w*Jc+p/3)if w0<w1:if w2<w0 : v*=Jcelif w2<w1 : v*=Jcelse: v*=Jreturn u+v-z0, u*J+v*Jc-z0, u*Jc+v*J-z0#~ @jit(nopython=True)
def Roots_2(a,b,c):bp=b/2delta=bp*bp-a*cr1=(-bp-delta**.5)/ar2=-r1-b/areturn r1,r2coef = [0.00614656, -0.0933333333333, 0.527664995846, -1.31617928376, 1.21906444869]
print("Coefficient A, B, C, D, E", coef) 
print("") 
print("numpy roots: ", np.roots(coef)) 
print("") 
print("ferrari python ", ferrari(*coef))
Answer

An other answer :

do it with analytic methods (Ferrari,Cardan), and speed the code with Just in Time compilation (Numba) :

Let see the improvement first :

In [2]: P=poly1d([1,2,3,4],True)In [3]: roots(P)
Out[3]: array([ 4.,  3.,  2.,  1.])In [4]: %timeit roots(P)
1000 loops, best of 3: 465 µs per loopIn [5]: ferrari(*P.coeffs)
Out[5]: ((1+0j), (2-0j), (3+0j), (4-0j))In [5]: %timeit ferrari(*P.coeffs) #pure python without jit
10000 loops, best of 3: 116 µs per loop    
In [6]: %timeit ferrari(*P.coeffs)  # with numba.jit
100000 loops, best of 3: 13 µs per loop

Then the ugly code :

for order 4 :

@jit(nopython=True)
def ferrari(a,b,c,d,e):"resolution of P=ax^4+bx^3+cx^2+dx+e=0""CN all coeffs real.""First shift : x= z-b/4/a  =>  P=z^4+pz^2+qz+r"z0=b/4/aa2,b2,c2,d2 = a*a,b*b,c*c,d*d p = -3*b2/(8*a2)+c/aq = b*b2/8/a/a2 - 1/2*b*c/a2 + d/ar = -3/256*b2*b2/a2/a2 +c*b2/a2/a/16-b*d/a2/4+e/a"Second find X so P2=AX^3+BX^2+C^X+D=0"A=8B=-4*pC=-8*rD=4*r*p-q*qy0,y1,y2=cardan(A,B,C,D)if abs(y1.imag)<abs(y0.imag): y0=y1 if abs(y2.imag)<abs(y0.imag): y0=y2 a0=(-p+2*y0.real)**.5if a0==0 : b0=y0**2-relse : b0=-q/2/a0r0,r1=roots2(1,a0,y0+b0)r2,r3=roots2(1,-a0,y0-b0)return (r0-z0,r1-z0,r2-z0,r3-z0) 

for order 3 :

J=exp(2j*pi/3)
Jc=1/J@jit(nopython=True) 
def cardan(a,b,c,d):u=empty(2,complex128)z0=b/3/aa2,b2 = a*a,b*b    p=-b2/3/a2 +c/aq=(b/27*(2*b2/a2-9*c/a)+d)/aD=-4*p*p*p-27*q*qr=sqrt(-D/27+0j)        u=((-q-r)/2)**0.33333333333333333333333v=((-q+r)/2)**0.33333333333333333333333w=u*vw0=abs(w+p/3)w1=abs(w*J+p/3)w2=abs(w*Jc+p/3)if w0<w1: if w2<w0 : v*=Jcelif w2<w1 : v*=Jcelse: v*=J        return u+v-z0, u*J+v*Jc-z0,u*Jc+v*J-z0

for order 2:

@jit(nopython=True)
def roots2(a,b,c):bp=b/2    delta=bp*bp-a*cu1=(-bp-delta**.5)/au2=-u1-b/areturn u1,u2  

Probably needs to be test furthermore, but efficient.

https://en.xdnf.cn/q/71493.html

Related Q&A

Split strings by 2nd space

Input :"The boy is running on the train"Output expected:["The boy", "boy is", "is running", "running on", "on the", "the train"]Wha…

Searching for a random python program generator

Im searching for a program that can generate random but valid python programs, similar to theRandom C program generator. I was trying to do this myself giving random input to the python tokenize.untoke…

Python tk framework

I have python code that generates the following error:objc[36554]: Class TKApplication is implemented in both /Library/Frameworks/Tk.framework/Versions/8.5/Tk and /System/Library/Frameworks/Tk.framewor…

SQLAlchemy relationship on many-to-many association table

I am trying to build a relationship to another many-to-many relationship, the code looks like this: from sqlalchemy import Column, Integer, ForeignKey, Table, ForeignKeyConstraint, create_engine from …

Python: interpolating in a triangular mesh

Is there any decent Pythonic way to interpolate in a triangular mesh, or would I need to implement that myself? That is to say, given a (X,Y) point well call P, and a mesh (vertices at (X,Y) with val…

Customizing pytest junitxml failure reports

I am trying to introspect test failures and include additional data into the junit xml test report. Specifically, this is a suite of functional tests on an external product, and I want to include the p…

python nltk keyword extraction from sentence

"First thing we do, lets kill all the lawyers." - William ShakespeareGiven the quote above, I would like to pull out "kill" and "lawyers" as the two prominent keywords to …

Getting the parameter names of scipy.stats distributions

I am writing a script to find the best-fitting distribution over a dataset using scipy.stats. I first have a list of distribution names, over which I iterate:dists = [alpha, anglit, arcsine, beta, bet…

Does Python 3 gzip closes the fileobj?

The gzip docs for Python 3 states thatCalling a GzipFile object’s close() method does not close fileobj, since you might wish to append more material after the compressed dataDoes this mean that the g…

pip stopped working after upgrading anaconda v4.4 to v5.0

I ran the command conda update anaconda to update anaconda v4.4 to v5.0After anaconda was successfully upgraded to v5.0, I had problems running pip.This is the error output I see after running pip;Trac…