From 08239a2181993426412cc30fa09ef47438d88f55 Mon Sep 17 00:00:00 2001
From: Xiaofei Wang <xiaofei@student.unimelb.edu.au>
Date: Wed, 11 Mar 2020 21:04:04 +1100
Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=B0=E6=96=87=E4=BB=B6?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 AZUL Learn Opponent/learn_opponent.py | 48 +++++++++++++++++++++++++++
 1 file changed, 48 insertions(+)
 create mode 100644 AZUL Learn Opponent/learn_opponent.py

diff --git a/AZUL Learn Opponent/learn_opponent.py b/AZUL Learn Opponent/learn_opponent.py
new file mode 100644
index 0000000..e3ff9df
--- /dev/null
+++ b/AZUL Learn Opponent/learn_opponent.py	
@@ -0,0 +1,48 @@
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import random
+
+from tensorflow import keras
+
+# Helper libraries
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+class Net_model:
+
+
+    def train(self, x, y):
+        model = keras.models.Sequential([
+            keras.layers.Flatten(input_shape=(150, 6)),
+            keras.layers.Dense(1024, activation='relu'),
+            # keras.layers.Dropout(0.2),
+            keras.layers.Dense(150, activation='softmax')
+        ])
+
+        model.compile(optimizer='adam',
+                      loss='sparse_categorical_crossentropy',
+                      metrics=['accuracy'])
+        train_num = len(x)//5*4
+        print('tn', train_num)
+        DATA_X = np.array(x)
+        DATA_Y = np.array(y)
+
+        TR_X, TE_X = DATA_X[:train_num], DATA_X[train_num:]
+        TR_Y, TE_Y = DATA_Y[:train_num], DATA_Y[train_num:]
+
+        cost = model.fit(TR_X, TR_Y, epochs=50)
+        print(cost)
+        cost = model.evaluate(TE_X, TE_Y)
+        print(cost)
+        Y_pred = model.predict(TE_X)
+
+        # print(TE_X[0])
+        # print(np.argmax(Y_pred[0]))
+        model.save('naive_model.h5')
+
+    def perdict(self, X):
+        model = keras.models.load_model('naive_model.h5')
+        Y_pred = model.predict(X)
+        return Y_pred
+
-- 
GitLab