Skip to content

Commit

Permalink
fix indentation and issues with shape.
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasColas committed Sep 4, 2023
1 parent fe307e8 commit 0372930
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 21 deletions.
29 changes: 14 additions & 15 deletions PokerPlus/DeepCFR/game_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from PokerPlus.DeepCFR.memory import AdvantageMemory, StrategyMemory

from PokerPlus.DeepCFR.utils import get_opponent_player_num
from PokerPlus.DeepCFR.utils import get_opponent_player_num, available_moves


def compute_strategy(state: TexasHoldEm, strategy_net: DeepCFRModel, nb_actions=5):
Expand All @@ -28,13 +28,22 @@ def compute_strategy(state: TexasHoldEm, strategy_net: DeepCFRModel, nb_actions=
"""

# Get the available moves for the current player
legal_actions = state.get_available_moves()[:nb_actions]
print("state: ", state)
legal_actions = available_moves(state, nb_actions)
player = state.current_player

cards_board = [card_to_int[str(card)] for card in state.board]
flop = [card for card in cards_board[:3]]
turn = cards_board[3:4]
river = cards_board[4:5]
hole = state.get_hand(player)
hole = [card_to_int[str(card)] for card in hole]
cards = hole + cards_board
cards = [hole, flop]
if turn:
cards.extend([turn])

if river:
cards.extend([river])
# bets
bets = [val_bet for val_bet in state._get_last_pot().player_amounts.values()]
cards = [torch.tensor(c, dtype=torch.long) for c in cards]
Expand Down Expand Up @@ -73,17 +82,6 @@ def get_info_set(game: TexasHoldEm, player: int):
return cards, bets


def available_moves(game : TexasHoldEm, nb_actions : int):
actions = game.get_available_moves()
available_actions = []
for index, action in enumerate(actions):
if action[0] < nb_actions:
available_actions.append(action[1])
else:
return available_actions



def traverse(
game: TexasHoldEm,
current_player_to_compute_strategy: int,
Expand Down Expand Up @@ -128,9 +126,10 @@ def traverse(
) # Insert infoset and its action advantages
return np.max(action_values) # Return the value of the best action
else:
available_moves_ = available_moves(game, nb_actions)
opponent_num = get_opponent_player_num(current_player_to_compute_strategy)
sigma_t = compute_strategy(
get_info_set(game, opponent_num), theta2, nb_actions
game, theta2, nb_actions
) # Compute opponent's strategy

StrategyMemory.insert(
Expand Down
9 changes: 8 additions & 1 deletion PokerPlus/DeepCFR/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def __init__(self, dim):
self.card = nn.Embedding(52, dim)

def forward(self, input):
print("forward of CardEmbedding")
print("input: ", input)
B, num_cards = input.shape
x = input.view(-1)
valid = x.ge(0).float() # -1 means 'no card'
Expand Down Expand Up @@ -54,8 +56,13 @@ def forward(self, cards, bets):
# 1. card branch
# embed hole, flop, and optionally turn and river
card_embs = []
print("cards: ", cards)
print("bets: ", bets)

for embedding, card_group in zip(self.card_embeddings, cards):
card_embs.append(embedding(card_group))
print("do embedding")
print("card_group: ", card_group)
card_embs.append(embedding(card_group.view(1, -1)))
card_embs = torch.cat(card_embs, dim=1)
x = F.relu(self.card1(card_embs))
x = F.relu(self.card2(x))
Expand Down
13 changes: 12 additions & 1 deletion PokerPlus/DeepCFR/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from texasholdem.card.deck import Deck
from texasholdem.card.card import Card
from texasholdem.game.game import TexasHoldEm

d = Deck()
card_to_int = {}
Expand All @@ -17,4 +18,14 @@ def get_opponent_player_num(current_player: int):
if current_player == 1:
return 0
else:
return 1
return 1


def available_moves(game: TexasHoldEm, nb_actions: int):
actions = game.get_available_moves()
available_actions = []
for index, action in enumerate(actions):
if index < nb_actions:
available_actions.append(action[1])
else:
return available_actions
4 changes: 2 additions & 2 deletions random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
print("game available moves slicing : ", game.get_available_moves()[:5])
for move in game.get_available_moves():
print("move : ", move)
gui.run_step()

gui.run_step()
15 changes: 13 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
from PokerPlus.DeepCFR.deep_cfr import deep_cfr, save_deep_cfr
from texasholdem.game.game import TexasHoldEm


def main():
print("create game")
game = TexasHoldEm(buyin=1500, big_blind=80, small_blind=40, max_players=2)
game.start_hand()
save_deep_cfr("", "DeepCFR", nb_iterations=10000, nb_players=2, nb_game_tree_traversals=200, game=game, n_actions=3, n_card_types=4, n_bets=20)
save_deep_cfr(
"",
"DeepCFR",
nb_iterations=10000,
nb_players=2,
nb_game_tree_traversals=200,
game=game,
n_actions=3,
n_card_types=4,
n_bets=20,
)


if __name__ == "__main__":
main()
main()

0 comments on commit 0372930

Please sign in to comment.