# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from numpy.random import random
import pennylane as qml
# import torch
from pennylane import numpy as pnp
import torch
import sys
import time
from pennylane.templates import BasicEntanglerLayers



def benchmark_optimization_torch(n_wires):

    n_steps = 3
    n_layers = 6
    params = random(size=(n_layers, n_wires))
    measurement = qml.expval(qml.PauliZ(0))

    diff_method = "best"  
    interface = "torch"
    device = qml.device("default.qubit", wires=n_wires)

    @qml.qnode(device, interface=interface, diff_method=diff_method)
    def circuit(params_):
        BasicEntanglerLayers(params_, wires=range(n_wires))
        measurement.queue()
        return measurement

    params = torch.tensor(params,requires_grad=True)
    opt = torch.optim.SGD([params], lr=0.1)

    def closure():
        opt.zero_grad()
        loss = circuit(params)
        loss.backward()
        return loss

    for i in range(n_steps):
        opt.step(closure)

if len(sys.argv) != 2:
    print("python benchmark.py ")
    sys.exit(1)

qubits = int(sys.argv[1])

start = time.time()
#PROFILE_BEGIN
benchmark_optimization_torch(qubits)
#PROFILE_END
end = time.time()
print("time =",(end-start))