/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sshd.common.session.helpers;

import java.io.IOException;
import java.io.Serializable;
import java.net.ProtocolException;
import java.security.GeneralSecurityException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;
import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.future.CancelOption;
import org.apache.sshd.common.future.DefaultKeyExchangeFuture;
import org.apache.sshd.common.future.SshFutureListener;
import org.apache.sshd.common.io.AbstractIoWriteFuture;
import org.apache.sshd.common.io.IoWriteFuture;
import org.apache.sshd.common.kex.KexState;
import org.apache.sshd.common.session.helpers.AbstractSession;
import org.apache.sshd.common.session.helpers.PendingWriteFuture;
import org.apache.sshd.common.util.ExceptionUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.threads.ThreadUtils;
import org.slf4j.Logger;

public class KeyExchangeMessageHandler {
    protected final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(false);
    protected final ExecutorService flushRunner = Executors.newSingleThreadExecutor();
    protected final AbstractSession session;
    protected final Logger log;
    protected final Queue<PendingWriteFuture> pendingPackets = new ConcurrentLinkedQueue<PendingWriteFuture>();
    protected final AtomicBoolean kexFlushed = new AtomicBoolean(true);
    protected final AtomicBoolean shutDown = new AtomicBoolean();
    protected final AtomicReference<DefaultKeyExchangeFuture> kexFlushedFuture = new AtomicReference();

    public KeyExchangeMessageHandler(AbstractSession session, Logger log) {
        this.session = Objects.requireNonNull(session);
        this.log = Objects.requireNonNull(log);
        DefaultKeyExchangeFuture initialFuture = new DefaultKeyExchangeFuture(session.toString(), session.getFutureLock());
        initialFuture.setValue(Boolean.TRUE);
        this.kexFlushedFuture.set(initialFuture);
    }

    public void updateState(Runnable update) {
        this.updateState(() -> {
            update.run();
            return null;
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public <V> V updateState(Supplier<V> update) {
        boolean locked = false;
        if (this.lock.getReadHoldCount() == 0) {
            this.lock.writeLock().lock();
            locked = true;
        }
        try {
            V v = update.get();
            return v;
        }
        finally {
            if (locked) {
                this.lock.writeLock().unlock();
            }
        }
    }

    public DefaultKeyExchangeFuture initNewKeyExchange() {
        return this.updateState(() -> {
            this.kexFlushed.set(false);
            return this.kexFlushedFuture.getAndSet(new DefaultKeyExchangeFuture(this.session.toString(), this.session.getFutureLock()));
        });
    }

    public AbstractMap.SimpleImmutableEntry<Integer, DefaultKeyExchangeFuture> terminateKeyExchange() {
        return this.updateState(() -> {
            int numPending = this.pendingPackets.size();
            if (numPending == 0) {
                this.kexFlushed.set(true);
            }
            return new AbstractMap.SimpleImmutableEntry<Integer, DefaultKeyExchangeFuture>(numPending, this.kexFlushedFuture.get());
        });
    }

    public void shutdown() {
        this.shutDown.set(true);
        AbstractMap.SimpleImmutableEntry items = this.updateState(() -> {
            this.kexFlushed.set(true);
            return new AbstractMap.SimpleImmutableEntry<Integer, DefaultKeyExchangeFuture>(this.pendingPackets.size(), this.kexFlushedFuture.get());
        });
        ((DefaultKeyExchangeFuture)items.getValue()).setValue((Integer)items.getKey() == 0);
        this.flushRunner.shutdownNow();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public IoWriteFuture writePacket(Buffer buffer, long timeout, TimeUnit unit) throws IOException {
        byte[] bufData = buffer.array();
        int cmd = bufData[buffer.rpos()] & 0xFF;
        boolean enqueued = false;
        boolean isLowLevelMessage = cmd <= 49 && cmd != 5 && cmd != 6;
        IoWriteFuture future = null;
        try {
            if (isLowLevelMessage) {
                future = this.session.doWritePacket(buffer);
            } else {
                future = this.writeOrEnqueue(cmd, buffer, timeout, unit);
                enqueued = future instanceof PendingWriteFuture;
            }
        }
        finally {
            this.session.resetIdleTimeout();
        }
        if (!enqueued) {
            try {
                this.session.checkRekey();
            }
            catch (GeneralSecurityException e) {
                if (this.log.isDebugEnabled()) {
                    this.log.debug("writePacket({}) failed ({}) to check re-key: {}", new Object[]{this.session, e.getClass().getSimpleName(), e.getMessage(), e});
                }
                throw (ProtocolException)ValidateUtils.initializeExceptionCause((Throwable)new ProtocolException("Failed (" + e.getClass().getSimpleName() + ") to check re-key necessity: " + e.getMessage()), (Throwable)e);
            }
            catch (Exception e) {
                ExceptionUtils.rethrowAsIoException((Throwable)e);
            }
        }
        return future;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected IoWriteFuture writeOrEnqueue(int cmd, Buffer buffer, long timeout, TimeUnit unit) throws IOException {
        boolean holdsFutureLock = Thread.holdsLock(this.session.getFutureLock());
        while (true) {
            DefaultKeyExchangeFuture block;
            block13: {
                block = null;
                this.lock.readLock().lock();
                try {
                    boolean kexDone;
                    if (this.shutDown.get()) {
                        throw new SshException("Write attempt on closing session: " + SshConstants.getCommandMessageName((int)cmd));
                    }
                    KexState state = this.session.kexState.get();
                    boolean bl = kexDone = KexState.DONE.equals((Object)state) || KexState.KEYS.equals((Object)state);
                    if (kexDone && this.kexFlushed.get()) {
                        IoWriteFuture ioWriteFuture = this.session.doWritePacket(buffer);
                        return ioWriteFuture;
                    }
                    if (!holdsFutureLock && this.isBlockAllowed(cmd)) {
                        block = this.kexFlushedFuture.get();
                        break block13;
                    }
                    if (kexDone && this.log.isDebugEnabled()) {
                        this.log.debug("writeOrEnqueue({})[{}]: Queuing packet while flushing", (Object)this.session, (Object)SshConstants.getCommandMessageName((int)cmd));
                    }
                    PendingWriteFuture pendingWriteFuture = this.enqueuePendingPacket(cmd, buffer);
                    return pendingWriteFuture;
                }
                finally {
                    this.lock.readLock().unlock();
                }
            }
            if (block == null) continue;
            if (timeout <= 0L || unit == null) {
                if (this.log.isDebugEnabled()) {
                    this.log.debug("writeOrEnqueue({})[{}]: Blocking thread {} until KEX is over", new Object[]{this.session, SshConstants.getCommandMessageName((int)cmd), Thread.currentThread()});
                }
                block.await(new CancelOption[0]);
            } else {
                if (this.log.isDebugEnabled()) {
                    this.log.debug("writeOrEnqueue({})[{}]: Blocking thread {} until KEX is over or timeout {} {}", new Object[]{this.session, SshConstants.getCommandMessageName((int)cmd), Thread.currentThread(), timeout, unit});
                }
                block.await(timeout, unit, new CancelOption[0]);
            }
            if (!this.log.isDebugEnabled()) continue;
            this.log.debug("writeOrEnqueue({})[{}]: Thread {} awakens after KEX done", new Object[]{this.session, SshConstants.getCommandMessageName((int)cmd), Thread.currentThread()});
        }
    }

    protected boolean isBlockAllowed(int cmd) {
        boolean isChannelData = cmd == 94 || cmd == 95;
        return isChannelData && !ThreadUtils.isInternalThread();
    }

    protected PendingWriteFuture enqueuePendingPacket(int cmd, Buffer buffer) {
        String cmdName = SshConstants.getCommandMessageName((int)cmd);
        PendingWriteFuture future = new PendingWriteFuture(cmdName, buffer);
        this.pendingPackets.add(future);
        int numPending = this.pendingPackets.size();
        if (this.log.isDebugEnabled()) {
            if (numPending == 1) {
                this.log.debug("enqueuePendingPacket({})[{}] Start flagging packets as pending until key exchange is done", (Object)this.session, (Object)cmdName);
            } else {
                this.log.debug("enqueuePendingPacket({})[{}] enqueued until key exchange is done (pending={})", new Object[]{this.session, cmdName, numPending});
            }
        }
        return future;
    }

    protected void flushQueue(DefaultKeyExchangeFuture flushDone) {
        this.flushRunner.submit(() -> {
            ArrayList<AbstractMap.SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> pendingFutures = new ArrayList<AbstractMap.SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>>();
            boolean allFlushed = false;
            DefaultKeyExchangeFuture newFuture = null;
            Serializable error = null;
            try {
                boolean warnedAboutChunkLimit = false;
                int lastSize = -1;
                int take = 2;
                while (!allFlushed) {
                    this.lock.writeLock().lock();
                    try {
                        PendingWriteFuture pending;
                        if (this.pendingPackets.isEmpty()) {
                            if (this.log.isDebugEnabled()) {
                                this.log.debug("flushQueue({}): All packets at end of KEX flushed", (Object)this.session);
                            }
                            this.kexFlushed.set(true);
                            allFlushed = true;
                            break;
                        }
                        if (!this.session.isOpen()) {
                            this.log.info("flushQueue({}): Session closed while flushing pending packets at end of KEX", (Object)this.session);
                            AbstractIoWriteFuture aborted = new AbstractIoWriteFuture(this.session, null){};
                            aborted.setValue((Object)new SshException("Session closed while flushing pending packets at end of KEX"));
                            this.drainQueueTo(pendingFutures, (IoWriteFuture)aborted);
                            this.kexFlushed.set(true);
                            error = Boolean.FALSE;
                            break;
                        }
                        DefaultKeyExchangeFuture currentFuture = this.kexFlushedFuture.get();
                        if (currentFuture != flushDone) {
                            if (this.log.isDebugEnabled()) {
                                this.log.debug("flushQueue({}): Stopping flushing pending packets", (Object)this.session);
                            }
                            newFuture = currentFuture;
                            break;
                        }
                        int newSize = this.pendingPackets.size();
                        if (lastSize < 0) {
                            this.log.info("flushQueue({}): {} pending packets to flush", (Object)this.session, (Object)newSize);
                        } else if (newSize >= lastSize) {
                            this.log.info("flushQueue({}): queue size before={} now={}", new Object[]{this.session, lastSize, newSize});
                            if (take < 64) {
                                take *= 2;
                            } else if (!warnedAboutChunkLimit) {
                                warnedAboutChunkLimit = true;
                                this.log.warn("flushQueue({}): maximum queue flush chunk of 64 reached", (Object)this.session);
                            }
                        }
                        lastSize = newSize;
                        if (this.log.isDebugEnabled()) {
                            this.log.debug("flushQueue({}): flushing {} packets", (Object)this.session, (Object)Math.min(lastSize, take));
                        }
                        for (int i = 0; i < take && (pending = this.pendingPackets.poll()) != null; ++i) {
                            IoWriteFuture written;
                            try {
                                if (this.log.isTraceEnabled()) {
                                    this.log.trace("flushQueue({}): Flushing a packet at end of KEX for {}", (Object)this.session, pending.getId());
                                }
                                written = this.session.doWritePacket(pending.getBuffer());
                            }
                            catch (Throwable e2) {
                                this.log.error("flushQueue({}): Exception while flushing packet at end of KEX for {}", new Object[]{this.session, pending.getId(), e2});
                                AbstractIoWriteFuture aborted = new AbstractIoWriteFuture(pending.getId(), null){};
                                aborted.setValue((Object)e2);
                                pendingFutures.add(new AbstractMap.SimpleImmutableEntry<PendingWriteFuture, 2>(pending, aborted));
                                this.drainQueueTo(pendingFutures, (IoWriteFuture)aborted);
                                this.kexFlushed.set(true);
                                error = e2;
                                this.lock.writeLock().unlock();
                                if (allFlushed) {
                                    flushDone.setValue(Boolean.TRUE);
                                } else if (error != null) {
                                    flushDone.setValue(error);
                                    if (error instanceof Throwable) {
                                        this.session.exceptionCaught((Throwable)error);
                                    }
                                } else if (newFuture != null) {
                                    newFuture.addListener(f -> {
                                        Throwable failed = f.getException();
                                        flushDone.setValue(failed != null ? failed : Boolean.TRUE);
                                    });
                                }
                                pendingFutures.forEach(e -> ((IoWriteFuture)e.getValue()).addListener((SshFutureListener)e.getKey()));
                                return;
                            }
                            pendingFutures.add(new AbstractMap.SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>(pending, written));
                            if (this.log.isTraceEnabled()) {
                                this.log.trace("flushQueue({}): Flushed a packet at end of KEX for {}", (Object)this.session, pending.getId());
                            }
                            this.session.resetIdleTimeout();
                        }
                        if (!this.pendingPackets.isEmpty()) continue;
                        if (this.log.isDebugEnabled()) {
                            this.log.debug("flushQueue({}): All packets at end of KEX flushed", (Object)this.session);
                        }
                        this.kexFlushed.set(true);
                        allFlushed = true;
                        break;
                    }
                    finally {
                        this.lock.writeLock().unlock();
                    }
                }
            }
            finally {
                if (allFlushed) {
                    flushDone.setValue(Boolean.TRUE);
                } else if (error != null) {
                    flushDone.setValue(error);
                    if (error instanceof Throwable) {
                        this.session.exceptionCaught((Throwable)error);
                    }
                } else if (newFuture != null) {
                    newFuture.addListener(f -> {
                        Throwable failed = f.getException();
                        flushDone.setValue(failed != null ? failed : Boolean.TRUE);
                    });
                }
                pendingFutures.forEach(e -> ((IoWriteFuture)e.getValue()).addListener((SshFutureListener)e.getKey()));
            }
        });
    }

    private void drainQueueTo(List<AbstractMap.SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> pendingAborted, IoWriteFuture aborted) {
        PendingWriteFuture pending = this.pendingPackets.poll();
        while (pending != null) {
            pendingAborted.add(new AbstractMap.SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>(pending, aborted));
            pending = this.pendingPackets.poll();
        }
    }
}

