#These functions are used to solve numerically linear and almost linear autonomous systems
#and to generate and plot trajectories in the neighborhood of a critical point.
#This script is provided AS IS and for educational purposes only.
#It is not guaranteed for merchantability or for any particular purpose.
#Copyright (C) by Eva A. Horvath

import matplotlib.pyplot as plt
from math import sqrt

def menu_calctypes():
    print("1 Linear proper node dx/dt = -2x and dy/dt = -2y")
    print("2 Nonlinear improper node dx/dt = -2x + xy and dy/dt = -2y")
    print()
    print("3 Linear improper node dx/dt = -x and dy/dt = -3y")
    print("4 Nonlinear improper node dx/dt = -x and dy/dt = -3y + y^3")
    print()
    print("5 Linear saddle point dx/dt = -3x and dy/dt = y")
    print("6 Nonlinear saddle point dx/dt = -3x + y^2 and dy/dt = y")
    print()
    print("7 Linear improper node dx/dt = -2x and dy/dt = x - 2y")
    print("8 Nonlinear node dx/dt = -2x + y^2 and dy/dt = x - 2y")
    print()
    print("9 Linear spiral point dx/dt = -x + 2y and dy/dt = -2x - y")
    print("10 Nonlinear spiral point dx/dt = -x + 2y and dy/dt = -2x - y + xy")
    print()
    print("11 Linear center dx/dt = y and dy/dt = -x")
    print("12 Nonlinear  dx/dt = y + x^2 and dy/dt = -x")
    print()
    return

def choose_fg(choice):
    if choice == "1":    #Linear proper node example
        f = lambda x, y: -2.0*x
        g = lambda x, y: -2.0*y
    elif choice == "2":    #Nonlinear node example
        f = lambda x, y: -2.0*x + x*y
        g = lambda x, y: -2.0*y
    elif choice == "3":    #Linear improper node example
        f = lambda x, y: -x
        g = lambda x, y: -3.0*y
    elif choice == "4":    #Nonlinear improper node example
        f = lambda x, y: -x
        g = lambda x, y: -3.0*y + y*y*y
    elif choice == "5":    #Linear saddle point example
        f = lambda x, y: -3.0*x
        g = lambda x, y: y
    elif choice == "6":    #Nonlinear saddle point example
        f = lambda x, y: -3.0*x + y*y
        g = lambda x, y: y
    elif choice == "7":    #Linear improper node example
        f = lambda x, y: -2.0*x
        g = lambda x, y: x - 2.0*y
    elif choice == "8":    #Nonlinear improper node example
        f = lambda x, y: -2.0*x + y*y
        g = lambda x, y: x - 2.0*y
    elif choice == "9":    #Linear spiral point example
        f = lambda x, y: -x + 2.0*y
        g = lambda x, y: -2.0*x - y
    elif choice == "10":    #Nonlinear spiral point example
        f = lambda x, y: -x + 2.0*y
        g = lambda x, y: -2.0*x - y + x*y
    elif choice == "11":    #Linear center example
        f = lambda x, y:  y
        g = lambda x, y: -x
    elif choice == "12":    #Nonlinear center example
        f = lambda x, y: y + x*x
        g = lambda x, y: -x
    else:
        return None
    return f, g

def enter_startingvalues():
    a = []
    try:
        pair = input("Enter x0, y0.  '0,0' to exit: ")
        while True:
            pl = pair.split(",")
            pl0 = float(pl[0])
            pl1 = float(pl[1])
            if pl0 == 0 and pl1 == 0:
                break
            a.append([pl0, pl1])
            pair = input("Enter x0, y0.  '0,0' to exit: ")
        tm = float(input("t max: "))
    except (ValueError, IndexError):
        a = [[1.0, 0.5]]
        tm = 1.0
    return a, tm

def choose_startingvalues(choice, data_choice):
    if data_choice == "y":
        if (choice == "1" or choice == "2" or choice == "3" or choice == "4" or choice == "7"
                          or choice == "8"):
            a0 = 0.1
            a = [[a0, 0.0], [a0/sqrt(2.0), a0/sqrt(2.0)], [0, a0], [-a0/sqrt(2.0), a0/sqrt(2.0)],
                 [-a0, 0.0], [a0/sqrt(2.0), -a0/sqrt(2.0)], [0.0, -a0], [-a0/sqrt(2.0), -a0/sqrt(2.0)]]
            tmax = 5.0
        elif choice == "5" or choice == "6":
            a = [[1.0, 0.0], [0.0, 0.01], [-1.0, 0.0], [0.0, -0.01], [-1.0, 0.01], [-1.0, -0.01],
                 [1.0, 0.01], [1.0, -0.01]]
            tmax = 2.0
        elif choice == "9" or choice == "10":
            a = [[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]]
            tmax = 5.0
        elif choice == "11":
            a = [[1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]
            tmax = 8.0
        elif choice == "12":
            a = [[0.2, 0.2]]
            tmax = 8.0
    else:
        a, tmax = enter_startingvalues()
    return a, tmax

def rk(f, g, a0, t0=0.0, h=0.001, tmax=5.0):
    #The Runge-Kutta Method
    tn = t0
    xn = a0[0]
    yn = a0[1]
    x = [xn]
    y = [yn]
    while tn < tmax:
        kn1 = f(xn, yn)
        ln1 = g(xn, yn)
        kn2 = f(xn + h*kn1/2.0, yn + h*ln1/2.0)
        ln2 = g(xn + h*kn1/2.0, yn + h*ln1/2.0)
        kn3 = f(xn + h*kn2/2.0, yn + h*ln2/2.0)
        ln3 = g(xn + h*kn2/2.0, yn + h*ln2/2.0)
        kn4 = f(xn + h*kn3, yn + h*ln3)
        ln4 = g(xn + h*kn3, yn + h*ln3)
        xn1 = xn + h/6.0 * (kn1 + 2.0*kn2 + 2.0*kn3 + kn4)
        yn1 = yn + h/6.0 * (ln1 + 2.0*ln2 + 2.0*ln3 + ln4)
        x.append(xn1)
        y.append(yn1)
        xn = xn1
        yn = yn1
        tn += h
    return x, y

def genplot():
    menu_calctypes()
    choice = input("Enter your choice.  Any other key to exit: ")
    func = choose_fg(choice)
    if func is None:
        return
    data_choice = input("Use stored starting values? ")
    a, tm = choose_startingvalues(choice, data_choice)
    n = len(a)
    ax = plt.subplot()
    colors = ["red", "green", "blue", "purple", "yellow", "orange", "brown","cyan"]
    for i in range(n):
        xp, yp = rk(func[0], func[1], a[i], tmax=tm)
        ax.scatter(xp, yp, 1, c=colors[i%len(colors)])
    ax.set(xlabel='x', ylabel='y', title='Critical point neighborhood.')
    plt.show()
    return

genplot()