import networkx as nx
import pennylane as qml
from pennylane import numpy as pnp
import numpy as np
from numpy.random import random
from pennylane.templates import BasicEntanglerLayers
import sys
import time
import torch

def benchmark_machine_learning_torch(n_features,n_samples):
    """Trains a hybrid quantum-classical machine learning pipeline.

    The data is generated from Gaussian blobs. The model first multiplies the input vectors with
    a weight matrix, and then feeds it into a quantum model that uses AngleEmbedding for the encoding
    and BasicEntanglingLayers as the trainable circuit. The number of qubits and layers correspond to the
    number of features. Training uses gradient descent with 20 steps.
    """
     
    diff_method = "best"
    interface = "torch"
    device = qml.device("default.qubit", wires=n_features)

    # data
    x0 = np.random.normal(loc=-1, scale=1, size=(n_samples // 2, n_features))
    x1 = np.random.normal(loc=1, scale=1, size=(n_samples // 2, n_features))
    x = np.concatenate([x0, x1], axis=0)
    y = np.concatenate([-np.ones(50), np.ones(50)], axis=0)
    data = list(zip(x, y))
    
    data = [
        [torch.tensor(x, dtype=torch.double), torch.tensor(y, dtype=torch.double)] for x, y in data
    ]
    
    @qml.qnode(device, interface=interface, diff_method=diff_method)
    def quantum_model(x, params):
        qml.templates.AngleEmbedding(x, wires=range(len(x)))
        qml.templates.BasicEntanglerLayers(params, wires=range(len(x)))
        return qml.expval(qml.PauliZ(0))

    
    def hybrid_model(x, w_quantum, w_classical):
        transformed_x = torch.matmul(w_classical, x)
        return quantum_model(transformed_x, w_quantum)

    def average_loss(w_quantum, w_classical):
        c = torch.tensor(0, dtype=torch.double)
        for x, y in data:
            prediction = hybrid_model(x, w_quantum, w_classical)
            c += (prediction - y) ** 2
        return c / len(data)

    n_features = len(data[0][0])

    w_quantum = torch.tensor(
        random(size=(n_features, n_features)), requires_grad=True, dtype=torch.double
    )
    w_classical = torch.tensor(
        random(size=(n_features, n_features)), requires_grad=True, dtype=torch.double
    )

    for _ in range(20):
        loss = average_loss(w_quantum, w_classical)
        loss.backward()

        w_quantum.data -= 0.05 * w_quantum.grad
        w_classical.data -= 0.05 * w_classical.grad
        w_quantum.grad = None
        w_classical.grad = None


if len(sys.argv) != 2:
    print("python benchmark.py <num_samples>")
    sys.exit(1)

n_features = 10
n_samples = int(sys.argv[1])
      
start = time.time()
#PROFILE_BEGIN
benchmark_machine_learning_torch(n_features,n_samples)
#PROFILE_END
end = time.time()
print("time =",(end-start))
