import numpy as np
import quantities as pq
from elephant.gpfa import GPFA
from elephant.spike_train_generation import StationaryPoissonProcess
from viziphant.gpfa import plot_trajectories
from matplotlib import pyplot as plt

data = []
for trial in range(50):
    n_channels = 20
    firing_rates = np.random.randint(low=1, high=100,
                                     size=n_channels) * pq.Hz
    spike_times = [StationaryPoissonProcess(rate=rate
                                            ).generate_spiketrain()
                   for rate in firing_rates]
    data.append(spike_times)

gpfa = GPFA(bin_size=20*pq.ms, x_dim=8)
gpfa.fit(data)

results = gpfa.transform(data, returned_data=['latent_variable_orth',
                                              'latent_variable'])

trial_id_lists = np.arange(50).reshape(5, 10)
trial_group_names = ['A', 'B', 'C', 'D', 'E']

trial_grouping_dict = {}
for trial_group_name, trial_id_list in zip(trial_group_names,
                                           trial_id_lists):
    trial_grouping_dict[trial_group_name] = trial_id_list

plot_trajectories(results,
                  gpfa,
                  dimensions=[0, 1, 2],
                  trial_grouping_dict=trial_grouping_dict,
                  plot_group_averages=True)
plt.show()