1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
// http://hyperphysics.phy-astr.gsu.edu/hbase/oscda.html
// https://www.ryanjuckett.com/damped-springs/

#[derive(Debug, Clone, Copy)]
pub struct Spring {
    pub mass: f32,
    pub stiffness: f32,
    pub damping: f32,
    pub(crate) initial_velocity: f32,
    pub(crate) initial_position: f32,
    pub(crate) last_update: f32,
}

const TOLERANCE: f32 = 0.00025;

impl Spring {
    pub fn new(mass: f32, stiffness: f32, damping: f32) -> Self {
        Spring {
            mass,
            stiffness,
            damping,
            initial_velocity: 0.0,
            initial_position: 0.0,
            last_update: 0.0,
        }
    }

    pub fn new_with_velocity(
        mass: f32,
        stiffness: f32,
        damping: f32,
        initial_velocity: f32,
    ) -> Self {
        Spring {
            mass,
            stiffness,
            damping,
            initial_velocity,
            initial_position: 0.0,
            last_update: 0.0,
        }
    }

    pub fn with_duration_and_bounce(duration: f32, bounce: f32) -> Self {
        let mass = 1.0;
        let omega = 2.0 * std::f32::consts::PI / duration; // Natural frequency
        let stiffness = mass * omega.powi(2); // Stiffness based on natural frequency

        // Calculate damping based on bounciness
        let damping = if bounce < 0.0 {
            // Overdamped
            2.0 * mass * omega * (1.0 + bounce.abs())
        } else if bounce == 0.0 {
            // Critically damped
            2.0 * mass * omega
        } else {
            // Underdamped
            2.0 * mass * omega * (1.0 - bounce)
        };

        Spring {
            mass,
            stiffness,
            damping,
            initial_velocity: 0.0,
            initial_position: 0.0,
            last_update: 0.0,
        }
    }
    pub fn with_duration_bounce_and_velocity(
        duration: f32,
        bounce: f32,
        initial_velocity: f32,
    ) -> Self {
        let mut spring = Spring::with_duration_and_bounce(duration, bounce);
        spring.initial_velocity = initial_velocity;
        spring
    }
    pub fn update_pos_vel_at(&self, t: f32) -> (f32, f32) {
        let target = 1.0;
        let omega = (self.stiffness / self.mass).sqrt();
        let zeta = self.damping / (2.0 * (self.mass * self.stiffness).sqrt());
        let delta_x = self.initial_position - target;

        if zeta < 1.0 {
            // Underdamped case
            let omega_d = omega * (1.0 - zeta * zeta).sqrt();
            let exp_decay = (-zeta * omega * t).exp();
            let cos_term = (omega_d * t).cos();
            let sin_term = (omega_d * t).sin();
            let new_position = target
                + exp_decay
                    * (delta_x * cos_term
                        + (self.initial_velocity + zeta * omega * delta_x) / omega_d * sin_term);
            let new_velocity = exp_decay
                * (self.initial_velocity * cos_term
                    - (self.initial_velocity + zeta * omega * delta_x) * sin_term / omega_d);
            (new_position, new_velocity)
        } else if zeta == 1.0 {
            // Critically damped case
            let exp_decay = (-omega * t).exp();
            let new_position =
                target + exp_decay * (delta_x + (self.initial_velocity + omega * delta_x) * t);
            let new_velocity =
                exp_decay * (self.initial_velocity * (1.0 - omega * t) + omega * delta_x * t);
            (new_position, new_velocity)
        } else {
            // Overdamped case
            let r1 = -omega * (zeta + (zeta * zeta - 1.0).sqrt());
            let r2 = -omega * (zeta - (zeta * zeta - 1.0).sqrt());
            let exp_r1 = (r1 * t).exp();
            let exp_r2 = (r2 * t).exp();
            let new_position = target + delta_x * (exp_r1 + exp_r2) / 2.0;
            let new_velocity = delta_x * (r1 * exp_r1 + r2 * exp_r2) / 2.0;
            (new_position, new_velocity)
        }
    }

    pub fn update_at(&mut self, elapsed: f32) -> f32 {
        // let dt = elapsed - self.last_update;
        self.update_pos_vel_at(elapsed).0
    }

    pub fn done(&self, elapsed: f32) -> bool {
        let target = 1.0;
        let (position, velocity) = self.update_pos_vel_at(elapsed);

        (position - target).abs() < TOLERANCE && velocity.abs() < TOLERANCE
    }
}