#These functions are used to generate and plot the trajectories of
#linear autonomous systems in order to illustrate the behavior in the
#neighborhood of a critical point.  This script is provided AS IS and for
#educational purposes only.  It is not guaranteed for merchantability or for
#any particular purpose.
#Copyright (C) 2020 by Eva A. Horvath

import matplotlib.pyplot as plt
from math import exp, sqrt, cos, sin

def menu_calctypes():
    print("1 Linear proper node dx/dt = -2x and dy/dt = -2y")
    print("2 Linear improper node dx/dt = -x and dy/dt = -3y")
    print("3 Linear saddle point dx/dt = -3x and dy/dt = y")
    print("4 Linear improper node dx/dt = -2x and dy/dt = x - 2y")
    print("5 Linear spiral point dx/dt = -x + 2y and dy/dt = -2x - y")
    print("6 Linear center dx/dt = y and dy/dt = -x")
    print()
    return

def proper_node(a0, t0=0.0, h=0.001, tmax=5.0):
    #This generates a trajectory for the system dx/dt = -2x and dy/dt = -2y.
    #The critical point for this system (0, 0) is a proper node.
    t = t0
    x = []
    y = []
    while t < tmax:
        x.append(a0[0]*exp(-2.0*t))
        y.append(a0[1]*exp(-2.0*t))
        t += h
    return x, y

def improper_node(a0, t0=0.0, h=0.001, tmax=5.0):
    #This generates a trajectory for the system dx/dt = -x and dy/dt = -3y.
    #The critical point for this system (0, 0) is an improper node.
    t = t0
    x = []
    y = []
    while t < tmax:
        x.append(a0[0]*exp(-t))
        y.append(a0[1]*exp(-3.0*t))
        t += h
    return x, y

def saddle_point(a0, t0=0.0, h=0.001, tmax=5.0):
    #This generates a trajectory for the system dx/dt = -3x and dy/dt = y.
    #The critical point for this system (0, 0) is a saddle point.
    t = t0
    x = []
    y = []
    while t < tmax:
        x.append(a0[0]*exp(-3.0*t))
        y.append(a0[1]*exp(t))
        t += h
    return x, y

def improper_node_a(a0, t0=0.0, h=0.001, tmax=5.0):
    #This generates a trajectory for the system dx/dt = -2x and dy/dt = x - 2y.
    #The critical point for this system (0, 0) is an improper node.
    t = t0
    x = []
    y = []
    while t < tmax:
        x.append(a0[0]*exp(-2.0*t))
        y.append(a0[0]*t*exp(-2.0*t) + a0[1]*exp(-2.0*t))
        t += h
    return x, y

def spiral_point(a0, t0=0.0, h=0.001, tmax=5.0):
    #This generates a trajectory for the system dx/dt = -x + 2y and dy/dt = -2x - y.
    #The critical point for this system (0, 0) is a spiral point.
    t = t0
    x = []
    y = []
    while t < tmax:
        x.append(exp(-t)*(a0[0]*cos(2.0*t) + a0[1]*sin(2.0*t)))
        y.append(exp(-t)*(a0[1]*cos(2.0*t) - a0[0]*sin(2.0*t)))
        t += h
    return x, y

def center(a0, t0=0.0, h=0.001, tmax=5.0):
    #This generates a trajectory for the system dx/dt = y and dy/dt = -x.
    #The critical point for this system (0, 0) is a center.
    t = t0
    x = []
    y = []
    while t < tmax:
        x.append(a0[0]*sin(t))
        y.append(a0[1]*cos(t))
        t += h
    return x, y

def choose_fun(choice):
    if choice == "1" or choice == "2" or choice == "4": #Proper and improper nodes
        a0 = 1.0
        a = [[a0, 0.0], [a0/sqrt(2.0), a0/sqrt(2.0)], [0, a0], [-a0/sqrt(2.0), a0/sqrt(2.0)],
             [-a0, 0.0], [a0/sqrt(2.0), -a0/sqrt(2.0)], [0.0, -a0],
             [-a0/sqrt(2.0), -a0/sqrt(2.0)]]
        if choice == "1":
            fun = proper_node
        elif choice == "2":
            fun = improper_node
        else:
            fun = improper_node_a
        tm = 5.0
    elif choice == "3": #Saddle point
        a = [[1.0, 0.0], [0.0, 0.01], [-1.0, 0.0], [0.0, -0.01], [-1.0, 0.01], [-1.0, -0.01],
             [1.0, 0.01], [1.0, -0.01]]
        fun = saddle_point
        tm = 2.5
    elif choice == "5": #Spiral point
        a = [[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]]
        fun = spiral_point
        tm = 5.0
    else: #Center - This suite is executed if choice is not equal to '1', '2', '3', '4', or '5'
        a = [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]
        fun = center
        tm = 8.0
    return a, fun, tm
  
def genplot():
    menu_calctypes()
    choice = input("Enter your choice: ")
    a, fun, tm = choose_fun(choice)
    ax = plt.subplot()
    n = len(a)
    colors = ["red", "green", "blue", "purple", "yellow", "orange", "brown", "cyan"]
    for i in range(n):
        xp, yp = fun(a[i], tmax=tm)
        ax.scatter(xp, yp, 1, c=colors[i%len(colors)])
    ax.set(xlabel='x', ylabel='y', title='Neighborhood of Critical Point')
    plt.show()

genplot()