279 lines
8.6 KiB
Python
279 lines
8.6 KiB
Python
from vec3 import Vec3
|
|
from typing import Union, List
|
|
from matplotlib import pyplot as plt
|
|
import time
|
|
import numpy as np
|
|
from vispy import scene, app
|
|
|
|
|
|
"""
|
|
Galaxy simulation using the Barnes-Hut algorithm.
|
|
following this guide http://arborjs.org/docs/barnes-hut#:~:text=The%20Barnes%2DHut%20algorithm%20is,of%20which%20may%20be%20empty)
|
|
"""
|
|
|
|
"""Global const"""
|
|
THETA = 0.5 #more is less precise but faster
|
|
THETA_SQUARED = THETA ** 2
|
|
CHILDREN_CONV = {"top_left_front": 0, "top_left_back": 1, "top_right_front": 2, "top_right_back": 3, "bottom_left_front": 4, "bottom_left_back": 5, "bottom_right_front": 6, "bottom_right_back": 7}
|
|
#is_top, is_left, is_front = pos.y > center.y, pos.x < center.x, pos.z < center.z
|
|
CHILDREN_CONV_BOOL = {(True, True, True): 0, (True, True, False): 1, (True, False, True): 2, (True, False, False): 3, (False, True, True): 4, (False, True, False): 5, (False, False, True): 6, (False, False, False): 7}
|
|
SMOOTHING = 0.5
|
|
G = 10
|
|
|
|
"""useful functions"""
|
|
|
|
def newton_law(mass1: float, mass2: float, vec_diff: Vec3) -> Vec3:
|
|
"""Return the force applied by mass1 on mass2."""
|
|
d = vec_diff.norm + SMOOTHING
|
|
return G * vec_diff * (mass1 * mass2 / (d*d*d))
|
|
|
|
def get_time(func):
|
|
def wrapper(*args, **kwargs):
|
|
start = time.perf_counter()
|
|
result = func(*args, **kwargs)
|
|
print(f"{func.__name__} took {time.perf_counter() - start} seconds")
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
|
|
"""Type defintion"""
|
|
|
|
class Star:
|
|
__slots__ = "pos", "speed", "mass"
|
|
def __init__(self, pos: Vec3, speed: Vec3, mass: float):
|
|
self.pos = pos
|
|
self.speed = speed
|
|
self.mass = mass
|
|
|
|
|
|
class Node:
|
|
""""Node class"""
|
|
__slots__ = "center", "center_of_mass", "mass", "children", "width"
|
|
def __init__(self, center: Vec3, center_of_mass: Vec3, mass: float, children, width: float):
|
|
self.center = center
|
|
self.center_of_mass = center_of_mass
|
|
self.mass = mass
|
|
self.children = children #[] if is_external else child 0 is top_left_front, 1 top_left_back, etc... see CHILDREN_CONV
|
|
self.width = width
|
|
|
|
|
|
@property
|
|
def is_external(self) -> bool:
|
|
"""Return True if the node is external."""
|
|
return self.children == []
|
|
|
|
def is_empty(self) -> bool:
|
|
"""Return True if the node is empty."""
|
|
return self.mass == 0
|
|
|
|
def __repr__(self):
|
|
if self.is_external:
|
|
return f"External(center={self.center}, center_of_mass={self.center_of_mass}, mass={self.mass}, width={self.width}"
|
|
else:
|
|
return f"Internal(center={self.center}, center_of_mass={self.center_of_mass}, mass={self.mass}, width={self.width}"
|
|
|
|
def get_child_index(self, pos: Vec3) -> int:
|
|
"""Return the index of the child in which the star should be inserted. see CHILDREN_CONV."""
|
|
center = self.center
|
|
is_top, is_left, is_front = pos.y > center.y, pos.x < center.x, pos.z < center.z
|
|
return CHILDREN_CONV_BOOL[(is_top, is_left, is_front)]
|
|
|
|
def create_children(self) -> None:
|
|
"""Create the children of a node. Divide the width by 2. Divide self into 8 children."""
|
|
children = [None] * 8
|
|
new_width = self.width / 2
|
|
center = self.center
|
|
for is_top in (True, False):
|
|
for is_left in (True, False):
|
|
for is_front in (True, False):
|
|
center_x = center.x - new_width if is_left else center.x + new_width
|
|
center_y = center.y + new_width if is_top else center.y - new_width
|
|
center_z = center.z - new_width if is_front else center.z + new_width
|
|
index = CHILDREN_CONV_BOOL[(is_top, is_left, is_front)]
|
|
children[index] = Node(Vec3(center_x, center_y, center_z), Vec3(0, 0, 0), 0, [], new_width)
|
|
self.children = children
|
|
|
|
def insert(self, star: Star) -> None:
|
|
"""Insert a star in the tree."""
|
|
if self.is_empty():
|
|
self.center_of_mass = star.pos
|
|
self.mass = star.mass
|
|
|
|
elif not self.is_external:
|
|
self.center_of_mass = (self.center_of_mass * self.mass + star.pos * star.mass) / (self.mass + star.mass)
|
|
self.mass += star.mass
|
|
self.children[self.get_child_index(star.pos)].insert(star)
|
|
|
|
else:
|
|
self.create_children()
|
|
self.center_of_mass = (self.center_of_mass * self.mass + star.pos * star.mass) / (self.mass + star.mass)
|
|
self.mass += star.mass
|
|
self.children[self.get_child_index(star.pos)].insert(star)
|
|
|
|
|
|
def get_force(self, star: Star) -> Vec3:
|
|
"""Return the force applied by the node on the star. Peformance bottleneck (called for each star)."""
|
|
dist_squared = (self.center_of_mass - star.pos).norm_squared
|
|
if self.is_external or self.width*self.width / dist_squared < THETA_SQUARED:
|
|
return newton_law(self.mass, star.mass, self.center - star.pos)
|
|
else:
|
|
acc = Vec3.zero()
|
|
for child in self.children:
|
|
acc += child.get_force(star)
|
|
return acc
|
|
|
|
#return sum((child.get_force(star) for child in self.children), Vec3.zero()) #slower
|
|
|
|
|
|
|
|
"""Init simulation"""
|
|
WIDTH = 50
|
|
|
|
"""Create random stars"""
|
|
def create_stars(nb_stars: int) -> List[Star]:
|
|
"""Return a galaxy like array of n_stars stars so stars rotate around the center, stars are more or less in the y=0 plan"""
|
|
stars = []
|
|
rand = lambda: np.random.normal() * WIDTH / 5
|
|
for _ in range(nb_stars):
|
|
x,y,z = rand(), rand() / 5, rand()
|
|
pos = Vec3(x, y, z)
|
|
orthog = Vec3(1, 0, -x / z) if z > 0 else Vec3(-1, 0, x / z)
|
|
speed = orthog.normalize()*pos.norm
|
|
stars.append(Star(pos, speed, 1))
|
|
|
|
#add center blackhole
|
|
stars.append(Star(Vec3(0, 0, 0), Vec3(0, 0, 0), nb_stars/10))
|
|
|
|
return stars
|
|
|
|
def create_tree(stars: List[Star]) -> Node:
|
|
"""Return the root of the tree."""
|
|
root = Node(Vec3(WIDTH / 2, WIDTH / 2, WIDTH / 2), Vec3(0, 0, 0), 0, [], WIDTH)
|
|
for star in stars:
|
|
root.insert(star)
|
|
return root
|
|
|
|
|
|
def update_star(star, root, dt):
|
|
"""Update the speed and the position of a star.
|
|
using leapfrog method."""
|
|
force = root.get_force(star)
|
|
star.speed += force * dt / 2
|
|
star.pos += star.speed * dt
|
|
force = root.get_force(star)
|
|
star.speed += force * dt / 2
|
|
|
|
|
|
def update_speed_and_pos(stars: List[Star], root: Node, dt: float) -> None:
|
|
"""Update the speed and the position of the stars"""
|
|
for star in stars:
|
|
update_star(star, root, dt)
|
|
|
|
|
|
|
|
def compute_next_state(stars: List[Star], dt: float):
|
|
"""Compute the next state of the simulation."""
|
|
root = create_tree(stars)
|
|
update_speed_and_pos(stars, root, dt)
|
|
|
|
"""
|
|
Graphical part
|
|
"""
|
|
plt.style.use("dark_background")
|
|
def animate_stars_plt(stars_array, dt):
|
|
"""Animate the stars in the array. turn off the axis and the grid. fix the size of the figure."""
|
|
fig = plt.figure()
|
|
ax = fig.add_subplot(111, projection='3d')
|
|
ax.set_axis_off()
|
|
ax.grid(False)
|
|
|
|
size = 20
|
|
ax.set_xlim(-size, size)
|
|
ax.set_ylim(-size, size)
|
|
ax.set_zlim(-size, size)
|
|
ax.set_aspect('equal')
|
|
#set view parrallel to the xz plane
|
|
ax.view_init(elev=90, azim=90)
|
|
|
|
#plot the stars
|
|
while True:
|
|
start = time.perf_counter()
|
|
ax.clear()
|
|
ax.set_axis_off()
|
|
ax.grid(False)
|
|
ax.set_xlim(-size, size)
|
|
ax.set_ylim(-size, size)
|
|
ax.set_zlim(-size, size)
|
|
ax.set_aspect('equal')
|
|
compute_next_state(stars_array, dt)
|
|
X, Y, Z = zip(*[star.pos for star in stars_array])
|
|
ax.scatter(X, Y, Z, s=0.1, color='r')
|
|
end = time.perf_counter()
|
|
print(f"fps: {1/(end-start):}")
|
|
plt.pause(0.01)
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
def animate_stars_vispy(stars_array, dt):
|
|
MAX_FPS = 20
|
|
"""Animate the stars in the array. faster than the matplotlib version."""
|
|
canvas = scene.SceneCanvas(keys='interactive', show=True)
|
|
view = canvas.central_widget.add_view()
|
|
#center the view
|
|
view.camera = 'turntable'
|
|
view.camera.distance = 50
|
|
|
|
@get_time
|
|
def update(event):
|
|
a = time.perf_counter()
|
|
compute_next_state(stars_array, dt)
|
|
points = np.array([star.pos.np for star in stars_array])
|
|
scatter.set_data(points, edge_color=None, face_color=(1, 0.5, 1, 0.3), size=4, )
|
|
canvas.update()
|
|
b = time.perf_counter()
|
|
if b-a < 1/MAX_FPS:
|
|
time.sleep(1/MAX_FPS - (b-a))
|
|
|
|
|
|
timer = app.Timer('auto', connect=update, start=True)
|
|
scatter = scene.visuals.Markers()
|
|
view.add(scatter)
|
|
canvas.show()
|
|
app.run()
|
|
|
|
|
|
|
|
def main():
|
|
stars = create_stars(10**3)
|
|
animate_stars_vispy(stars, 0.01)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|