/*
 * Decompiled with CFR 0.152.
 */
package org.forgerock.json.jose.jwe.handlers.encryption;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.util.Arrays;
import javax.crypto.KeyAgreement;
import javax.crypto.SecretKey;
import javax.security.auth.DestroyFailedException;
import javax.security.auth.Destroyable;
import org.forgerock.json.jose.exceptions.JweDecryptionException;
import org.forgerock.json.jose.exceptions.JweEncryptionException;
import org.forgerock.json.jose.exceptions.JweException;
import org.forgerock.json.jose.jwe.EncryptionMethod;
import org.forgerock.json.jose.jwe.JweAlgorithm;
import org.forgerock.json.jose.jwe.JweAlgorithmType;
import org.forgerock.json.jose.jwe.JweEncryption;
import org.forgerock.json.jose.jwe.JweHeader;
import org.forgerock.json.jose.jwe.handlers.encryption.EncryptionHandler;
import org.forgerock.json.jose.jwk.EcJWK;
import org.forgerock.json.jose.jwk.KeyUse;
import org.forgerock.json.jose.jws.SupportedEllipticCurve;
import org.forgerock.json.jose.utils.Utils;
import org.forgerock.util.Reject;
import org.forgerock.util.annotations.VisibleForTesting;
import org.forgerock.util.encode.Base64url;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class ECDHEncryptionHandler
implements EncryptionHandler {
    private static final String KEY_GENERATION_ALGORITHM = "EC";
    private static final String KEY_AGREEMENT_ALGORITHM = "ECDH";
    private static final String HASH_ALGORITHM = "SHA-256";
    private static final double HASH_LENGTH = 256.0;
    private static final Logger logger = LoggerFactory.getLogger(ECDHEncryptionHandler.class);
    private final EncryptionHandler keyWrappingHandler;
    private final int keySize;
    private final String algorithmId;
    private final String keyAlgorithm;

    private ECDHEncryptionHandler(EncryptionHandler keyWrappingHandler, String keyAlgorithm, int keySize, String algorithmId) {
        this.keyWrappingHandler = keyWrappingHandler;
        this.keySize = keySize;
        this.algorithmId = algorithmId;
        this.keyAlgorithm = keyAlgorithm;
    }

    public static ECDHEncryptionHandler getInstance(EncryptionHandler keyWrappingHandler, JweAlgorithm algorithm, EncryptionMethod encryptionMethod) {
        String keyAlgorithm;
        int keySize;
        String algorithmId = algorithm.getJwaAlgorithmName();
        switch (algorithm) {
            case ECDH_ES: {
                keySize = encryptionMethod.getKeySize();
                algorithmId = encryptionMethod.getJweStandardName();
                keyAlgorithm = encryptionMethod.getEncryptionAlgorithm();
                break;
            }
            case ECDH_ES_A128KW: {
                keySize = 128;
                keyAlgorithm = "AES";
                break;
            }
            case ECDH_ES_A192KW: {
                keySize = 192;
                keyAlgorithm = "AES";
                break;
            }
            case ECDH_ES_A256KW: {
                keySize = 256;
                keyAlgorithm = "AES";
                break;
            }
            default: {
                throw new IllegalArgumentException("Invalid ECDH encryption algorithm: " + algorithm);
            }
        }
        return new ECDHEncryptionHandler(keyWrappingHandler, keyAlgorithm, keySize, algorithmId);
    }

    @Override
    public Key getContentEncryptionKey() {
        return new ECDHDerivedKey((SecretKey)this.keyWrappingHandler.getContentEncryptionKey(), this.keyAlgorithm, this.keySize);
    }

    @Override
    public byte[] generateJWEEncryptedKey(Key key, Key ephemeralKey, JweHeader header) {
        ECPublicKey theirPublicKey = (ECPublicKey)key;
        SupportedEllipticCurve curve = SupportedEllipticCurve.forKey(theirPublicKey);
        if (!EcJWK.isPublicKeyValid(theirPublicKey, curve)) {
            throw new JweEncryptionException("Invalid public key");
        }
        ECDHDerivedKey keyAgreement = (ECDHDerivedKey)ephemeralKey;
        keyAgreement.setTheirPublicKey(theirPublicKey);
        EcJWK jwk = (EcJWK)header.getEphemeralPublicKey();
        if (jwk != null) {
            if (jwk.getEllipticCurve() != curve) {
                throw new JweEncryptionException("Ephemeral key is on different curve to public key");
            }
            keyAgreement.setOurKeyPair(jwk.toKeyPair());
        }
        jwk = new EcJWK(keyAgreement.getOurPublicKey(), null, null);
        header.setEphemeralPublicKey(jwk);
        if (header.getAgreementPartyUInfo() == null) {
            header.setAgreementPartyUInfo(Base64url.encode(Utils.sha256(keyAgreement.getOurPublicKey().getEncoded())));
        }
        if (header.getAgreementPartyVInfo() == null) {
            header.setAgreementPartyVInfo(Base64url.encode(Utils.sha256(theirPublicKey.getEncoded())));
        }
        keyAgreement.setOtherInfo(ECDHEncryptionHandler.generateOtherInfo(header, this.algorithmId, this.keySize));
        DestroyableSecretKey derivedKey = keyAgreement.getDerivedKey();
        return this.keyWrappingHandler.generateJWEEncryptedKey(derivedKey, keyAgreement.getContentEncryptionKey(), header);
    }

    /*
     * Exception decompiling
     */
    static byte[] generateOtherInfo(JweHeader header, String algorithmId, int keySize) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 3 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private static byte[] base64urlDecode(String data) {
        if (data == null || data.isEmpty()) {
            return new byte[0];
        }
        return Base64url.decode(data);
    }

    @Override
    public byte[] generateInitialisationVector() {
        return this.keyWrappingHandler.generateInitialisationVector();
    }

    @Override
    public JweEncryption encryptPlaintext(Key contentEncryptionKey, byte[] initialisationVector, byte[] plaintext, byte[] additionalAuthenticatedData) {
        try (ECDHDerivedKey agreedKey = (ECDHDerivedKey)contentEncryptionKey;){
            Key encryptionKey = agreedKey.getContentEncryptionKey();
            if (encryptionKey == null) {
                encryptionKey = agreedKey.getDerivedKey();
            }
            JweEncryption jweEncryption = this.keyWrappingHandler.encryptPlaintext(encryptionKey, initialisationVector, plaintext, additionalAuthenticatedData);
            return jweEncryption;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Key decryptContentEncryptionKey(Key key, byte[] encryptedContentEncryptionKey, JweHeader header) {
        JweAlgorithm algorithm;
        ECPrivateKey ourPrivateKey = (ECPrivateKey)key;
        EcJWK jwk = (EcJWK)header.getEphemeralPublicKey();
        if (!jwk.isPublicKeyValid()) {
            throw new JweDecryptionException();
        }
        if (jwk.getJwaAlgorithm() instanceof JweAlgorithm && (algorithm = (JweAlgorithm)jwk.getJwaAlgorithm()).getAlgorithmType() != JweAlgorithmType.ECDH_ES) {
            throw new JweDecryptionException();
        }
        if (jwk.getUse() != null && jwk.getUse() != KeyUse.ENC) {
            throw new JweDecryptionException();
        }
        ECPublicKey theirPublicKey = jwk.toECPublicKey();
        if (SupportedEllipticCurve.forKey(theirPublicKey) != SupportedEllipticCurve.forKey(ourPrivateKey)) {
            logger.debug("ECDH: Cannot derive content key: ephemeral key on different curve to our key");
            throw new JweDecryptionException();
        }
        ECDHDerivedKey keyAgreement = new ECDHDerivedKey(ourPrivateKey, this.keyAlgorithm, this.keySize);
        keyAgreement.setTheirPublicKey(theirPublicKey);
        keyAgreement.setOtherInfo(ECDHEncryptionHandler.generateOtherInfo(header, this.algorithmId, this.keySize));
        DestroyableSecretKey derivedKey = keyAgreement.getDerivedKey();
        Key decryptedKey = null;
        try {
            Key key2 = decryptedKey = this.keyWrappingHandler.decryptContentEncryptionKey(derivedKey, encryptedContentEncryptionKey, header);
            return key2;
        }
        finally {
            if (decryptedKey != derivedKey) {
                derivedKey.destroy();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public byte[] decryptCiphertext(Key contentEncryptionKey, byte[] initialisationVector, byte[] ciphertext, byte[] authenticationTag, byte[] additionalAuthenticatedData) {
        try {
            byte[] byArray = this.keyWrappingHandler.decryptCiphertext(contentEncryptionKey, initialisationVector, ciphertext, authenticationTag, additionalAuthenticatedData);
            return byArray;
        }
        finally {
            ECDHEncryptionHandler.destroyOrLog(contentEncryptionKey, "content decryption key");
        }
    }

    private static void destroyOrLog(Key key, String description) {
        if (key instanceof Destroyable) {
            try {
                ((Destroyable)((Object)key)).destroy();
            }
            catch (DestroyFailedException e) {
                logger.debug("Unable to destroy key material for {}: {}", (Object)description, (Object)e.getMessage());
            }
        }
    }

    @VisibleForTesting
    static class DestroyableSecretKey
    implements SecretKey,
    Destroyable {
        private static final long serialVersionUID = 1L;
        private volatile byte[] keyBytes;
        private final String algorithm;

        DestroyableSecretKey(byte[] bytes, String algorithm) {
            this.keyBytes = bytes;
            this.algorithm = algorithm;
        }

        @Override
        public String getAlgorithm() {
            return this.algorithm;
        }

        @Override
        public String getFormat() {
            return "RAW";
        }

        @Override
        public byte[] getEncoded() {
            return this.keyBytes;
        }

        @Override
        public void destroy() {
            byte[] tmp = this.keyBytes;
            if (tmp != null) {
                Arrays.fill(tmp, (byte)0);
                this.keyBytes = null;
            }
        }

        @Override
        public boolean isDestroyed() {
            return this.keyBytes == null;
        }
    }

    @VisibleForTesting
    static class ECDHDerivedKey
    implements Key,
    Destroyable,
    AutoCloseable {
        private static final long serialVersionUID = 1L;
        private final String keyAlgorithm;
        private final int keySize;
        private final Key contentEncryptionKey;
        private KeyPair ourKeyPair;
        private ECPublicKey theirPublicKey;
        private DestroyableSecretKey derivedKey;
        private byte[] otherInfo;

        ECDHDerivedKey(SecretKey contentEncryptionKey, String keyAlgorithm, int keySize) {
            this.contentEncryptionKey = contentEncryptionKey;
            this.keyAlgorithm = keyAlgorithm;
            this.keySize = keySize;
        }

        ECDHDerivedKey(ECPrivateKey privateKey, String keyAlgorithm, int keySize) {
            this((SecretKey)null, keyAlgorithm, keySize);
            this.ourKeyPair = new KeyPair(null, privateKey);
        }

        Key getContentEncryptionKey() {
            return this.contentEncryptionKey;
        }

        void setTheirPublicKey(ECPublicKey publicKey) {
            this.theirPublicKey = publicKey;
            this.derivedKey = null;
        }

        private static boolean allZero(byte[] data) {
            byte result = 0;
            for (byte b : data) {
                result = (byte)(result | b);
            }
            return result == 0;
        }

        void setOtherInfo(byte[] otherInfo) {
            this.otherInfo = otherInfo;
        }

        private KeyPair getOurKeyPair() {
            if (this.ourKeyPair == null) {
                try {
                    SupportedEllipticCurve curve = SupportedEllipticCurve.forKey(this.theirPublicKey);
                    KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(ECDHEncryptionHandler.KEY_GENERATION_ALGORITHM);
                    keyPairGenerator.initialize(curve.getParameters());
                    this.ourKeyPair = keyPairGenerator.generateKeyPair();
                }
                catch (GeneralSecurityException e) {
                    throw new JweEncryptionException(e);
                }
            }
            return this.ourKeyPair;
        }

        void setOurKeyPair(KeyPair keyPair) {
            this.ourKeyPair = keyPair;
        }

        ECPublicKey getOurPublicKey() {
            return (ECPublicKey)this.getOurKeyPair().getPublic();
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        DestroyableSecretKey getDerivedKey() {
            if (this.derivedKey != null) return this.derivedKey;
            byte[] sharedSecret = null;
            try {
                KeyAgreement keyAgreement = KeyAgreement.getInstance(ECDHEncryptionHandler.KEY_AGREEMENT_ALGORITHM);
                keyAgreement.init(this.getOurKeyPair().getPrivate());
                keyAgreement.doPhase(this.theirPublicKey, true);
                sharedSecret = keyAgreement.generateSecret();
                if (ECDHDerivedKey.allZero(sharedSecret)) {
                    throw new JweEncryptionException("ECDH produced all-zero shared secret");
                }
                this.derivedKey = ECDHDerivedKey.concatKdf(sharedSecret, this.keyAlgorithm, this.keySize, this.otherInfo);
                if (sharedSecret == null) return this.derivedKey;
            }
            catch (InvalidKeyException | NoSuchAlgorithmException e) {
                try {
                    throw new JweException(e);
                }
                catch (Throwable throwable) {
                    if (sharedSecret == null) throw throwable;
                    Arrays.fill(sharedSecret, (byte)0);
                    throw throwable;
                }
            }
            Arrays.fill(sharedSecret, (byte)0);
            return this.derivedKey;
        }

        private static DestroyableSecretKey concatKdf(byte[] sharedSecret, String keyAlgorithm, int keySize, byte[] otherInfo) {
            Reject.ifFalse(keySize >= 128 && keySize <= 512);
            try {
                MessageDigest messageDigest = MessageDigest.getInstance(ECDHEncryptionHandler.HASH_ALGORITHM);
                int repetitions = (int)Math.ceil((double)keySize / 256.0);
                ByteBuffer buffer = ByteBuffer.allocate(4 + sharedSecret.length + otherInfo.length).order(ByteOrder.BIG_ENDIAN);
                ByteBuffer keyBuffer = ByteBuffer.allocate((keySize + 7) / 8);
                for (int i = 1; i <= repetitions; ++i) {
                    buffer.rewind();
                    buffer.putInt(i);
                    buffer.put(sharedSecret);
                    buffer.put(otherInfo);
                    buffer.flip();
                    messageDigest.update(buffer);
                    byte[] hash = messageDigest.digest();
                    keyBuffer.put(hash, 0, Math.min(hash.length, keyBuffer.remaining()));
                }
                return new DestroyableSecretKey(keyBuffer.array(), keyAlgorithm);
            }
            catch (NoSuchAlgorithmException e) {
                throw new IllegalStateException(e);
            }
        }

        @Override
        public String getAlgorithm() {
            return this.contentEncryptionKey.getAlgorithm();
        }

        @Override
        public String getFormat() {
            return this.contentEncryptionKey.getFormat();
        }

        @Override
        public byte[] getEncoded() {
            return this.contentEncryptionKey.getEncoded();
        }

        @Override
        public void destroy() {
            this.destroyEphemeralKey();
            ECDHEncryptionHandler.destroyOrLog(this.derivedKey, "derived key");
            this.derivedKey = null;
            ECDHEncryptionHandler.destroyOrLog(this.contentEncryptionKey, "content encryption key");
        }

        void destroyEphemeralKey() {
            if (this.ourKeyPair != null && this.ourKeyPair.getPublic() != null) {
                ECDHEncryptionHandler.destroyOrLog(this.ourKeyPair.getPrivate(), "ephemeral private key");
                this.ourKeyPair = null;
            }
        }

        @Override
        public boolean isDestroyed() {
            return this.ourKeyPair == null || this.ourKeyPair.getPrivate().isDestroyed();
        }

        @Override
        public void close() {
            this.destroy();
        }
    }
}

