| /* |
| * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved. |
| * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. |
| * |
| * This code is free software; you can redistribute it and/or modify it |
| * under the terms of the GNU General Public License version 2 only, as |
| * published by the Free Software Foundation. |
| * |
| * This code is distributed in the hope that it will be useful, but WITHOUT |
| * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or |
| * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License |
| * version 2 for more details (a copy is included in the LICENSE file that |
| * accompanied this code). |
| * |
| * You should have received a copy of the GNU General Public License version |
| * 2 along with this work; if not, write to the Free Software Foundation, |
| * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. |
| * |
| * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA |
| * or visit www.oracle.com if you need additional information or have any |
| * questions. |
| */ |
| |
| import java.io.Closeable; |
| import java.io.IOException; |
| import java.io.ObjectInputStream; |
| import java.io.ObjectOutputStream; |
| import java.io.Serializable; |
| import java.net.ServerSocket; |
| import java.net.Socket; |
| import java.net.UnknownHostException; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.HashMap; |
| import java.util.Map; |
| import java.util.StringJoiner; |
| import javax.security.auth.callback.Callback; |
| import javax.security.auth.callback.CallbackHandler; |
| import javax.security.auth.callback.NameCallback; |
| import javax.security.auth.callback.PasswordCallback; |
| import javax.security.auth.callback.UnsupportedCallbackException; |
| import javax.security.sasl.AuthorizeCallback; |
| import javax.security.sasl.RealmCallback; |
| import javax.security.sasl.RealmChoiceCallback; |
| import javax.security.sasl.Sasl; |
| import javax.security.sasl.SaslClient; |
| import javax.security.sasl.SaslException; |
| import javax.security.sasl.SaslServer; |
| |
| /* |
| * @test |
| * @bug 8049814 |
| * @summary JAVA SASL server and client tests with CRAM-MD5 and |
| * DIGEST-MD5 mechanisms. The tests try different QOP values on |
| * client and server side. |
| * @modules java.security.sasl/javax.security.sasl |
| */ |
| public class ClientServerTest { |
| |
| private static final int DELAY = 100; |
| private static final String LOCALHOST = "localhost"; |
| private static final String DIGEST_MD5 = "DIGEST-MD5"; |
| private static final String CRAM_MD5 = "CRAM-MD5"; |
| private static final String PROTOCOL = "saslservice"; |
| private static final String USER_ID = "sasltester"; |
| private static final String PASSWD = "password"; |
| private static final String QOP_AUTH = "auth"; |
| private static final String QOP_AUTH_CONF = "auth-conf"; |
| private static final String QOP_AUTH_INT = "auth-int"; |
| private static final String AUTHID_SASL_TESTER = "sasl_tester"; |
| private static final ArrayList<String> SUPPORT_MECHS = new ArrayList<>(); |
| |
| static { |
| SUPPORT_MECHS.add(DIGEST_MD5); |
| SUPPORT_MECHS.add(CRAM_MD5); |
| } |
| |
| public static void main(String[] args) throws Exception { |
| String[] allQops = { QOP_AUTH_CONF, QOP_AUTH_INT, QOP_AUTH }; |
| String[] twoQops = { QOP_AUTH_INT, QOP_AUTH }; |
| String[] authQop = { QOP_AUTH }; |
| String[] authIntQop = { QOP_AUTH_INT }; |
| String[] authConfQop = { QOP_AUTH_CONF }; |
| String[] emptyQop = {}; |
| |
| boolean success = true; |
| |
| success &= runTest("", CRAM_MD5, new String[] { QOP_AUTH }, |
| new String[] { QOP_AUTH }, false); |
| success &= runTest("", DIGEST_MD5, new String[] { QOP_AUTH }, |
| new String[] { QOP_AUTH }, false); |
| success &= runTest(AUTHID_SASL_TESTER, DIGEST_MD5, |
| new String[] { QOP_AUTH }, new String[] { QOP_AUTH }, false); |
| success &= runTest("", DIGEST_MD5, allQops, authQop, false); |
| success &= runTest("", DIGEST_MD5, allQops, authIntQop, false); |
| success &= runTest("", DIGEST_MD5, allQops, authConfQop, false); |
| success &= runTest("", DIGEST_MD5, twoQops, authQop, false); |
| success &= runTest("", DIGEST_MD5, twoQops, authIntQop, false); |
| success &= runTest("", DIGEST_MD5, twoQops, authConfQop, true); |
| success &= runTest("", DIGEST_MD5, authIntQop, authQop, true); |
| success &= runTest("", DIGEST_MD5, authConfQop, authQop, true); |
| success &= runTest("", DIGEST_MD5, authConfQop, emptyQop, true); |
| success &= runTest("", DIGEST_MD5, authIntQop, emptyQop, true); |
| success &= runTest("", DIGEST_MD5, authQop, emptyQop, true); |
| |
| if (!success) { |
| throw new RuntimeException("At least one test case failed"); |
| } |
| |
| System.out.println("Test passed"); |
| } |
| |
| private static boolean runTest(String authId, String mech, |
| String[] clientQops, String[] serverQops, boolean expectException) |
| throws Exception { |
| |
| System.out.println("AuthId:" + authId |
| + " mechanism:" + mech |
| + " clientQops: " + Arrays.toString(clientQops) |
| + " serverQops: " + Arrays.toString(serverQops) |
| + " expect exception:" + expectException); |
| |
| try (Server server = Server.start(LOCALHOST, authId, serverQops)) { |
| new Client(LOCALHOST, server.getPort(), mech, authId, clientQops) |
| .run(); |
| if (expectException) { |
| System.out.println("Expected exception not thrown"); |
| return false; |
| } |
| } catch (SaslException e) { |
| if (!expectException) { |
| System.out.println("Unexpected exception: " + e); |
| return false; |
| } |
| System.out.println("Expected exception: " + e); |
| } |
| |
| return true; |
| } |
| |
| static enum SaslStatus { |
| SUCCESS, FAILURE, CONTINUE |
| } |
| |
| static class Message implements Serializable { |
| |
| private final SaslStatus status; |
| private final byte[] data; |
| |
| public Message(SaslStatus status, byte[] data) { |
| this.status = status; |
| this.data = data; |
| } |
| |
| public SaslStatus getStatus() { |
| return status; |
| } |
| |
| public byte[] getData() { |
| return data; |
| } |
| } |
| |
| static class SaslPeer { |
| |
| final String host; |
| final String mechanism; |
| final String qop; |
| final CallbackHandler callback; |
| |
| SaslPeer(String host, String authId, String... qops) { |
| this(host, null, authId, qops); |
| } |
| |
| SaslPeer(String host, String mechanism, String authId, String... qops) { |
| this.host = host; |
| this.mechanism = mechanism; |
| |
| StringJoiner sj = new StringJoiner(","); |
| for (String q : qops) { |
| sj.add(q); |
| } |
| qop = sj.toString(); |
| |
| callback = new TestCallbackHandler(USER_ID, PASSWD, host, authId); |
| } |
| |
| Message getMessage(Object ob) { |
| if (!(ob instanceof Message)) { |
| throw new RuntimeException("Expected an instance of Message"); |
| } |
| return (Message) ob; |
| } |
| } |
| |
| static class Server extends SaslPeer implements Runnable, Closeable { |
| |
| private volatile boolean ready = false; |
| private volatile ServerSocket ssocket; |
| |
| static Server start(String host, String authId, String[] serverQops) |
| throws UnknownHostException { |
| Server server = new Server(host, authId, serverQops); |
| Thread thread = new Thread(server); |
| thread.setDaemon(true); |
| thread.start(); |
| |
| while (!server.ready) { |
| try { |
| Thread.sleep(DELAY); |
| } catch (InterruptedException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| |
| return server; |
| } |
| |
| Server(String host, String authId, String... qops) { |
| super(host, authId, qops); |
| } |
| |
| int getPort() { |
| return ssocket.getLocalPort(); |
| } |
| |
| private void processConnection(SaslEndpoint endpoint) |
| throws SaslException, IOException, ClassNotFoundException { |
| System.out.println("process connection"); |
| endpoint.send(SUPPORT_MECHS); |
| Object o = endpoint.receive(); |
| if (!(o instanceof String)) { |
| throw new RuntimeException("Received unexpected object: " + o); |
| } |
| String mech = (String) o; |
| SaslServer saslServer = createSaslServer(mech); |
| Message msg = getMessage(endpoint.receive()); |
| while (!saslServer.isComplete()) { |
| byte[] data = processData(msg.getData(), endpoint, |
| saslServer); |
| if (saslServer.isComplete()) { |
| System.out.println("server is complete"); |
| endpoint.send(new Message(SaslStatus.SUCCESS, data)); |
| } else { |
| System.out.println("server continues"); |
| endpoint.send(new Message(SaslStatus.CONTINUE, data)); |
| msg = getMessage(endpoint.receive()); |
| } |
| } |
| } |
| |
| private byte[] processData(byte[] data, SaslEndpoint endpoint, |
| SaslServer server) throws SaslException, IOException { |
| try { |
| return server.evaluateResponse(data); |
| } catch (SaslException e) { |
| endpoint.send(new Message(SaslStatus.FAILURE, null)); |
| System.out.println("Error while processing data"); |
| throw e; |
| } |
| } |
| |
| private SaslServer createSaslServer(String mechanism) |
| throws SaslException { |
| Map<String, String> props = new HashMap<>(); |
| props.put(Sasl.QOP, qop); |
| return Sasl.createSaslServer(mechanism, PROTOCOL, host, props, |
| callback); |
| } |
| |
| @Override |
| public void run() { |
| try (ServerSocket ss = new ServerSocket(0)) { |
| ssocket = ss; |
| System.out.println("server started on port " + getPort()); |
| ready = true; |
| Socket socket = ss.accept(); |
| try (SaslEndpoint endpoint = new SaslEndpoint(socket)) { |
| System.out.println("server accepted connection"); |
| processConnection(endpoint); |
| } |
| } catch (Exception e) { |
| // ignore it for now, client will throw an exception |
| } |
| } |
| |
| @Override |
| public void close() throws IOException { |
| if (!ssocket.isClosed()) { |
| ssocket.close(); |
| } |
| } |
| } |
| |
| static class Client extends SaslPeer { |
| |
| private final int port; |
| |
| Client(String host, int port, String mech, String authId, |
| String... qops) { |
| super(host, mech, authId, qops); |
| this.port = port; |
| } |
| |
| public void run() throws Exception { |
| System.out.println("Host:" + host + " port: " |
| + port); |
| try (SaslEndpoint endpoint = SaslEndpoint.create(host, port)) { |
| negotiateMechanism(endpoint); |
| SaslClient client = createSaslClient(); |
| byte[] data = new byte[0]; |
| if (client.hasInitialResponse()) { |
| data = client.evaluateChallenge(data); |
| } |
| endpoint.send(new Message(SaslStatus.CONTINUE, data)); |
| Message msg = getMessage(endpoint.receive()); |
| while (!client.isComplete() |
| && msg.getStatus() != SaslStatus.FAILURE) { |
| switch (msg.getStatus()) { |
| case CONTINUE: |
| System.out.println("client continues"); |
| data = client.evaluateChallenge(msg.getData()); |
| endpoint.send(new Message(SaslStatus.CONTINUE, |
| data)); |
| msg = getMessage(endpoint.receive()); |
| break; |
| case SUCCESS: |
| System.out.println("client succeeded"); |
| data = client.evaluateChallenge(msg.getData()); |
| if (data != null) { |
| throw new SaslException("data should be null"); |
| } |
| break; |
| default: |
| throw new RuntimeException("Wrong status:" |
| + msg.getStatus()); |
| } |
| } |
| |
| if (msg.getStatus() == SaslStatus.FAILURE) { |
| throw new RuntimeException("Status is FAILURE"); |
| } |
| } |
| |
| System.out.println("Done"); |
| } |
| |
| private SaslClient createSaslClient() throws SaslException { |
| Map<String, String> props = new HashMap<>(); |
| props.put(Sasl.QOP, qop); |
| return Sasl.createSaslClient(new String[] {mechanism}, USER_ID, |
| PROTOCOL, host, props, callback); |
| } |
| |
| private void negotiateMechanism(SaslEndpoint endpoint) |
| throws ClassNotFoundException, IOException { |
| Object o = endpoint.receive(); |
| if (o instanceof ArrayList) { |
| ArrayList list = (ArrayList) o; |
| if (!list.contains(mechanism)) { |
| throw new RuntimeException( |
| "Server does not support specified mechanism:" |
| + mechanism); |
| } |
| } else { |
| throw new RuntimeException( |
| "Expected an instance of ArrayList, but received " + o); |
| } |
| |
| endpoint.send(mechanism); |
| } |
| |
| } |
| |
| static class SaslEndpoint implements AutoCloseable { |
| |
| private final Socket socket; |
| private ObjectInputStream input; |
| private ObjectOutputStream output; |
| |
| static SaslEndpoint create(String host, int port) throws IOException { |
| return new SaslEndpoint(new Socket(host, port)); |
| } |
| |
| SaslEndpoint(Socket socket) throws IOException { |
| this.socket = socket; |
| } |
| |
| private ObjectInputStream getInput() throws IOException { |
| if (input == null && socket != null) { |
| input = new ObjectInputStream(socket.getInputStream()); |
| } |
| return input; |
| } |
| |
| private ObjectOutputStream getOutput() throws IOException { |
| if (output == null && socket != null) { |
| output = new ObjectOutputStream(socket.getOutputStream()); |
| } |
| return output; |
| } |
| |
| public Object receive() throws IOException, ClassNotFoundException { |
| return getInput().readObject(); |
| } |
| |
| public void send(Object obj) throws IOException { |
| getOutput().writeObject(obj); |
| getOutput().flush(); |
| } |
| |
| @Override |
| public void close() throws IOException { |
| if (socket != null && !socket.isClosed()) { |
| socket.close(); |
| } |
| } |
| |
| } |
| |
| static class TestCallbackHandler implements CallbackHandler { |
| |
| private final String userId; |
| private final char[] passwd; |
| private final String realm; |
| private String authId; |
| |
| TestCallbackHandler(String userId, String passwd, String realm, |
| String authId) { |
| this.userId = userId; |
| this.passwd = passwd.toCharArray(); |
| this.realm = realm; |
| this.authId = authId; |
| } |
| |
| @Override |
| public void handle(Callback[] callbacks) throws IOException, |
| UnsupportedCallbackException { |
| for (Callback callback : callbacks) { |
| if (callback instanceof NameCallback) { |
| System.out.println("NameCallback"); |
| ((NameCallback) callback).setName(userId); |
| } else if (callback instanceof PasswordCallback) { |
| System.out.println("PasswordCallback"); |
| ((PasswordCallback) callback).setPassword(passwd); |
| } else if (callback instanceof RealmCallback) { |
| System.out.println("RealmCallback"); |
| ((RealmCallback) callback).setText(realm); |
| } else if (callback instanceof RealmChoiceCallback) { |
| System.out.println("RealmChoiceCallback"); |
| RealmChoiceCallback choice = (RealmChoiceCallback) callback; |
| if (realm == null) { |
| choice.setSelectedIndex(choice.getDefaultChoice()); |
| } else { |
| String[] choices = choice.getChoices(); |
| for (int j = 0; j < choices.length; j++) { |
| if (realm.equals(choices[j])) { |
| choice.setSelectedIndex(j); |
| break; |
| } |
| } |
| } |
| } else if (callback instanceof AuthorizeCallback) { |
| System.out.println("AuthorizeCallback"); |
| ((AuthorizeCallback) callback).setAuthorized(true); |
| if (authId == null || authId.trim().length() == 0) { |
| authId = userId; |
| } |
| ((AuthorizeCallback) callback).setAuthorizedID(authId); |
| } else { |
| throw new UnsupportedCallbackException(callback); |
| } |
| } |
| } |
| } |
| |
| } |