from pythreejs import *
from sympy import *
import math
from collections import Iterable

__all__ = ['Vector','arrow','mark','draw','vec_geom','refresh']


# from: https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python
def flatten(items):
    """Yield items from any nested iterable; see Reference."""
    for x in items:
        if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
            for sub_x in flatten(x):
                yield sub_x
        else:
            yield x

####

def refresh(ds):
    for d in ds: d.refresh()

def Vector(*v):
    return Matrix([[x] for x in v])

class arrow:
    def __init__(self, v, start=Vector(0,0,0), color='blue', w=0.04):
        self.v = v
        self.start = start
        self.w = w
        self.geom = vec_geom([float(x) for x in v], start=[float(x) for x in start], w=self.w)
        self.mesh = Mesh(geometry=self.geom,material=MeshLambertMaterial(color=color))
    def refresh(self):
        self.geom = vec_geom([float(x) for x in self.v], start=[float(x) for x in self.start], w=self.w)
        self.mesh.geometry = self.geom

class mark:
    def __init__(self, v, color='black', w=0.08):
        self.v = v
        self.geom = SphereGeometry(radius=0.05)
        self.mesh = Mesh(geometry=self.geom, 
                         material=MeshLambertMaterial(color=color),
                         position=[float(x) for x in self.v])


def vec_geom(v, start, w=0.04):
    p1 = start
    p2 = [start[0] + v[0], start[1] + v[1], start[2] + v[2]]
    dx = p2[0] - p1[0]
    dy = p2[1] - p1[1]
    dz = p2[2] - p1[2]
    d = math.sqrt(dx*dx + dy*dy + dz*dz)
    if d == 0:
        return Geometry()
    
    #w = 0.04
    p3 = [p2[0] - 6*dx*w/d, p2[1] - 6*dy*w/d, p2[2] - 6*dz*w/d]
    
    if (dx == 0 and dy == 0):
        o1 = [w,0,0]
        o2 = [0,w,0]
    else:
        o1 = [dy,-dx,0]
        o2 = [dx*dz,dy*dz,-dx*dx-dy*dy]
        n1 = math.sqrt(o1[0]*o1[0]+o1[1]*o1[1])
        n2 = math.sqrt(o2[0]*o2[0]+o2[1]*o2[1]+o2[2]*o2[2])
        o1 = [o1[0]/n1 * w, o1[1]/n1 * w, o1[2]/n1 * w]
        o2 = [o2[0]/n2 * w, o2[1]/n2 * w, o2[2]/n2 * w]
    
    vertices = [
        [p1[0] - o1[0], p1[1] - o1[1], p1[2] - o1[2]],
        [p1[0] + o2[0], p1[1] + o2[1], p1[2] + o2[2]],
        [p1[0] - o2[0], p1[1] - o2[1], p1[2] - o2[2]],
        [p1[0] + o1[0], p1[1] + o1[1], p1[2] + o1[2]],
        [p3[0] - o1[0], p3[1] - o1[1], p3[2] - o1[2]],
        [p3[0] + o2[0], p3[1] + o2[1], p3[2] + o2[2]],
        [p3[0] - o2[0], p3[1] - o2[1], p3[2] - o2[2]],
        [p3[0] + o1[0], p3[1] + o1[1], p3[2] + o1[2]],
        
        [p3[0] - 3*o1[0], p3[1] - 3*o1[1], p3[2] - 3*o1[2]],
        [p3[0] + 3*o2[0], p3[1] + 3*o2[1], p3[2] + 3*o2[2]],
        [p3[0] - 3*o2[0], p3[1] - 3*o2[1], p3[2] - 3*o2[2]],
        [p3[0] + 3*o1[0], p3[1] + 3*o1[1], p3[2] + 3*o1[2]],
        p2,p2,p2,p2
    ]

    faces = [
        [0, 1, 3],[0, 3, 2],[0, 2, 4],[2, 6, 4],[0, 4, 1],[1, 4, 5],
        [2, 3, 6],[3, 7, 6],[1, 5, 3],[3, 5, 7],[4, 6, 5],[5, 6, 7]
    ]
    
    faces = faces + [[f[0]+8,f[1]+8,f[2]+8] for f in faces]

    # Create the geometry:
    cubeGeometry = Geometry(vertices=vertices,
        faces=faces)
    # Calculate normals per face, for nice crisp edges:
    cubeGeometry.exec_three_obj_method('computeFaceNormals')
    return cubeGeometry

def slider(mesh,fv):
    def f(x): mesh.geometry = fv(x)
    interact(f, x=(0,1,0.01))

def draw(*vecs):
    vecs = flatten(vecs)
    meshes = [arrow(Vector(5,0,0),color='gray', w=0.02).mesh,
              arrow(Vector(0,5,0),color='gray', w=0.02).mesh,
              arrow(Vector(0,0,5),color='gray', w=0.02).mesh] + [v.mesh for v in vecs]
    
    cam = PerspectiveCamera(position=[10, 15, 25], fov=25) #,
                            #children=[DirectionalLight(color='#ffffff', position=[-3, 5, 1], intensity=0.5)])
    scene = Scene(children=meshes + [cam, AmbientLight(color='#ffffff')])
    r = Renderer(camera=cam, background='black', background_opacity=1, antialias=True,
          scene=scene, controls=[OrbitControls(controlling=cam)],
          width=500,height=500)
    display(r)