Galaxy-sim/barnes-hut-galaxy-new.py
Crizomb 052870b7de
increasing performances
in newton_law d*d*d is faster than d**3
2023-02-15 13:40:18 +01:00

279 lines
8.6 KiB
Python

from vec3 import Vec3
from typing import Union, List
from matplotlib import pyplot as plt
import time
import numpy as np
from vispy import scene, app
"""
Galaxy simulation using the Barnes-Hut algorithm.
following this guide http://arborjs.org/docs/barnes-hut#:~:text=The%20Barnes%2DHut%20algorithm%20is,of%20which%20may%20be%20empty)
"""
"""Global const"""
THETA = 0.5 #more is less precise but faster
THETA_SQUARED = THETA ** 2
CHILDREN_CONV = {"top_left_front": 0, "top_left_back": 1, "top_right_front": 2, "top_right_back": 3, "bottom_left_front": 4, "bottom_left_back": 5, "bottom_right_front": 6, "bottom_right_back": 7}
#is_top, is_left, is_front = pos.y > center.y, pos.x < center.x, pos.z < center.z
CHILDREN_CONV_BOOL = {(True, True, True): 0, (True, True, False): 1, (True, False, True): 2, (True, False, False): 3, (False, True, True): 4, (False, True, False): 5, (False, False, True): 6, (False, False, False): 7}
SMOOTHING = 0.5
G = 10
"""useful functions"""
def newton_law(mass1: float, mass2: float, vec_diff: Vec3) -> Vec3:
"""Return the force applied by mass1 on mass2."""
d = vec_diff.norm + SMOOTHING
return G * vec_diff * (mass1 * mass2 / (d*d*d))
def get_time(func):
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
print(f"{func.__name__} took {time.perf_counter() - start} seconds")
return result
return wrapper
"""Type defintion"""
class Star:
__slots__ = "pos", "speed", "mass"
def __init__(self, pos: Vec3, speed: Vec3, mass: float):
self.pos = pos
self.speed = speed
self.mass = mass
class Node:
""""Node class"""
__slots__ = "center", "center_of_mass", "mass", "children", "width"
def __init__(self, center: Vec3, center_of_mass: Vec3, mass: float, children, width: float):
self.center = center
self.center_of_mass = center_of_mass
self.mass = mass
self.children = children #[] if is_external else child 0 is top_left_front, 1 top_left_back, etc... see CHILDREN_CONV
self.width = width
@property
def is_external(self) -> bool:
"""Return True if the node is external."""
return self.children == []
def is_empty(self) -> bool:
"""Return True if the node is empty."""
return self.mass == 0
def __repr__(self):
if self.is_external:
return f"External(center={self.center}, center_of_mass={self.center_of_mass}, mass={self.mass}, width={self.width}"
else:
return f"Internal(center={self.center}, center_of_mass={self.center_of_mass}, mass={self.mass}, width={self.width}"
def get_child_index(self, pos: Vec3) -> int:
"""Return the index of the child in which the star should be inserted. see CHILDREN_CONV."""
center = self.center
is_top, is_left, is_front = pos.y > center.y, pos.x < center.x, pos.z < center.z
return CHILDREN_CONV_BOOL[(is_top, is_left, is_front)]
def create_children(self) -> None:
"""Create the children of a node. Divide the width by 2. Divide self into 8 children."""
children = [None] * 8
new_width = self.width / 2
center = self.center
for is_top in (True, False):
for is_left in (True, False):
for is_front in (True, False):
center_x = center.x - new_width if is_left else center.x + new_width
center_y = center.y + new_width if is_top else center.y - new_width
center_z = center.z - new_width if is_front else center.z + new_width
index = CHILDREN_CONV_BOOL[(is_top, is_left, is_front)]
children[index] = Node(Vec3(center_x, center_y, center_z), Vec3(0, 0, 0), 0, [], new_width)
self.children = children
def insert(self, star: Star) -> None:
"""Insert a star in the tree."""
if self.is_empty():
self.center_of_mass = star.pos
self.mass = star.mass
elif not self.is_external:
self.center_of_mass = (self.center_of_mass * self.mass + star.pos * star.mass) / (self.mass + star.mass)
self.mass += star.mass
self.children[self.get_child_index(star.pos)].insert(star)
else:
self.create_children()
self.center_of_mass = (self.center_of_mass * self.mass + star.pos * star.mass) / (self.mass + star.mass)
self.mass += star.mass
self.children[self.get_child_index(star.pos)].insert(star)
def get_force(self, star: Star) -> Vec3:
"""Return the force applied by the node on the star. Peformance bottleneck (called for each star)."""
dist_squared = (self.center_of_mass - star.pos).norm_squared
if self.is_external or self.width*self.width / dist_squared < THETA_SQUARED:
return newton_law(self.mass, star.mass, self.center - star.pos)
else:
acc = Vec3.zero()
for child in self.children:
acc += child.get_force(star)
return acc
#return sum((child.get_force(star) for child in self.children), Vec3.zero()) #slower
"""Init simulation"""
WIDTH = 50
"""Create random stars"""
def create_stars(nb_stars: int) -> List[Star]:
"""Return a galaxy like array of n_stars stars so stars rotate around the center, stars are more or less in the y=0 plan"""
stars = []
rand = lambda: np.random.normal() * WIDTH / 5
for _ in range(nb_stars):
x,y,z = rand(), rand() / 5, rand()
pos = Vec3(x, y, z)
orthog = Vec3(1, 0, -x / z) if z > 0 else Vec3(-1, 0, x / z)
speed = orthog.normalize()*pos.norm
stars.append(Star(pos, speed, 1))
#add center blackhole
stars.append(Star(Vec3(0, 0, 0), Vec3(0, 0, 0), nb_stars/10))
return stars
def create_tree(stars: List[Star]) -> Node:
"""Return the root of the tree."""
root = Node(Vec3(WIDTH / 2, WIDTH / 2, WIDTH / 2), Vec3(0, 0, 0), 0, [], WIDTH)
for star in stars:
root.insert(star)
return root
def update_star(star, root, dt):
"""Update the speed and the position of a star.
using leapfrog method."""
force = root.get_force(star)
star.speed += force * dt / 2
star.pos += star.speed * dt
force = root.get_force(star)
star.speed += force * dt / 2
def update_speed_and_pos(stars: List[Star], root: Node, dt: float) -> None:
"""Update the speed and the position of the stars"""
for star in stars:
update_star(star, root, dt)
def compute_next_state(stars: List[Star], dt: float):
"""Compute the next state of the simulation."""
root = create_tree(stars)
update_speed_and_pos(stars, root, dt)
"""
Graphical part
"""
plt.style.use("dark_background")
def animate_stars_plt(stars_array, dt):
"""Animate the stars in the array. turn off the axis and the grid. fix the size of the figure."""
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_axis_off()
ax.grid(False)
size = 20
ax.set_xlim(-size, size)
ax.set_ylim(-size, size)
ax.set_zlim(-size, size)
ax.set_aspect('equal')
#set view parrallel to the xz plane
ax.view_init(elev=90, azim=90)
#plot the stars
while True:
start = time.perf_counter()
ax.clear()
ax.set_axis_off()
ax.grid(False)
ax.set_xlim(-size, size)
ax.set_ylim(-size, size)
ax.set_zlim(-size, size)
ax.set_aspect('equal')
compute_next_state(stars_array, dt)
X, Y, Z = zip(*[star.pos for star in stars_array])
ax.scatter(X, Y, Z, s=0.1, color='r')
end = time.perf_counter()
print(f"fps: {1/(end-start):}")
plt.pause(0.01)
plt.show()
def animate_stars_vispy(stars_array, dt):
MAX_FPS = 20
"""Animate the stars in the array. faster than the matplotlib version."""
canvas = scene.SceneCanvas(keys='interactive', show=True)
view = canvas.central_widget.add_view()
#center the view
view.camera = 'turntable'
view.camera.distance = 50
@get_time
def update(event):
a = time.perf_counter()
compute_next_state(stars_array, dt)
points = np.array([star.pos.np for star in stars_array])
scatter.set_data(points, edge_color=None, face_color=(1, 0.5, 1, 0.3), size=4, )
canvas.update()
b = time.perf_counter()
if b-a < 1/MAX_FPS:
time.sleep(1/MAX_FPS - (b-a))
timer = app.Timer('auto', connect=update, start=True)
scatter = scene.visuals.Markers()
view.add(scatter)
canvas.show()
app.run()
def main():
stars = create_stars(10**3)
animate_stars_vispy(stars, 0.01)
if __name__ == '__main__':
main()