diff --git a/PokerPlus/DeepCFR/game_tree.py b/PokerPlus/DeepCFR/game_tree.py index 6d956e8..7ea4a44 100644 --- a/PokerPlus/DeepCFR/game_tree.py +++ b/PokerPlus/DeepCFR/game_tree.py @@ -17,7 +17,7 @@ from PokerPlus.DeepCFR.memory import AdvantageMemory, StrategyMemory -from PokerPlus.DeepCFR.utils import get_opponent_player_num +from PokerPlus.DeepCFR.utils import get_opponent_player_num, available_moves def compute_strategy(state: TexasHoldEm, strategy_net: DeepCFRModel, nb_actions=5): @@ -28,13 +28,22 @@ def compute_strategy(state: TexasHoldEm, strategy_net: DeepCFRModel, nb_actions= """ # Get the available moves for the current player - legal_actions = state.get_available_moves()[:nb_actions] + print("state: ", state) + legal_actions = available_moves(state, nb_actions) player = state.current_player cards_board = [card_to_int[str(card)] for card in state.board] + flop = [card for card in cards_board[:3]] + turn = cards_board[3:4] + river = cards_board[4:5] hole = state.get_hand(player) hole = [card_to_int[str(card)] for card in hole] - cards = hole + cards_board + cards = [hole, flop] + if turn: + cards.extend([turn]) + + if river: + cards.extend([river]) # bets bets = [val_bet for val_bet in state._get_last_pot().player_amounts.values()] cards = [torch.tensor(c, dtype=torch.long) for c in cards] @@ -73,17 +82,6 @@ def get_info_set(game: TexasHoldEm, player: int): return cards, bets -def available_moves(game : TexasHoldEm, nb_actions : int): - actions = game.get_available_moves() - available_actions = [] - for index, action in enumerate(actions): - if action[0] < nb_actions: - available_actions.append(action[1]) - else: - return available_actions - - - def traverse( game: TexasHoldEm, current_player_to_compute_strategy: int, @@ -128,9 +126,10 @@ def traverse( ) # Insert infoset and its action advantages return np.max(action_values) # Return the value of the best action else: + available_moves_ = available_moves(game, nb_actions) opponent_num = get_opponent_player_num(current_player_to_compute_strategy) sigma_t = compute_strategy( - get_info_set(game, opponent_num), theta2, nb_actions + game, theta2, nb_actions ) # Compute opponent's strategy StrategyMemory.insert( diff --git a/PokerPlus/DeepCFR/nn.py b/PokerPlus/DeepCFR/nn.py index 3341d9d..ded11ec 100644 --- a/PokerPlus/DeepCFR/nn.py +++ b/PokerPlus/DeepCFR/nn.py @@ -20,6 +20,8 @@ def __init__(self, dim): self.card = nn.Embedding(52, dim) def forward(self, input): + print("forward of CardEmbedding") + print("input: ", input) B, num_cards = input.shape x = input.view(-1) valid = x.ge(0).float() # -1 means 'no card' @@ -54,8 +56,13 @@ def forward(self, cards, bets): # 1. card branch # embed hole, flop, and optionally turn and river card_embs = [] + print("cards: ", cards) + print("bets: ", bets) + for embedding, card_group in zip(self.card_embeddings, cards): - card_embs.append(embedding(card_group)) + print("do embedding") + print("card_group: ", card_group) + card_embs.append(embedding(card_group.view(1, -1))) card_embs = torch.cat(card_embs, dim=1) x = F.relu(self.card1(card_embs)) x = F.relu(self.card2(x)) diff --git a/PokerPlus/DeepCFR/utils.py b/PokerPlus/DeepCFR/utils.py index 7c278ef..9546eb1 100644 --- a/PokerPlus/DeepCFR/utils.py +++ b/PokerPlus/DeepCFR/utils.py @@ -1,5 +1,6 @@ from texasholdem.card.deck import Deck from texasholdem.card.card import Card +from texasholdem.game.game import TexasHoldEm d = Deck() card_to_int = {} @@ -17,4 +18,14 @@ def get_opponent_player_num(current_player: int): if current_player == 1: return 0 else: - return 1 \ No newline at end of file + return 1 + + +def available_moves(game: TexasHoldEm, nb_actions: int): + actions = game.get_available_moves() + available_actions = [] + for index, action in enumerate(actions): + if index < nb_actions: + available_actions.append(action[1]) + else: + return available_actions diff --git a/random_test.py b/random_test.py index de7cff0..becbf8c 100644 --- a/random_test.py +++ b/random_test.py @@ -13,5 +13,5 @@ print("game available moves slicing : ", game.get_available_moves()[:5]) for move in game.get_available_moves(): print("move : ", move) - - gui.run_step() \ No newline at end of file + + gui.run_step() diff --git a/test.py b/test.py index 232c581..a89238b 100644 --- a/test.py +++ b/test.py @@ -1,12 +1,23 @@ from PokerPlus.DeepCFR.deep_cfr import deep_cfr, save_deep_cfr from texasholdem.game.game import TexasHoldEm + def main(): print("create game") game = TexasHoldEm(buyin=1500, big_blind=80, small_blind=40, max_players=2) game.start_hand() - save_deep_cfr("", "DeepCFR", nb_iterations=10000, nb_players=2, nb_game_tree_traversals=200, game=game, n_actions=3, n_card_types=4, n_bets=20) + save_deep_cfr( + "", + "DeepCFR", + nb_iterations=10000, + nb_players=2, + nb_game_tree_traversals=200, + game=game, + n_actions=3, + n_card_types=4, + n_bets=20, + ) if __name__ == "__main__": - main() \ No newline at end of file + main()