sad
차이
문서의 선택한 두 판 사이의 차이를 보여줍니다.
| 다음 판 | 이전 판 | ||
| sad [2020/06/01 23:20] – 만듦 rex8312 | sad [2024/03/23 02:38] (현재) – 바깥 편집 127.0.0.1 | ||
|---|---|---|---|
| 줄 1: | 줄 1: | ||
| ====== SAD (Simplified Action Decoder) ====== | ====== SAD (Simplified Action Decoder) ====== | ||
| - | <code python> | + | {{: |
| + | * [[https:// | ||
| + | * [[https:// | ||
| + | * [[https:// | ||
| + | |||
| + | <code python> | ||
| import numpy as np | import numpy as np | ||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||
| 줄 9: | 줄 14: | ||
| import torch.optim as optim | import torch.optim as optim | ||
| from IPython import embed | from IPython import embed | ||
| - | |||
| - | |||
| - | def repeat_tensor(tensor, | ||
| - | with tf.variable_scope(" | ||
| - | exp_tensor = tf.expand_dims(tensor, | ||
| - | tensor_t = tf.tile(exp_tensor, | ||
| - | tensor_r = tf.reshape(tensor_t, | ||
| - | return tensor_r | ||
| # payoff values | # payoff values | ||
| payoff_values = [ | payoff_values = [ | ||
| - | [[[10, 0, 0], [4, 8, 4], [10, 0, 0]], | + | [[[10, 0, 0], |
| - | [[0, 0, 10], [4, 8, 4], [0, 0, 10]]], | + | |
| - | [[[0, 0, 10], [4, 8, 4], [0, 0, 0]], | + | |
| - | [[10, 0, 0], [4, 8, 4], [10, 0, 0]]] | + | [[0, 0, 10], |
| + | | ||
| + | | ||
| + | [[[0, 0, 10], | ||
| + | | ||
| + | | ||
| + | [[10, 0, 0], | ||
| + | | ||
| + | | ||
| ] | ] | ||
| - | payoff_values = np.array( payoff_values ) | + | payoff_values = np.array(payoff_values) |
| n_cards = 2 | n_cards = 2 | ||
| 줄 39: | 줄 44: | ||
| final_epsilon = 0.05 | final_epsilon = 0.05 | ||
| - | n_runs = 20 | + | n_runs = 50 |
| - | n_episodes = 100000 | + | n_episodes = 50000 |
| - | n_readings = 100 | + | n_readings = 25 |
| np.random.seed(seed) | np.random.seed(seed) | ||
| 줄 52: | 줄 57: | ||
| interval = n_episodes // n_readings | interval = n_episodes // n_readings | ||
| - | for bad_mode in [2, 4]: | + | |
| - | print(' | + | print(' |
| - | for n_r in range(n_runs): | + | |
| - | net0 = nn.Linear(n_cards, n_actions) | + | print(' |
| - | | + | |
| - | net1 = nn.Linear(input_size_1, n_actions) | + | net0 = nn.Sequential( |
| - | optimizer = optim.SGD( | + | |
| + | nn.ReLU(), | ||
| + | nn.Linear(32, n_actions), | ||
| + | | ||
| + | net1 = nn.Sequential( | ||
| + | | ||
| + | nn.ReLU(), | ||
| + | nn.Linear(32, n_actions), | ||
| + | | ||
| + | optimizer = optim.Adam( | ||
| [ | [ | ||
| {' | {' | ||
| {' | {' | ||
| ], | ], | ||
| - | lr=0.01, | + | lr=0.001, |
| ) | ) | ||
| greedy = 1 if bad_mode > 3 else 0 | greedy = 1 if bad_mode > 3 else 0 | ||
| - | | + | |
| for j in range(n_episodes+1): | for j in range(n_episodes+1): | ||
| cards_0 = np.random.choice(n_cards, | cards_0 = np.random.choice(n_cards, | ||
| 줄 72: | 줄 86: | ||
| eps = 0 | eps = 0 | ||
| if j % (interval) != 0: | if j % (interval) != 0: | ||
| - | eps = max(final_epsilon, | + | eps = max(final_epsilon, |
| | | ||
| with torch.no_grad(): | with torch.no_grad(): | ||
| - | input_0 = np.eye(n_cards)[cards_0] | + | |
| - | input_0 = torch.from_numpy(input_0).to(torch.float32) | + | |
| + | input_0 = torch.from_numpy(input_0).float() | ||
| q_vals = net0(input_0) | q_vals = net0(input_0) | ||
| qv0, qv0_i = q_vals.max(1) | qv0, qv0_i = q_vals.max(1) | ||
| 줄 89: | 줄 104: | ||
| u0.numpy() * n_actions + \ | u0.numpy() * n_actions + \ | ||
| u0_greedy.numpy() * greedy | u0_greedy.numpy() * greedy | ||
| - | input_1 = np.eye(input_size_1)[joint_in1] | + | input_1 = np.eye(input_size)[joint_in1] |
| - | input_1 = torch.from_numpy(input_1).to(torch.float32) | + | input_1 = torch.from_numpy(input_1).float() |
| q_vals = net1(input_1) | q_vals = net1(input_1) | ||
| qv1, qv1_i = q_vals.max(1) | qv1, qv1_i = q_vals.max(1) | ||
| 줄 100: | 줄 115: | ||
| q1_greedy = (q_vals * torch.eye(n_actions)[u1_greedy]).sum(1) | q1_greedy = (q_vals * torch.eye(n_actions)[u1_greedy]).sum(1) | ||
| - | | + | |
| - | payoff_values[cards_0[i], | + | payoff_values[cards_0[i], |
| - | for i in range(bs) | + | for i in range(bs) |
| - | ] | + | ] |
| - | rew = torch.from_numpy(np.array(rew)).to(torch.float32) | + | rew = torch.from_numpy(np.array(rew)).float() |
| - | q0 = net0(input_0).gather(1, u0.view(-1, 1)) | + | q0 = net0(input_0)[torch.arange(0, bs).long(), u0] |
| - | q1 = net1(input_1).gather(1, u1.view(-1, 1)) | + | q1 = net1(input_1)[torch.arange(0, bs).long(), u1] |
| optimizer.zero_grad() | optimizer.zero_grad() | ||
| 줄 113: | 줄 128: | ||
| loss.backward() | loss.backward() | ||
| optimizer.step() | optimizer.step() | ||
| + | |||
| + | net0, net1 = net1, net0 | ||
| | | ||
| if eps == 0: | if eps == 0: | ||
| - | all_r[bad_mode, | + | all_r[bad_mode, |
| - | + | ||
| - | if j % (n_episodes // 10) == 0: | + | |
| print(j, ' | print(j, ' | ||
| - | | + | |
| - | colors = ['','','# | + | colors = ['','','# |
| - | plt.figure(figsize=(6, | + | plt.figure(figsize=(6, |
| - | x_vals = np.arange(n_readings+1)* interval | + | x_vals = np.arange(n_readings+1) * interval |
| - | for bad_mode in [2,4]: | + | for bad_mode in [2, 4]: |
| - | vals = all_r[bad_mode] | + | vals = all_r[bad_mode][:n_r+1] |
| - | y_m = vals.mean(0) | + | y_m = vals.mean(0) |
| - | y_std = vals.std(0) / ( n_runs**0.5 ) | + | y_std = vals.std(0) / (n_runs**0.5) |
| - | plt.plot( x_vals, y_m, colors[bad_mode], | + | plt.plot(x_vals, |
| - | plt.fill_between(x_vals, | + | plt.fill_between(x_vals, |
| - | plt.ylim([7.5, | + | plt.ylim([7.5, |
| - | plt.legend() | + | plt.legend() |
| - | None | + | None |
| - | plt.xlabel(' | + | plt.xlabel(' |
| - | plt.ylabel(' | + | plt.ylabel(' |
| - | plt.savefig(' | + | plt.savefig(' |
| + | plt.clf() | ||
| </ | </ | ||
sad.1591053639.txt.gz · 마지막으로 수정됨: (바깥 편집)