rlttt.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. import json
  2. import numpy as np
  3. import pickle
  4. BOARD_ROWS = 3
  5. BOARD_COLS = 3
  6. class State:
  7. def __init__(self, p1, p2):
  8. self.board = np.zeros((BOARD_ROWS, BOARD_COLS))
  9. self.p1 = p1
  10. self.p2 = p2
  11. self.isEnd = False
  12. self.boardHash = None
  13. # init p1 plays first
  14. self.playerSymbol = 1
  15. self.wins = {}
  16. self.wins[p1.name] = 0
  17. self.wins[p2.name] = 0
  18. self.wins["tie"] = 0
  19. # get unique hash of current board state
  20. def getHash(self):
  21. self.boardHash = str(self.board.reshape(BOARD_COLS * BOARD_ROWS))
  22. return self.boardHash
  23. def winner(self):
  24. # row
  25. for i in range(BOARD_ROWS):
  26. if sum(self.board[i, :]) == 3:
  27. self.isEnd = True
  28. self.wins[self.p1.name] += 1
  29. return 1
  30. if sum(self.board[i, :]) == -3:
  31. self.isEnd = True
  32. self.wins[self.p2.name] += 1
  33. return -1
  34. # col
  35. for i in range(BOARD_COLS):
  36. if sum(self.board[:, i]) == 3:
  37. self.isEnd = True
  38. self.wins[self.p1.name] += 1
  39. return 1
  40. if sum(self.board[:, i]) == -3:
  41. self.isEnd = True
  42. self.wins[self.p2.name] += 1
  43. return -1
  44. # diagonal
  45. diag_sum1 = sum([self.board[i, i] for i in range(BOARD_COLS)])
  46. diag_sum2 = sum([self.board[i, BOARD_COLS - i - 1] for i in range(BOARD_COLS)])
  47. diag_sum = max(abs(diag_sum1), abs(diag_sum2))
  48. if diag_sum == 3:
  49. self.isEnd = True
  50. if diag_sum1 == 3 or diag_sum2 == 3:
  51. self.wins[self.p1.name] += 1
  52. return 1
  53. else:
  54. self.wins[self.p2.name] += 1
  55. return -1
  56. # tie
  57. # no available positions
  58. if len(self.availablePositions()) == 0:
  59. self.isEnd = True
  60. self.wins["tie"] += 1
  61. return 0
  62. # not end
  63. self.isEnd = False
  64. return None
  65. def availablePositions(self):
  66. positions = []
  67. for i in range(BOARD_ROWS):
  68. for j in range(BOARD_COLS):
  69. if self.board[i, j] == 0:
  70. positions.append((i, j)) # need to be tuple
  71. return positions
  72. def updateState(self, position):
  73. self.board[position] = self.playerSymbol
  74. # switch to another player
  75. self.playerSymbol = -1 if self.playerSymbol == 1 else 1
  76. # only when game ends
  77. def giveReward(self):
  78. result = self.winner()
  79. # backpropagate reward
  80. if result == 1:
  81. self.p1.feedReward(1)
  82. self.p2.feedReward(0)
  83. elif result == -1:
  84. self.p1.feedReward(0)
  85. self.p2.feedReward(1)
  86. else:
  87. self.p1.feedReward(0.1)
  88. self.p2.feedReward(0.5)
  89. # board reset
  90. def reset(self):
  91. self.board = np.zeros((BOARD_ROWS, BOARD_COLS))
  92. self.boardHash = None
  93. self.isEnd = False
  94. self.playerSymbol = 1
  95. def play(self, rounds=100):
  96. for i in range(rounds):
  97. if i % 1000 == 0:
  98. print(f"Rounds {i}; Current Results: {json.dumps(self.wins)}")
  99. while not self.isEnd:
  100. # Player 1
  101. positions = self.availablePositions()
  102. p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)
  103. # take action and upate board state
  104. self.updateState(p1_action)
  105. board_hash = self.getHash()
  106. self.p1.addState(board_hash)
  107. # check board status if it is end
  108. win = self.winner()
  109. if win is not None:
  110. # self.showBoard()
  111. # ended with p1 either win or draw
  112. self.giveReward()
  113. self.p1.reset()
  114. self.p2.reset()
  115. self.reset()
  116. break
  117. else:
  118. # Player 2
  119. positions = self.availablePositions()
  120. p2_action = self.p2.chooseAction(positions, self.board, self.playerSymbol)
  121. self.updateState(p2_action)
  122. board_hash = self.getHash()
  123. self.p2.addState(board_hash)
  124. win = self.winner()
  125. if win is not None:
  126. # self.showBoard()
  127. # ended with p2 either win or draw
  128. self.giveReward()
  129. self.p1.reset()
  130. self.p2.reset()
  131. self.reset()
  132. break
  133. # play with human
  134. def play2(self):
  135. while not self.isEnd:
  136. # Player 1
  137. positions = self.availablePositions()
  138. p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)
  139. # take action and upate board state
  140. self.updateState(p1_action)
  141. self.showBoard()
  142. # check board status if it is end
  143. win = self.winner()
  144. if win is not None:
  145. if win == 1:
  146. print(self.p1.name, "wins!")
  147. else:
  148. print("tie!")
  149. self.reset()
  150. break
  151. else:
  152. # Player 2
  153. positions = self.availablePositions()
  154. p2_action = self.p2.chooseAction(positions)
  155. self.updateState(p2_action)
  156. self.showBoard()
  157. win = self.winner()
  158. if win is not None:
  159. if win == -1:
  160. print(self.p2.name, "wins!")
  161. else:
  162. print("tie!")
  163. self.reset()
  164. break
  165. def showBoard(self):
  166. # p1: x p2: o
  167. for i in range(0, BOARD_ROWS):
  168. print('-------------')
  169. out = '| '
  170. for j in range(0, BOARD_COLS):
  171. if self.board[i, j] == 1:
  172. token = 'x'
  173. if self.board[i, j] == -1:
  174. token = 'o'
  175. if self.board[i, j] == 0:
  176. token = ' '
  177. out += token + ' | '
  178. print(out)
  179. print('-------------')
  180. class Player:
  181. def __init__(self, name, exp_rate=0.3):
  182. self.name = name
  183. self.states = [] # record all positions taken
  184. self.lr = 0.2
  185. self.exp_rate = exp_rate
  186. self.decay_gamma = 0.9
  187. self.states_value = {} # state -> value
  188. def getHash(self, board):
  189. boardHash = str(board.reshape(BOARD_COLS * BOARD_ROWS))
  190. return boardHash
  191. def chooseAction(self, positions, current_board, symbol):
  192. if np.random.uniform(0, 1) <= self.exp_rate:
  193. # take random action
  194. idx = np.random.choice(len(positions))
  195. action = positions[idx]
  196. else:
  197. value_max = -999
  198. for p in positions:
  199. next_board = current_board.copy()
  200. next_board[p] = symbol
  201. next_boardHash = self.getHash(next_board)
  202. value = 0 if self.states_value.get(next_boardHash) is None else self.states_value.get(next_boardHash)
  203. # print("value", value)
  204. if value >= value_max:
  205. value_max = value
  206. action = p
  207. # print("{} takes action {}".format(self.name, action))
  208. return action
  209. # append a hash state
  210. def addState(self, state):
  211. self.states.append(state)
  212. # at the end of game, backpropagate and update states value
  213. def feedReward(self, reward):
  214. for st in reversed(self.states):
  215. if self.states_value.get(st) is None:
  216. self.states_value[st] = 0
  217. self.states_value[st] += self.lr * (self.decay_gamma * reward - self.states_value[st])
  218. reward = self.states_value[st]
  219. def reset(self):
  220. self.states = []
  221. def savePolicy(self):
  222. fw = open('policy_' + str(self.name), 'wb')
  223. pickle.dump(self.states_value, fw)
  224. fw.close()
  225. def loadPolicy(self, file):
  226. fr = open(file, 'rb')
  227. self.states_value = pickle.load(fr)
  228. fr.close()
  229. class HumanPlayer:
  230. def __init__(self, name):
  231. self.name = name
  232. def chooseAction(self, positions):
  233. while True:
  234. row = int(input("Input your action row:"))
  235. col = int(input("Input your action col:"))
  236. action = (row, col)
  237. if action in positions:
  238. return action
  239. # append a hash state
  240. def addState(self, state):
  241. pass
  242. # at the end of game, backpropagate and update states value
  243. def feedReward(self, reward):
  244. pass
  245. def reset(self):
  246. pass
  247. if __name__ == "__main__":
  248. # training
  249. p1 = Player("p1")
  250. p2 = Player("p2")
  251. st = State(p1, p2)
  252. print("training...")
  253. st.play(50000)
  254. p1.savePolicy()
  255. p2.savePolicy()
  256. # play with human
  257. p1 = Player("computer", exp_rate=0)
  258. p1.loadPolicy("policy_p1")
  259. p2 = HumanPlayer("human")
  260. st = State(p1, p2)
  261. st.play2()