From 238e791a7cd1d8d8c2488d5af4992f36c37c7414 Mon Sep 17 00:00:00 2001 From: lucascolas Date: Wed, 6 Sep 2023 21:54:48 -0400 Subject: [PATCH] change card types --- PokerPlus/DeepCFR/deep_cfr.py | 2 +- PokerPlus/DeepCFR/nn.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/PokerPlus/DeepCFR/deep_cfr.py b/PokerPlus/DeepCFR/deep_cfr.py index ada9ff0..6f68cff 100644 --- a/PokerPlus/DeepCFR/deep_cfr.py +++ b/PokerPlus/DeepCFR/deep_cfr.py @@ -33,7 +33,7 @@ def deep_cfr( nb_game_tree_traversals: int = 200, game: TexasHoldEm = None, n_actions: int = 3, - n_card_types: int = 4, + n_card_types: int = 52, n_bets: int = 20, ): """ diff --git a/PokerPlus/DeepCFR/nn.py b/PokerPlus/DeepCFR/nn.py index 70f1a8a..6a541f5 100644 --- a/PokerPlus/DeepCFR/nn.py +++ b/PokerPlus/DeepCFR/nn.py @@ -63,8 +63,11 @@ def forward(self, cards, bets): print("do embedding") print("card_group: ", card_group) if card_group.numel(): - card_embs.append(embedding(card_group.view(1, -1))) + card_embs.append(embedding(card_group.view(-1, 1))) + card_embs = torch.cat(card_embs, dim=1) + + print("cards embs shape : ", card_embs.shape) x = F.relu(self.card1(card_embs)) x = F.relu(self.card2(x)) x = F.relu(self.card3(x))